Hands-On Neural Networks with Keras
上QQ阅读APP看书,第一时间看更新

Choosing a metric to monitor

The ideal choice is always validation loss or validation accuracy, as these metrics best represent the out of set predictability of our model. This is simply due to the fact that we only update our model weights during a training pass, and not a validation pass. Choosing our training accuracy or loss as a metric (as in the following code) is suboptimal in the sense that you are benchmarking your model by its own definition of a benchmark. To put this in a different way, your model might keep reducing its loss and increasing in accuracy, but it is doing so by rote memorization—not because it is learning general predictive rules as we want it to. As we can see in the following code, by monitoring our training loss, our model continues to decrease loss on the training set, even though the loss on the validation set actually starts increasing shortly after the very first epoch:

import matplotlib.pyplot as plt

acc=history_dict['acc']
loss_values=history_dict['loss']
val_loss_values=history_dict['loss']
val_loss_values=history_dict['val_loss']

epochs = range(1, len(acc) + 1)
plt.plot(epochs, loss_values,'r',label='Training loss')
plt.plot(epochs, val_loss_valuesm, 'rD', label-'Validation loss')
plt.title('Training and validation loss')plt.xlabel('Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

The preceding code generates the following output:

We used Matplotlib to plot out the preceding graph. Similarly, you can clear out the previous loss graph and plot out a new accuracy graph of our training session, as shown in the following code. If we had used validation accuracy as a metric to track our early stopping callback, our training session would have ended after the first epoch, as this is the point in time where our model appears to be the most generalizable to unseen data:

plt.clf()
acc_values=history_dict['acc']
val_acc_values=history_dict['val_acct']
plt.plot(epochs, history_dict.get('acc'),'g',label='Training acc')
plt.plot(epochs, history_dict.get('val_acc'),'gD',label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

The preceding code generates the following output: