Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial version of serialize test case #82

Merged
merged 3 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 43 additions & 44 deletions howl/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,24 @@

from .phone import Phone, PhonePhrase

__all__ = ['AudioClipExample',
'AudioClipMetadata',
'AudioDatasetStatistics',
'ClassificationBatch',
'ClassificationClipExample',
'DatasetType',
'EmplacableExample',
'WakeWordClipExample',
'SequenceBatch',
'FrameLabelData',
'UNKNOWN_TRANSCRIPTION',
'NEGATIVE_CLASS']


UNKNOWN_TRANSCRIPTION = '[UNKNOWN]'
NEGATIVE_CLASS = '[NEGATIVE]'
__all__ = [
"AudioClipExample",
"AudioClipMetadata",
"AudioDatasetStatistics",
"ClassificationBatch",
"ClassificationClipExample",
"DatasetType",
"EmplacableExample",
"WakeWordClipExample",
"SequenceBatch",
"FrameLabelData",
"UNKNOWN_TRANSCRIPTION",
"NEGATIVE_CLASS",
]


UNKNOWN_TRANSCRIPTION = "[UNKNOWN]"
NEGATIVE_CLASS = "[NEGATIVE]"


@dataclass
Expand All @@ -43,17 +45,18 @@ class AudioDatasetStatistics:


class AudioClipMetadata(BaseModel):
path: Optional[Path] = Path('.')
path: Optional[Path] = Path(".")
phone_strings: Optional[List[str]]
words: Optional[List[str]]
phone_end_timestamps: Optional[List[float]]
word_end_timestamps: Optional[List[float]]
end_timestamps: Optional[List[float]] # TODO: remove, backwards compat right now
transcription: Optional[str] = ''
transcription: Optional[str] = ""

# TODO:: id should be an explicit variable in order to support datasets creation with the audio data in memory
@property
def audio_id(self) -> str:
return self.path.name.split('.', 1)[0]
return self.path.name.split(".", 1)[0]

@property
def phone_phrase(self) -> Optional[PhonePhrase]:
Expand All @@ -63,15 +66,13 @@ def phone_phrase(self) -> Optional[PhonePhrase]:
class EmplacableExample:
audio_data: torch.Tensor

def emplaced_audio_data(self,
audio_data: torch.Tensor,
scale: float = 1,
bias: float = 0,
new: bool = False) -> 'EmplacableExample':
def emplaced_audio_data(
self, audio_data: torch.Tensor, scale: float = 1, bias: float = 0, new: bool = False
) -> "EmplacableExample":
raise NotImplementedError


T = TypeVar('T', bound=AudioClipMetadata)
T = TypeVar("T", bound=AudioClipMetadata)


class AudioClipExample(EmplacableExample, Generic[T]):
Expand All @@ -83,15 +84,13 @@ def __init__(self, metadata: T, audio_data: torch.Tensor, sample_rate: int):
def pin_memory(self):
self.audio_data.pin_memory()

def emplaced_audio_data(self,
audio_data: torch.Tensor,
scale: float = 1,
bias: float = 0,
new: bool = False) -> 'AudioClipExample':
def emplaced_audio_data(
self, audio_data: torch.Tensor, scale: float = 1, bias: float = 0, new: bool = False
) -> "AudioClipExample":
metadata = self.metadata
if new:
metadata = deepcopy(metadata)
metadata.transcription = ''
metadata.transcription = ""
return AudioClipExample(metadata, audio_data, self.sample_rate)


Expand All @@ -102,15 +101,15 @@ class ClassificationBatch:
lengths: torch.Tensor

@classmethod
def from_single(cls, audio_clip: torch.Tensor, label: int) -> 'ClassificationBatch':
def from_single(cls, audio_clip: torch.Tensor, label: int) -> "ClassificationBatch":
return cls(audio_clip.unsqueeze(0), torch.tensor([label]), torch.tensor([audio_clip.size(-1)]))

def pin_memory(self):
self.audio_data.pin_memory()
self.labels.pin_memory()
self.lengths.pin_memory()

def to(self, device: torch.device) -> 'ClassificationBatch':
def to(self, device: torch.device) -> "ClassificationBatch":
self.audio_data = self.audio_data.to(device)
if self.labels is not None:
self.labels = self.labels.to(device)
Expand Down Expand Up @@ -139,7 +138,7 @@ def pin_memory(self):
self.audio_lengths.pin_memory()
self.label_lengths.pin_memory()

def to(self, device: torch.device) -> 'SequenceBatch':
def to(self, device: torch.device) -> "SequenceBatch":
self.audio_data = self.audio_data.to(device)
self.labels = self.labels.to(device)
self.audio_lengths = self.audio_lengths.to(device)
Expand All @@ -153,17 +152,17 @@ def __init__(self, label_data: FrameLabelData, *args, **kwargs):
super().__init__(*args, **kwargs)
self.label_data = label_data

def emplaced_audio_data(self,
audio_data: torch.Tensor,
scale: float = 1,
bias: float = 0,
new: bool = False) -> 'WakeWordClipExample':
def emplaced_audio_data(
self, audio_data: torch.Tensor, scale: float = 1, bias: float = 0, new: bool = False
) -> "WakeWordClipExample":
ex = super().emplaced_audio_data(audio_data, scale, bias, new)
label_data = {} if new else {scale * k + bias: v for k, v in self.label_data.timestamp_label_map.items()}
return WakeWordClipExample(FrameLabelData(label_data, self.label_data.start_timestamp, self.label_data.char_indices),
ex.metadata,
audio_data,
self.sample_rate)
return WakeWordClipExample(
FrameLabelData(label_data, self.label_data.start_timestamp, self.label_data.char_indices),
ex.metadata,
audio_data,
self.sample_rate,
)


@dataclass
Expand All @@ -172,7 +171,7 @@ def __init__(self, label, *args, **kwargs):
super().__init__(*args, **kwargs)
self.label = label

def emplaced_audio_data(self, audio_data: torch.Tensor, **kwargs) -> 'ClassificationClipExample':
def emplaced_audio_data(self, audio_data: torch.Tensor, **kwargs) -> "ClassificationClipExample":
return ClassificationClipExample(self.label, self.metadata, audio_data, self.sample_rate)


Expand Down
Loading