Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter out bad stitched samples by checking the speech to text transcription #72

Merged
merged 3 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ repos:
- id: trailing-whitespace
- id: check-json
- id: name-tests-test
- repo: https://github.com/timothycrosley/isort
rev: 5.2.0
- repo: https://github.com/pycqa/isort
rev: 5.5.4
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/psf/black
rev: 19.10b0
hooks:
Expand All @@ -27,6 +28,7 @@ repos:
rev: 3.8.3
hooks:
- id: flake8
args: [--max-line-length=120]
- repo: https://github.com/pre-commit/mirrors-pylint
rev: v2.6.0
hooks:
Expand Down
108 changes: 80 additions & 28 deletions howl/data/stitcher.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import itertools
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import soundfile
import torch
from howl.data.dataset import (AudioClipDataset, AudioClipExample,
AudioClipMetadata, AudioDataset, DatasetType)
from tqdm import tqdm

from howl.data.dataset import (
AudioClipDataset,
AudioClipExample,
AudioClipMetadata,
AudioDataset,
DatasetType,
)
from howl.data.tokenize import Vocab
from howl.settings import SETTINGS
from tqdm import tqdm
from howl.utils.sphinx_keyword_detector import SphinxKeywordDetector

__all__ = ['WordStitcher']
__all__ = ["WordStitcher"]


@dataclass
Expand All @@ -24,18 +30,27 @@ class FrameLabelledSample:


class Stitcher:
def __init__(self,
vocab: Vocab):
def __init__(self, vocab: Vocab, detect_keyword: bool = True):
"""Base Stitcher class

Args:
vocab (Vocab): vocab containing wakeword
detect_keyword (bool, optional): drop invalid stitched samples through secondary keyword detection step
"""
self.sequence = SETTINGS.inference_engine.inference_sequence
self.sr = SETTINGS.audio.sample_rate
self.vocab = vocab
self.wakeword = ' '.join(self.vocab[x]
for x in self.sequence)
self.wakeword = " ".join(self.vocab[x] for x in self.sequence)

self.detect_keyword = detect_keyword
self.keyword_detector = []
if self.detect_keyword:
for x in self.sequence:
self.keyword_detector.append(SphinxKeywordDetector(self.vocab[x]))


