-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
56 lines (43 loc) · 2.23 KB
/
utils.py
File metadata and controls
56 lines (43 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import re
import os
import numpy as np
import matplotlib.pyplot as plt
def plot_loss_acc(log_dir):
network_files = os.listdir(log_dir)
train_acc_file = [string for string in network_files if 'train_accuracies' in string]
train_accs = np.loadtxt(log_dir + '/' + train_acc_file[0])
validation_acc_file = [string for string in network_files if 'val_accuracies' in string]
validation_accs = np.loadtxt(log_dir + '/' + validation_acc_file[0])
train_loss_file = [string for string in network_files if 'train_losses' in string]
train_losses = np.loadtxt(log_dir + '/' + train_loss_file[0])
validation_loss_file = [string for string in network_files if 'val_losses' in string]
validation_losses = np.loadtxt(log_dir + '/' + validation_loss_file[0])
bestEpoch = validation_acc_file[0].split('_')
bestEpoch = bestEpoch[-1]
bestEpoch = bestEpoch.split('.')
bestEpoch = bestEpoch[0]
bestEpoch = int(re.search(r'\d+', bestEpoch).group())
epochs = np.arange(train_losses.shape[0])
plt.figure()
plt.plot(epochs, train_losses, label="Training loss", c='b')
plt.plot(epochs, validation_losses, label="Validation loss", c='r')
plt.plot(bestEpoch, validation_losses[bestEpoch], label="Best epoch", c='y', marker='.', markersize=10)
plt.text(bestEpoch+.01, validation_losses[bestEpoch]+.01, str(bestEpoch) + ' - ' + str(round(validation_losses[bestEpoch], 3)), fontsize=8)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss along epochs')
plt.legend()
plt.draw()
plt.savefig(log_dir + '/loss.png')
plt.figure()
plt.plot(epochs, train_accs, label="Training accuracy", c='b')
plt.plot(epochs, validation_accs, label="Validation accuracy", c='r')
plt.plot(bestEpoch, validation_accs[bestEpoch], label="Best epoch", c='y', marker='.', markersize=10)
plt.text(bestEpoch+.001, validation_accs[bestEpoch]+.001, str(bestEpoch) + ' - ' + str(round(validation_accs[bestEpoch], 3)), fontsize=8)
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy along epochs')
plt.legend()
plt.draw()
plt.savefig(log_dir + '/accuracy.png')
plt.show()