-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_plot_loss_acc.py
32 lines (22 loc) · 1.2 KB
/
main_plot_loss_acc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
from matplotlib import pyplot as plt
import pandas as pd
# model_sizes = [8, 16, 32, 64, 128, 256]
# for model_size in model_sizes:
model = f'policy_vae_256'
os.makedirs(f'output/{model}/plots', exist_ok=True)
train_log = pd.read_csv(f'output/{model}/training_info.csv')
val_log = pd.read_csv(f'output/{model}/validation_info.csv')
logs = pd.merge(train_log, val_log, on='epoch', suffixes=['_train', '_val'])
logs[['mean_total_loss_train', 'mean_progs_loss_train', 'mean_a_h_loss_train', 'mean_latent_loss_train']].plot.line(grid=True)
plt.savefig(f'output/{model}/plots/mean_loss_train.png')
plt.close()
logs[['mean_total_loss_val', 'mean_progs_loss_val', 'mean_a_h_loss_val', 'mean_latent_loss_val']].plot.line(grid=True)
plt.savefig(f'output/{model}/plots/mean_loss_val.png')
plt.close()
logs[['mean_progs_t_accuracy_train', 'mean_progs_s_accuracy_train', 'mean_a_h_t_accuracy_train', 'mean_a_h_s_accuracy_train']].plot.line(grid=True)
plt.savefig(f'output/{model}/plots/mean_acc_train.png')
plt.close()
logs[['mean_progs_t_accuracy_val', 'mean_progs_s_accuracy_val', 'mean_a_h_t_accuracy_val', 'mean_a_h_s_accuracy_val']].plot.line(grid=True)
plt.savefig(f'output/{model}/plots/mean_acc_val.png')
plt.close()