Skip to content

Commit fbe59f4

Browse files
committed
feat: whisper workflow
1 parent 53fa913 commit fbe59f4

File tree

8 files changed

+175
-2
lines changed

8 files changed

+175
-2
lines changed

audio_example.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import json
2+
import asyncio
3+
import torchaudio
4+
5+
from comfystream.client import ComfyStreamClient
6+
7+
async def main():
8+
cwd = "/home/user/ComfyUI"
9+
client = ComfyStreamClient(cwd=cwd)
10+
11+
with open("./workflows/audio-whsiper-example-workflow.json", "r") as f:
12+
prompt = json.load(f)
13+
14+
client.set_prompt(prompt)
15+
16+
waveform, _ = torchaudio.load("harvard.wav")
17+
if waveform.ndim > 1:
18+
audio_tensor = waveform.mean(dim=0)
19+
20+
output = await client.queue_prompt(audio_tensor)
21+
print(output)
22+
23+
if __name__ == "__main__":
24+
asyncio.run(main())

nodes/audio_utils/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .apply_whisper import ApplyWhisper
2+
from .load_audio_tensor import LoadAudioTensor
3+
from .save_asr_response import SaveASRResponse
4+
5+
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper}
6+
7+
__all__ = ["NODE_CLASS_MAPPINGS"]

nodes/audio_utils/apply_whisper.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
import whisper
3+
4+
class ApplyWhisper:
5+
@classmethod
6+
def INPUT_TYPES(s):
7+
return {
8+
"required": {
9+
"audio": ("AUDIO",),
10+
"model": (["base", "tiny", "small", "medium", "large"],),
11+
}
12+
}
13+
14+
RETURN_TYPES = ("DICT",)
15+
FUNCTION = "apply_whisper"
16+
17+
def __init__(self):
18+
self.model = None
19+
self.audio_buffer = []
20+
# TO:DO to get them as params
21+
self.sample_rate = 16000
22+
self.min_duration = 1.0
23+
24+
def apply_whisper(self, audio, model):
25+
if self.model is None:
26+
self.model = whisper.load_model(model).cuda()
27+
28+
self.audio_buffer.append(audio)
29+
total_duration = sum(chunk.shape[0] / self.sample_rate for chunk in self.audio_buffer)
30+
if total_duration < self.min_duration:
31+
return {"text": "", "segments_alignment": [], "words_alignment": []}
32+
33+
concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda()
34+
self.audio_buffer = []
35+
result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True)
36+
segments = result['segments']
37+
segments_alignment = []
38+
words_alignment = []
39+
40+
for segment in segments:
41+
segment_dict = {
42+
'value': segment['text'].strip(),
43+
'start': segment['start'],
44+
'end': segment['end']
45+
}
46+
segments_alignment.append(segment_dict)
47+
48+
for word in segment["words"]:
49+
word_dict = {
50+
'value': word["word"].strip(),
51+
'start': word["start"],
52+
'end': word['end']
53+
}
54+
words_alignment.append(word_dict)
55+
56+
return ({
57+
"text": result["text"].strip(),
58+
"segments_alignment": segments_alignment,
59+
"words_alignment": words_alignment
60+
},)
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from comfystream import tensor_cache
2+
3+
class LoadAudioTensor:
4+
CATEGORY = "tensor_utils"
5+
RETURN_TYPES = ("AUDIO",)
6+
FUNCTION = "execute"
7+
8+
@classmethod
9+
def INPUT_TYPES(s):
10+
return {}
11+
12+
@classmethod
13+
def IS_CHANGED():
14+
return float("nan")
15+
16+
def execute(self):
17+
audio = tensor_cache.inputs.pop()
18+
return (audio,)
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from comfystream import tensor_cache
2+
3+
class SaveASRResponse:
4+
CATEGORY = "tensor_utils"
5+
RETURN_TYPES = ()
6+
FUNCTION = "execute"
7+
OUTPUT_NODE = True
8+
9+
@classmethod
10+
def INPUT_TYPES(s):
11+
return {
12+
"required": {
13+
"data": ("DICT",),
14+
}
15+
}
16+
17+
@classmethod
18+
def IS_CHANGED(s):
19+
return float("nan")
20+
21+
def execute(self, data: dict):
22+
fut = tensor_cache.outputs.pop()
23+
fut.set_result(data)
24+
return data

src/comfystream/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def convert_prompt(prompt: PromptDictInput) -> Prompt:
2929
"class_type": "SaveTensor",
3030
"_meta": {"title": "SaveTensor"},
3131
}
32-
elif node.get("class_type") == "LoadTensor":
32+
elif node.get("class_type") in ["LoadTensor", "LoadAudioTensor"]:
3333
num_inputs += 1
34-
elif node.get("class_type") == "SaveTensor":
34+
elif node.get("class_type") in ["SaveTensor", "SaveASRResponse"]:
3535
num_outputs += 1
3636

3737
# Only handle single input for now

ui/src/components/webcam.tsx

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ export function Webcam({ onStreamReady }: WebcamProps) {
2828
width: { exact: 512 },
2929
height: { exact: 512 },
3030
},
31+
audio: {
32+
noiseSuppression: true,
33+
echoCancellation: true,
34+
sampleRate: 16000,
35+
sampleSize: 16,
36+
},
3137
});
3238

3339
if (videoRef.current) videoRef.current.srcObject = stream;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
{
2+
"1": {
3+
"inputs": {},
4+
"class_type": "LoadAudioTensor",
5+
"_meta": {
6+
"title": "Load Audio Tensor"
7+
}
8+
},
9+
"2": {
10+
"inputs": {
11+
"audio": [
12+
"1",
13+
0
14+
],
15+
"model": "large"
16+
},
17+
"class_type": "ApplyWhisper",
18+
"_meta": {
19+
"title": "Apply Whisper"
20+
}
21+
},
22+
"3": {
23+
"inputs": {
24+
"data": [
25+
"2",
26+
0
27+
]
28+
},
29+
"class_type": "SaveASRResponse",
30+
"_meta": {
31+
"title": "Save ASR Response"
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)