class WordStitcher(Stitcher):
def __init__(self,
**kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def concatenate_end_timestamps(self, end_timestamps_list: List[List[float]]) -> List[float]:
Expand All @@ -60,7 +75,7 @@ def concatenate_end_timestamps(self, end_timestamps_list: List[List[float]]) ->

return concatnated_timestamps[:-1] # discard last space timestamp

def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datasets: AudioDataset):
def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, *datasets: AudioDataset):
"""collect vocab samples from datasets and generate stitched wakeword samples

Args:
Expand All @@ -85,8 +100,14 @@ def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datase
for char_idx in char_indices:
adjusted_end_timestamps.append(sample.metadata.end_timestamps[char_idx] - start_timestamp)

sample_set[label].append(FrameLabelledSample(
sample.audio_data[audio_start_idx:audio_end_idx], end_timestamp-start_timestamp, adjusted_end_timestamps, label))
sample_set[label].append(
FrameLabelledSample(
sample.audio_data[audio_start_idx:audio_end_idx],
end_timestamp - start_timestamp,
adjusted_end_timestamps,
label,
)
)

audio_dir = stitched_dataset_dir / "audio"
audio_dir.mkdir(exist_ok=True)
Expand All @@ -100,34 +121,63 @@ def stitch(self, num_stitched_samples: int, stitched_dataset_dir: Path, * datase

# generate AudioClipExample for each vocab sample
self.stitched_samples = []
for sample_idx in tqdm(range(num_stitched_samples), desc="Generating stitched samples"):

pbar = tqdm(total=num_stitched_samples, desc="Generating stitched samples")
sample_idx = 0
num_skipped_samples = 0
while True:
if sample_idx == num_stitched_samples:
break

sample_set = []
for sample_list in sample_lists:
sample_set.append(random.choice(sample_list))

audio_data = torch.cat([labelled_data.audio_data for labelled_data in sample_set])

if self.detect_keyword:
temp_audio_file_path = "/tmp/temp.wav"
soundfile.write(temp_audio_file_path, audio_data.numpy(), self.sr)

keyword_exists = True
for detector in self.keyword_detector:
# sphinx keyword detection may not be sufficient for audio with repeated words
if len(detector.detect(temp_audio_file_path)) == 0:
keyword_exists = False
break

if keyword_exists:
num_skipped_samples += 1
continue

metatdata = AudioClipMetadata(
path=Path(audio_dir / f"{sample_idx}").with_suffix(".wav"),
transcription=self.wakeword,
end_timestamps=self.concatenate_end_timestamps(
[labelled_data.end_timestamps for labelled_data in sample_set])
[labelled_data.end_timestamps for labelled_data in sample_set]
),
)

# TODO:: dataset writer load the samples upon write and does not use data in memory
# writer classes need to be refactored to use audio data if exist
audio_data = torch.cat([labelled_data.audio_data for labelled_data in sample_set])
soundfile.write(metatdata.path, audio_data.numpy(), self.sr)

stitched_sample = AudioClipExample(
metadata=metatdata,
audio_data=audio_data,
sample_rate=self.sr)
stitched_sample = AudioClipExample(metadata=metatdata, audio_data=audio_data, sample_rate=self.sr)

self.stitched_samples.append(stitched_sample)

def load_splits(self,
train_pct: float,
dev_pct: float,
test_pct: float) -> Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]:
sample_idx += 1
pbar.update()

if self.detect_keyword:
print(
f"While generating {num_stitched_samples} stithced samples, "
f"{num_skipped_samples} are filtered by keyword detection"
)

def load_splits(
self, train_pct: float, dev_pct: float, test_pct: float
) -> Tuple[AudioClipDataset, AudioClipDataset, AudioClipDataset]:
"""split the generated stitched samples based on the given pct
first train_pct samples are used to generate train set
next dev_pct samples are used to generate dev set
Expand Down Expand Up @@ -161,6 +211,8 @@ def load_splits(self,
test_split.append(sample.metadata)

ds_kwargs = dict(sr=self.sr, mono=SETTINGS.audio.use_mono)
return (AudioClipDataset(metadata_list=train_split, set_type=DatasetType.TRAINING, **ds_kwargs),
AudioClipDataset(metadata_list=dev_split, set_type=DatasetType.DEV, **ds_kwargs),
AudioClipDataset(metadata_list=test_split, set_type=DatasetType.TEST, **ds_kwargs))
return (
AudioClipDataset(metadata_list=train_split, set_type=DatasetType.TRAINING, **ds_kwargs),
AudioClipDataset(metadata_list=dev_split, set_type=DatasetType.DEV, **ds_kwargs),
AudioClipDataset(metadata_list=test_split, set_type=DatasetType.TEST, **ds_kwargs),
)
35 changes: 35 additions & 0 deletions howl/utils/sphinx_keyword_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os

from pocketsphinx import AudioFile


class SphinxKeywordDetector():
def __init__(self, target_transcription, threshold=1e-20, verbose=False):
self.target_transcription = target_transcription
self.verbose = verbose
self.kws_config = {
'verbose': self.verbose,
'keyphrase': self.target_transcription,
'kws_threshold': threshold,
'lm': False,
}

def detect(self, file_name):

kws_results = []

self.kws_config['audio_file'] = file_name
audio = AudioFile(**self.kws_config)

for phrase in audio:
result = phrase.segments(detailed=True)

# TODO:: confirm that when multiple keywords are detected, every detection is valid
if len(result) == 1:
start_time = result[0][2] * 10
end_time = result[0][3] * 10
if self.verbose:
print('%4sms ~ %4sms' % (start_time, end_time))
kws_results.append((start_time, end_time))

return kws_results
5 changes: 3 additions & 2 deletions requirements_training.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
openpyxl
pocketsphinx==0.1.15
praat-textgrids==1.3.1
webrtcvad==2.0.10
pytest
pre-commit
pytest
webrtcvad==2.0.10
11 changes: 5 additions & 6 deletions test/data/stitcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import unittest
from pathlib import Path

import torch
from howl.data.dataset import WakeWordDatasetLoader, WordFrameLabeler
from howl.data.stitcher import WordStitcher
from howl.data.tokenize import Vocab
from howl.settings import SETTINGS


class TestStitcher(unittest.TestCase):

def test_compute_statistics(self):
random.seed(1)

Expand All @@ -20,25 +18,26 @@ def test_compute_statistics(self):
SETTINGS.training.token_type = "word"
SETTINGS.inference_engine.inference_sequence = [0, 1, 2]

vocab = Vocab({"hey": 0, "fire": 1, "fox": 2}, oov_token_id=3, oov_word_repr='<OOV>')
vocab = Vocab({"hey": 0, "fire": 1, "fox": 2}, oov_token_id=3, oov_word_repr="<OOV>")
labeler = WordFrameLabeler(vocab)

loader = WakeWordDatasetLoader()
ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=labeler)

test_dataset_path = Path("test/test_data")
test_dataset_path = Path("test/test_data/stitcher")
stitched_dataset_path = test_dataset_path / "stitched"
stitched_dataset_path.mkdir(exist_ok=True)

test_ds, _, _ = loader.load_splits(test_dataset_path, **ds_kwargs)
stitcher = WordStitcher(vocab=vocab)
stitcher = WordStitcher(vocab=vocab, detect_keyword=True)
stitcher.stitch(20, stitched_dataset_path, test_ds)

stitched_train_ds, stitched_dev_ds, stitched_test_ds = stitcher.load_splits(0.5, 0.25, 0.25)

self.assertEqual(len(stitched_train_ds), 10)
self.assertEqual(len(stitched_dev_ds), 5)
self.assertEqual(len(stitched_test_ds), 5)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
stitched
stitched_dataset
28 changes: 28 additions & 0 deletions test/utils/sphinx_keyword_detector_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest

from howl.utils.sphinx_keyword_detector import SphinxKeywordDetector


class TestSphinxKeywordDetector(unittest.TestCase):

def test_detect(self):
"""test word detection from an audio file
"""

hello_world_file = "test/test_data/sphinx_keyword_detector/hello_world.wav"
hello_extractor = SphinxKeywordDetector("hello")
self.assertTrue(len(hello_extractor.detect(hello_world_file)) > 0)
world_extractor = SphinxKeywordDetector("world")
self.assertTrue(len(world_extractor.detect(hello_world_file)) > 0)

hey_fire_fox_file = "test/test_data/sphinx_keyword_detector/hey_fire_fox.wav"
hey_extractor = SphinxKeywordDetector("hey")
self.assertTrue(len(hey_extractor.detect(hey_fire_fox_file)) > 0)
fire_extractor = SphinxKeywordDetector("fire")
self.assertTrue(len(fire_extractor.detect(hey_fire_fox_file)) > 0)
fox_extractor = SphinxKeywordDetector("fox")
self.assertTrue(len(fox_extractor.detect(hey_fire_fox_file)) > 0)


if __name__ == '__main__':
unittest.main()
Loading