1
1
import torch
2
2
import av
3
3
import numpy as np
4
+ import fractions
4
5
5
- from typing import Any , Dict , Optional , Union
6
+ from av import AudioFrame
7
+ from typing import Any , Dict , Optional , Union , List
6
8
from comfystream .client import ComfyStreamClient
7
9
8
10
WARMUP_RUNS = 5
9
-
10
- # TODO: remove, was just for temp UI
11
11
import logging
12
-
13
12
display_logger = logging .getLogger ('display_logger' )
14
13
display_logger .setLevel (logging .INFO )
15
14
handler = logging .FileHandler ('display_logs.txt' )
@@ -57,46 +56,47 @@ async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame:
57
56
class AudioPipeline :
58
57
def __init__ (self , ** kwargs ):
59
58
self .client = ComfyStreamClient (** kwargs , type = "audio" )
60
- self .resampler = av .audio .resampler .AudioResampler (format = 's16' , layout = 'mono' , rate = 16000 )
59
+ self .resampler = av .audio .resampler .AudioResampler (format = 's16' , layout = 'mono' , rate = 48000 )
60
+ self .sample_rate = 48000
61
+ self .frame_size = int (self .sample_rate * 0.02 )
62
+ self .time_base = fractions .Fraction (1 , self .sample_rate )
63
+ self .curr_pts = 0
61
64
62
65
async def warm (self ):
63
- dummy_audio = torch . randn ( 16000 )
66
+ dummy_audio = np . random . randint ( - 32768 , 32767 , 48000 * 1 , dtype = np . int16 )
64
67
for _ in range (WARMUP_RUNS ):
65
68
await self .predict (dummy_audio )
66
69
67
70
def set_prompt (self , prompt : Dict [Any , Any ]):
68
71
self .client .set_prompt (prompt )
69
72
70
- def preprocess (self , frame : av .AudioFrame ) -> torch .Tensor :
71
- resampled_frame = self .resampler .resample (frame )[0 ]
72
- samples = resampled_frame .to_ndarray ()
73
- samples = samples .astype (np .float32 ) / 32768.0
74
- return samples
75
-
76
- def postprocess (self , output : torch .Tensor ) -> Optional [Union [av .AudioFrame , str ]]:
77
- out_np = output .cpu ().numpy ()
78
- out_np = np .clip (out_np * 32768.0 , - 32768 , 32767 ).astype (np .int16 )
79
- audio_frame = av .AudioFrame .from_ndarray (out_np , format = "s16" , layout = "stereo" )
80
- return audio_frame
73
+ def preprocess (self , frames : List [av .AudioFrame ]) -> torch .Tensor :
74
+ audio_arrays = []
75
+ for frame in frames :
76
+ audio_arrays .append (self .resampler .resample (frame )[0 ].to_ndarray ())
77
+ return np .concatenate (audio_arrays , axis = 1 ).flatten ()
78
+
79
+ def postprocess (self , out_np ) -> Optional [Union [av .AudioFrame , str ]]:
80
+ frames = []
81
+ for idx in range (0 , len (out_np ), self .frame_size ):
82
+ frame_samples = out_np [idx :idx + self .frame_size ]
83
+ frame_samples = frame_samples .reshape (1 , - 1 )
84
+ frame = AudioFrame .from_ndarray (frame_samples , layout = "mono" )
85
+ frame .sample_rate = self .sample_rate
86
+ frame .pts = self .curr_pts
87
+ frame .time_base = self .time_base
88
+ self .curr_pts += 960
89
+
90
+ frames .append (frame )
91
+ return frames
81
92
82
- async def predict (self , frame : torch . Tensor ) -> torch .Tensor :
93
+ async def predict (self , frame ) -> torch .Tensor :
83
94
return await self .client .queue_prompt (frame )
84
95
85
- async def __call__ (self , frame : av .AudioFrame ):
86
- # TODO: clean this up later for audio-to-text and audio-to-audio
87
- pre_output = self .preprocess (frame )
88
- pred_output = await self .predict (pre_output )
89
- if type (pred_output ) == tuple :
90
- if pred_output [0 ] is not None :
91
- await self .log_text (f"{ pred_output [0 ]} { pred_output [1 ]} { pred_output [2 ]} " )
92
- return frame
93
- else :
94
- post_output = self .postprocess (pred_output )
95
- post_output .sample_rate = frame .sample_rate
96
- post_output .pts = frame .pts
97
- post_output .time_base = frame .time_base
98
- return post_output
99
-
100
- async def log_text (self , text : str ):
101
- # TODO: remove, was just for temp UI
102
- display_logger .info (text )
96
+ async def __call__ (self , frames : List [av .AudioFrame ]):
97
+ pre_audio = self .preprocess (frames )
98
+ pred_audio , text = await self .predict (pre_audio )
99
+ if text [- 1 ] != "" :
100
+ display_logger .info (f"{ text [0 ]} { text [1 ]} { text [2 ]} " )
101
+ pred_audios = self .postprocess (pred_audio )
102
+ return pred_audios
0 commit comments