|
| 1 | +import torch |
| 2 | +import whisper |
| 3 | + |
| 4 | +class ApplyWhisper: |
| 5 | + @classmethod |
| 6 | + def INPUT_TYPES(s): |
| 7 | + return { |
| 8 | + "required": { |
| 9 | + "audio": ("AUDIO",), |
| 10 | + "model": (["base", "tiny", "small", "medium", "large"],), |
| 11 | + } |
| 12 | + } |
| 13 | + |
| 14 | + RETURN_TYPES = ("DICT",) |
| 15 | + FUNCTION = "apply_whisper" |
| 16 | + |
| 17 | + def __init__(self): |
| 18 | + self.model = None |
| 19 | + self.audio_buffer = [] |
| 20 | + # TO:DO to get them as params |
| 21 | + self.sample_rate = 16000 |
| 22 | + self.min_duration = 1.0 |
| 23 | + |
| 24 | + def apply_whisper(self, audio, model): |
| 25 | + if self.model is None: |
| 26 | + self.model = whisper.load_model(model).cuda() |
| 27 | + |
| 28 | + self.audio_buffer.append(audio) |
| 29 | + total_duration = sum(chunk.shape[0] / self.sample_rate for chunk in self.audio_buffer) |
| 30 | + if total_duration < self.min_duration: |
| 31 | + return {"text": "", "segments_alignment": [], "words_alignment": []} |
| 32 | + |
| 33 | + concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda() |
| 34 | + self.audio_buffer = [] |
| 35 | + result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True) |
| 36 | + segments = result['segments'] |
| 37 | + segments_alignment = [] |
| 38 | + words_alignment = [] |
| 39 | + |
| 40 | + for segment in segments: |
| 41 | + segment_dict = { |
| 42 | + 'value': segment['text'].strip(), |
| 43 | + 'start': segment['start'], |
| 44 | + 'end': segment['end'] |
| 45 | + } |
| 46 | + segments_alignment.append(segment_dict) |
| 47 | + |
| 48 | + for word in segment["words"]: |
| 49 | + word_dict = { |
| 50 | + 'value': word["word"].strip(), |
| 51 | + 'start': word["start"], |
| 52 | + 'end': word['end'] |
| 53 | + } |
| 54 | + words_alignment.append(word_dict) |
| 55 | + |
| 56 | + return ({ |
| 57 | + "text": result["text"].strip(), |
| 58 | + "segments_alignment": segments_alignment, |
| 59 | + "words_alignment": words_alignment |
| 60 | + },) |
0 commit comments