Skip to content

Commit 21e4310

Browse files
committed
feat: streaming whisper
1 parent 29f6bb7 commit 21e4310

10 files changed

+1075
-82
lines changed

audio_example.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,36 @@
55
from comfystream.client import ComfyStreamClient
66

77
async def main():
8-
cwd = "/home/user/ComfyUI"
9-
client = ComfyStreamClient(cwd=cwd)
10-
8+
cwd = "/home/user/ComfyUI"
9+
client = ComfyStreamClient(cwd=cwd, type="audio")
1110
with open("./workflows/audio-whsiper-example-workflow.json", "r") as f:
1211
prompt = json.load(f)
1312

1413
client.set_prompt(prompt)
15-
16-
waveform, _ = torchaudio.load("/home/user/harvard.wav")
14+
waveform, sr = torchaudio.load("/home/user/harvard.wav")
15+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
16+
waveform = resampler(waveform)
17+
sr = 16000
1718
if waveform.ndim > 1:
18-
audio_tensor = waveform.mean(dim=0)
19+
waveform = waveform.mean(dim=0, keepdim=True)
20+
21+
chunk_ms = 20
22+
chunk_size = int(sr * (chunk_ms / 1000.0))
23+
24+
total_samples = waveform.shape[1]
25+
offset = 0
26+
27+
results = []
28+
while offset < total_samples:
29+
end = min(offset + chunk_size, total_samples)
30+
chunk = waveform[:, offset:end]
31+
offset = end
32+
results.append(await client.queue_prompt(chunk.numpy().squeeze()))
1933

20-
output = await client.queue_prompt(audio_tensor)
21-
print(output)
34+
print("Result:")
35+
for result in results:
36+
if result[0] is not None:
37+
print(result[-1])
2238

2339
if __name__ == "__main__":
2440
asyncio.run(main())

nodes/audio_utils/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from .apply_whisper import ApplyWhisper
21
from .load_audio_tensor import LoadAudioTensor
3-
from .save_asr_response import SaveASRResponse
2+
from .save_result import SaveResult
43
from .save_audio_tensor import SaveAudioTensor
54

6-
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper, "SaveAudioTensor": SaveAudioTensor}
5+
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveResult": SaveResult, "SaveAudioTensor": SaveAudioTensor}
76

87
__all__ = ["NODE_CLASS_MAPPINGS"]

nodes/audio_utils/apply_whisper.py

-62
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from comfystream import tensor_cache
22

3-
class SaveASRResponse:
3+
class SaveResult:
44
CATEGORY = "audio_utils"
55
RETURN_TYPES = ()
66
FUNCTION = "execute"
@@ -10,15 +10,15 @@ class SaveASRResponse:
1010
def INPUT_TYPES(s):
1111
return {
1212
"required": {
13-
"data": ("DICT",),
13+
"result": ("RESULT",),
1414
}
1515
}
1616

1717
@classmethod
1818
def IS_CHANGED(s):
1919
return float("nan")
2020

21-
def execute(self, data: dict):
21+
def execute(self, result):
2222
fut = tensor_cache.audio_outputs.pop()
23-
fut.set_result(data)
24-
return data
23+
fut.set_result(result)
24+
return result

nodes/whisper_utils/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .apply_whisper import ApplyWhisper
2+
3+
NODE_CLASS_MAPPINGS = {"ApplyWhisper": ApplyWhisper}
4+
5+
__all__ = ["NODE_CLASS_MAPPINGS"]

