-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_phoneme_wise_mean_contour.py
138 lines (120 loc) · 4.42 KB
/
train_phoneme_wise_mean_contour.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
####################################################################################################
#
# Train the phoneme-wise mean contour phoneme-to-articulation
#
####################################################################################################
import argparse
import logging
import mlflow
import os
import pandas as pd
import shutil
import tempfile
import ujson
import yaml
from helpers import sequences_from_dict
from phoneme_recognition import UNKNOWN
from phoneme_to_articulation.encoder_decoder.dataset import ArtSpeechDataset
from phoneme_to_articulation.phoneme_wise_mean_contour import train, test
from settings import BASE_DIR
TMPFILES = os.path.join(BASE_DIR, "tmp")
TMP_DIR = tempfile.mkdtemp(dir=TMPFILES)
RESULTS_DIR = os.path.join(TMP_DIR, "results")
if not os.path.exists(RESULTS_DIR):
os.makedirs(RESULTS_DIR)
def main(
database_name,
datadir,
train_seq_dict,
test_seq_dict,
vocab_filepath,
articulators,
state_dict_filepath=None,
clip_tails=True,
weighted=False,
):
default_tokens = [UNKNOWN]
vocabulary = {token: i for i, token in enumerate(default_tokens)}
with open(vocab_filepath) as f:
tokens = ujson.load(f)
for i, token in enumerate(tokens, start=len(vocabulary)):
vocabulary[token] = i
train_sequences = sequences_from_dict(datadir, train_seq_dict)
train_dataset = ArtSpeechDataset(
datadir,
database_name,
train_sequences,
vocabulary,
articulators,
clip_tails=clip_tails,
)
if state_dict_filepath is None:
save_to = os.path.join(RESULTS_DIR, "phoneme_wise_articulators.csv")
df = train(train_dataset, save_to=save_to, weighted=weighted)
mlflow.log_artifact(save_to)
else:
df = pd.read_csv(state_dict_filepath)
for articulator in train_dataset.articulators:
df[articulator] = df[articulator].apply(eval)
mlflow.log_artifact(state_dict_filepath)
logging.info("Finished training phoneme wise mean contour")
test_sequences = sequences_from_dict(datadir, test_seq_dict)
test_dataset = ArtSpeechDataset(
datadir,
database_name,
test_sequences,
vocabulary,
articulators,
clip_tails=clip_tails,
)
test_outputs_dir = os.path.join(RESULTS_DIR, "test_outputs")
if not os.path.exists(test_outputs_dir):
os.makedirs(test_outputs_dir)
test_results = test(
test_dataset,
df,
test_outputs_dir,
weighted=weighted
)
mlflow.log_artifact(test_outputs_dir)
test_results_filepath = os.path.join(RESULTS_DIR, "test_results.json")
with open(test_results_filepath, "w") as f:
ujson.dump(test_results, f)
mlflow.log_artifact(test_results_filepath)
results_item = {
"loss": test_results["loss"],
}
for articulator in test_dataset.articulators:
results_item[f"x_corr_{articulator}"] = test_results[articulator]["x_corr"]
results_item[f"y_corr_{articulator}"] = test_results[articulator]["y_corr"]
df = pd.DataFrame([results_item])
df_filepath = os.path.join(RESULTS_DIR, "test_results.csv")
df.to_csv(df_filepath, index=False)
mlflow.log_artifact(df_filepath)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", dest="config_filepath")
parser.add_argument("--mlflow", dest="mlflow_tracking_uri", default=None)
parser.add_argument("--experiment", dest="experiment_name", default="phoneme_wise_mean_contour")
parser.add_argument("--run_id", dest="run_id", default=None)
parser.add_argument("--run_name", dest="run_name", default=None)
args = parser.parse_args()
if args.mlflow_tracking_uri is not None:
mlflow.set_tracking_uri(args.mlflow_tracking_uri)
with open(args.config_filepath) as f:
cfg = yaml.safe_load(f)
experiment = mlflow.set_experiment(args.experiment_name)
with mlflow.start_run(
run_id=args.run_id,
experiment_id=experiment.experiment_id,
run_name=args.run_name
) as run:
print(f"Experiment ID: {experiment.experiment_id}\nRun ID: {run.info.run_id}")
try:
mlflow.log_artifact(args.config_filepath)
except shutil.SameFileError:
logging.info("Skipping logging config file since it already exists.")
try:
main(**cfg)
finally:
shutil.rmtree(TMP_DIR)