-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_phoneme_to_articulation_transformer.py
129 lines (110 loc) · 3.78 KB
/
test_phoneme_to_articulation_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
####################################################################################################
#
# Test the model-free phoneme-to-articulation using a transformer-based network
#
# This experiment is not included in the thesis
#
####################################################################################################
import argparse
import os
import pandas as pd
import torch
import yaml
import ujson
from torch.utils.data import DataLoader
from tqdm import tqdm
from helpers import set_seeds, sequences_from_dict
from phoneme_to_articulation.transformer.evaluation import run_transformer_test
from phoneme_to_articulation.encoder_decoder.dataset import (
ArtSpeechDataset,
pad_sequence_transformer_collate_fn
)
from phoneme_to_articulation.transformer.models import ArtSpeechTransformer
from phoneme_to_articulation.metrics import EuclideanDistance
from settings import UNKNOWN, BLANK
def main(
datadir,
database_name,
batch_size,
test_seq_dict,
state_dict_fpath,
vocab_filepath,
articulators,
save_to,
model_kwargs=None,
clip_tails=True,
num_workers=0
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
default_tokens = [BLANK, UNKNOWN]
vocabulary = {token: i for i, token in enumerate(default_tokens)}
with open(vocab_filepath) as f:
tokens = ujson.load(f)
for i, token in enumerate(tokens, start=len(vocabulary)):
vocabulary[token] = i
test_sequences = sequences_from_dict(datadir, test_seq_dict)
test_dataset = ArtSpeechDataset(
datadir,
database_name,
test_sequences,
vocabulary,
articulators,
clip_tails=clip_tails,
)
test_dataloader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
worker_init_fn=set_seeds,
collate_fn=pad_sequence_transformer_collate_fn
)
num_articulators = len(articulators)
model_kwargs = model_kwargs or {}
best_model = ArtSpeechTransformer(
len(vocabulary),
num_articulators,
**model_kwargs,
)
state_dict = torch.load(state_dict_fpath, map_location=device)
best_model.load_state_dict(state_dict)
best_model.to(device)
print(f"""
ArtSpeechTransformer -- {best_model.total_parameters} parameters
""")
test_outputs_dir = os.path.join(save_to, "test_outputs")
if not os.path.exists(test_outputs_dir):
os.makedirs(test_outputs_dir)
loss_fn = EuclideanDistance("none")
test_results = run_transformer_test(
epoch=0,
model=best_model,
dataloader=test_dataloader,
criterion=loss_fn,
outputs_dir=test_outputs_dir,
articulators=articulators,
device=device,
regularize_out=True,
)
test_results_filepath = os.path.join(save_to, "test_results.json")
with open(test_results_filepath, "w") as f:
ujson.dump(test_results, f)
results_item = {
"exp": None,
"loss": test_results["loss"],
}
for articulator in test_dataset.articulators:
results_item[f"p2cp_{articulator}"] = test_results[articulator]["p2cp"]
results_item[f"p2cp_mm_{articulator}"] = test_results[articulator]["p2cp_mm"]
results_item[f"med_{articulator}"] = test_results[articulator]["med"]
results_item[f"med_mm_{articulator}"] = test_results[articulator]["med_mm"]
df = pd.DataFrame([results_item])
df_filepath = os.path.join(save_to, "test_results.csv")
df.to_csv(df_filepath, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", dest="cfg_filepath")
args = parser.parse_args()
with open(args.cfg_filepath) as f:
cfg = yaml.safe_load(f.read())
main(**cfg)