nodes/whisper_utils/apply_whisper.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from .whisper_online import FasterWhisperASR, VACOnlineASRProcessor
2+
3+
class ApplyWhisper:
4+
@classmethod
5+
def INPUT_TYPES(s):
6+
return {
7+
"required": {
8+
"audio": ("AUDIO",),
9+
}
10+
}
11+
12+
CATEGORY = "whisper_utils"
13+
RETURN_TYPES = ("RESULT",)
14+
FUNCTION = "apply_whisper"
15+
16+
def __init__(self):
17+
self.asr = FasterWhisperASR(
18+
lan="en",
19+
modelsize="large-v3",
20+
cache_dir=None,
21+
model_dir=None,
22+
logfile=None
23+
)
24+
self.asr.use_vad()
25+
26+
self.online = VACOnlineASRProcessor(
27+
online_chunk_size=0.5,
28+
asr=self.asr,
29+
tokenizer=None,
30+
logfile=None,
31+
buffer_trimming=("segment", 15)
32+
)
33+
34+
def apply_whisper(self, audio):
35+
self.online.insert_audio_chunk(audio)
36+
result = self.online.process_iter()
37+
return (result,)
+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import torch
2+
3+
# This is copied from silero-vad's vad_utils.py:
4+
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
5+
# (except changed defaults)
6+
7+
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
8+
9+
class VADIterator:
10+
def __init__(self,
11+
model,
12+
threshold: float = 0.5,
13+
sampling_rate: int = 16000,
14+
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
15+
speech_pad_ms: int = 100 # same
16+
):
17+
18+
"""
19+
Class for stream imitation
20+
21+
Parameters
22+
----------
23+
model: preloaded .jit silero VAD model
24+
25+
threshold: float (default - 0.5)
26+
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
27+
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
28+
29+
sampling_rate: int (default - 16000)
30+
Currently silero VAD models support 8000 and 16000 sample rates
31+
32+
min_silence_duration_ms: int (default - 100 milliseconds)
33+
In the end of each speech chunk wait for min_silence_duration_ms before separating it
34+
35+
speech_pad_ms: int (default - 30 milliseconds)
36+
Final speech chunks are padded by speech_pad_ms each side
37+
"""
38+
39+
self.model = model
40+
self.threshold = threshold
41+
self.sampling_rate = sampling_rate
42+
43+
if sampling_rate not in [8000, 16000]:
44+
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
45+
46+
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
47+
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
48+
self.reset_states()
49+
50+
def reset_states(self):
51+
52+
self.model.reset_states()
53+
self.triggered = False
54+
self.temp_end = 0
55+
self.current_sample = 0
56+
57+
def __call__(self, x, return_seconds=False):
58+
"""
59+
x: torch.Tensor
60+
audio chunk (see examples in repo)
61+
62+
return_seconds: bool (default - False)
63+
whether return timestamps in seconds (default - samples)
64+
"""
65+
66+
if not torch.is_tensor(x):
67+
try:
68+
x = torch.Tensor(x)
69+
except:
70+
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
71+
72+
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
73+
self.current_sample += window_size_samples
74+
75+
speech_prob = self.model(x, self.sampling_rate).item()
76+
77+
if (speech_prob >= self.threshold) and self.temp_end:
78+
self.temp_end = 0
79+
80+
if (speech_prob >= self.threshold) and not self.triggered:
81+
self.triggered = True
82+
speech_start = self.current_sample - self.speech_pad_samples
83+
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
84+
85+
if (speech_prob < self.threshold - 0.15) and self.triggered:
86+
if not self.temp_end:
87+
self.temp_end = self.current_sample
88+
if self.current_sample - self.temp_end < self.min_silence_samples:
89+
return None
90+
else:
91+
speech_end = self.temp_end + self.speech_pad_samples
92+
self.temp_end = 0
93+
self.triggered = False
94+
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
95+
96+
return None
97+
98+
#######################
99+
# because Silero now requires exactly 512-sized audio chunks
100+
101+
import numpy as np
102+
class FixedVADIterator(VADIterator):
103+
'''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
104+
If audio to be processed at once is long and multiple voiced segments detected,
105+
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
106+
'''
107+
108+
def reset_states(self):
109+
super().reset_states()
110+
self.buffer = np.array([],dtype=np.float32)
111+
112+
def __call__(self, x, return_seconds=False):
113+
self.buffer = np.append(self.buffer, x)
114+
ret = None
115+
while len(self.buffer) >= 512:
116+
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
117+
self.buffer = self.buffer[512:]
118+
if ret is None:
119+
ret = r
120+
elif r is not None:
121+
if 'end' in r:
122+
ret['end'] = r['end'] # the latter end
123+
if 'start' in r and 'end' in ret: # there is an earlier start.
124+
# Remove end, merging this segment with the previous one.
125+
del ret['end']
126+
return ret if ret != {} else None
127+
128+
if __name__ == "__main__":
129+
# test/demonstrate the need for FixedVADIterator:
130+
131+
import torch
132+
model, _ = torch.hub.load(
133+
repo_or_dir='snakers4/silero-vad',
134+
model='silero_vad'
135+
)
136+
vac = FixedVADIterator(model)
137+
# vac = VADIterator(model) # the second case crashes with this
138+
139+
# this works: for both
140+
audio_buffer = np.array([0]*(512),dtype=np.float32)
141+
vac(audio_buffer)
142+
143+
# this crashes on the non FixedVADIterator with
144+
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
145+
audio_buffer = np.array([0]*(512-1),dtype=np.float32)
146+
vac(audio_buffer)

0 commit comments

Comments
 (0)