|
| 1 | +import torch |
| 2 | + |
| 3 | +# This is copied from silero-vad's vad_utils.py: |
| 4 | +# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340 |
| 5 | +# (except changed defaults) |
| 6 | + |
| 7 | +# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE |
| 8 | + |
| 9 | +class VADIterator: |
| 10 | + def __init__(self, |
| 11 | + model, |
| 12 | + threshold: float = 0.5, |
| 13 | + sampling_rate: int = 16000, |
| 14 | + min_silence_duration_ms: int = 500, # makes sense on one recording that I checked |
| 15 | + speech_pad_ms: int = 100 # same |
| 16 | + ): |
| 17 | + |
| 18 | + """ |
| 19 | + Class for stream imitation |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + model: preloaded .jit silero VAD model |
| 24 | +
|
| 25 | + threshold: float (default - 0.5) |
| 26 | + Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH. |
| 27 | + It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. |
| 28 | +
|
| 29 | + sampling_rate: int (default - 16000) |
| 30 | + Currently silero VAD models support 8000 and 16000 sample rates |
| 31 | +
|
| 32 | + min_silence_duration_ms: int (default - 100 milliseconds) |
| 33 | + In the end of each speech chunk wait for min_silence_duration_ms before separating it |
| 34 | +
|
| 35 | + speech_pad_ms: int (default - 30 milliseconds) |
| 36 | + Final speech chunks are padded by speech_pad_ms each side |
| 37 | + """ |
| 38 | + |
| 39 | + self.model = model |
| 40 | + self.threshold = threshold |
| 41 | + self.sampling_rate = sampling_rate |
| 42 | + |
| 43 | + if sampling_rate not in [8000, 16000]: |
| 44 | + raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]') |
| 45 | + |
| 46 | + self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 |
| 47 | + self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000 |
| 48 | + self.reset_states() |
| 49 | + |
| 50 | + def reset_states(self): |
| 51 | + |
| 52 | + self.model.reset_states() |
| 53 | + self.triggered = False |
| 54 | + self.temp_end = 0 |
| 55 | + self.current_sample = 0 |
| 56 | + |
| 57 | + def __call__(self, x, return_seconds=False): |
| 58 | + """ |
| 59 | + x: torch.Tensor |
| 60 | + audio chunk (see examples in repo) |
| 61 | +
|
| 62 | + return_seconds: bool (default - False) |
| 63 | + whether return timestamps in seconds (default - samples) |
| 64 | + """ |
| 65 | + |
| 66 | + if not torch.is_tensor(x): |
| 67 | + try: |
| 68 | + x = torch.Tensor(x) |
| 69 | + except: |
| 70 | + raise TypeError("Audio cannot be casted to tensor. Cast it manually") |
| 71 | + |
| 72 | + window_size_samples = len(x[0]) if x.dim() == 2 else len(x) |
| 73 | + self.current_sample += window_size_samples |
| 74 | + |
| 75 | + speech_prob = self.model(x, self.sampling_rate).item() |
| 76 | + |
| 77 | + if (speech_prob >= self.threshold) and self.temp_end: |
| 78 | + self.temp_end = 0 |
| 79 | + |
| 80 | + if (speech_prob >= self.threshold) and not self.triggered: |
| 81 | + self.triggered = True |
| 82 | + speech_start = self.current_sample - self.speech_pad_samples |
| 83 | + return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)} |
| 84 | + |
| 85 | + if (speech_prob < self.threshold - 0.15) and self.triggered: |
| 86 | + if not self.temp_end: |
| 87 | + self.temp_end = self.current_sample |
| 88 | + if self.current_sample - self.temp_end < self.min_silence_samples: |
| 89 | + return None |
| 90 | + else: |
| 91 | + speech_end = self.temp_end + self.speech_pad_samples |
| 92 | + self.temp_end = 0 |
| 93 | + self.triggered = False |
| 94 | + return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)} |
| 95 | + |
| 96 | + return None |
| 97 | + |
| 98 | +####################### |
| 99 | +# because Silero now requires exactly 512-sized audio chunks |
| 100 | + |
| 101 | +import numpy as np |
| 102 | +class FixedVADIterator(VADIterator): |
| 103 | + '''It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once. |
| 104 | + If audio to be processed at once is long and multiple voiced segments detected, |
| 105 | + then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment. |
| 106 | + ''' |
| 107 | + |
| 108 | + def reset_states(self): |
| 109 | + super().reset_states() |
| 110 | + self.buffer = np.array([],dtype=np.float32) |
| 111 | + |
| 112 | + def __call__(self, x, return_seconds=False): |
| 113 | + self.buffer = np.append(self.buffer, x) |
| 114 | + ret = None |
| 115 | + while len(self.buffer) >= 512: |
| 116 | + r = super().__call__(self.buffer[:512], return_seconds=return_seconds) |
| 117 | + self.buffer = self.buffer[512:] |
| 118 | + if ret is None: |
| 119 | + ret = r |
| 120 | + elif r is not None: |
| 121 | + if 'end' in r: |
| 122 | + ret['end'] = r['end'] # the latter end |
| 123 | + if 'start' in r and 'end' in ret: # there is an earlier start. |
| 124 | + # Remove end, merging this segment with the previous one. |
| 125 | + del ret['end'] |
| 126 | + return ret if ret != {} else None |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + # test/demonstrate the need for FixedVADIterator: |
| 130 | + |
| 131 | + import torch |
| 132 | + model, _ = torch.hub.load( |
| 133 | + repo_or_dir='snakers4/silero-vad', |
| 134 | + model='silero_vad' |
| 135 | + ) |
| 136 | + vac = FixedVADIterator(model) |
| 137 | +# vac = VADIterator(model) # the second case crashes with this |
| 138 | + |
| 139 | + # this works: for both |
| 140 | + audio_buffer = np.array([0]*(512),dtype=np.float32) |
| 141 | + vac(audio_buffer) |
| 142 | + |
| 143 | + # this crashes on the non FixedVADIterator with |
| 144 | + # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError") |
| 145 | + audio_buffer = np.array([0]*(512-1),dtype=np.float32) |
| 146 | + vac(audio_buffer) |
0 commit comments