Source code for hosa.callbacks.early_stopping

"""
Utilities for implementing early stopping callbacks.
"""
import numpy as np
import tensorflow as tf


[docs]class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback): """ This class implements the early stopping for avoiding overfitting the model. The training is stopped when the monitored metric has stopped improving. Args: class_model: Class of the object to be optimized. Available options are: :class:`.RNNClassification`, :class:`.RNNRegression`, :class:`.CNNClassification` and :class:`.CNNRegression`. patience (int): Number of epochs with no improvement after which training will be stopped. validation_data (numpy.ndarray): Input data extracted from the validation dataset ( which was itself extracted from the training dataset). imbalance_correction (bool): `True` if correction for imbalance should be applied to the metrics; `False` otherwise. rtol (float): The relative tolerance parameter, as used in `numpy.isclose`. See `numpy.isclose <https://numpy.org/doc/stable/reference/generated/numpy.isclose .html>`_. atol (float): The absolute tolerance parameter, as used in `numpy.isclose`. See `numpy.isclose <https://numpy.org/doc/stable/reference/generated/numpy.isclose .html>`_. """ def __init__(self, class_model, patience, validation_data, imbalance_correction=False, rtol=1e-03, atol=1e-04): super().__init__() self.class_model = class_model self.model = self.class_model.model self.patience = patience self.imbalance_correction = imbalance_correction self.x_validation, self.y_validation = validation_data self.best_weights = self.wait = self.stopped_epoch = self.best_metric_value = \ self.compare_function = None self.rtol, self.atol = rtol, atol self.early_stopping = False
[docs] def on_train_begin(self, logs=None): """ Called at the beginning of training to initialize the variables for early stopping. Args: logs (dict): Currently no data is passed to this argument for this method but that may change in the future. """ self.wait = 0 self.stopped_epoch = 0 if 'Regression' in str(type(self.class_model)): self.best_metric_value = np.inf self.compare_function = np.less elif 'Classification' in str(type(self.class_model)): self.best_metric_value = -np.inf self.compare_function = np.greater else: raise ValueError( 'The class of the model is invalid. Only regression and classification models ' 'are currently available.')
[docs] def on_epoch_end(self, epoch, logs=None): """ Checks, based on the patience value, if the training should stop. After stopping, it restores the model's weights from the epoch with the best value of the monitored quantity. Args: epoch (int): Index of epoch. logs (dict): Currently no data is passed to this argument for this method but that may change in the future. """ current_metric, *_ = self.class_model.score(self.x_validation, self.y_validation, imbalance_correction=self.imbalance_correction) if self.compare_function(current_metric, self.best_metric_value) and not np.isclose( current_metric, self.best_metric_value, rtol=self.rtol, atol=self.atol): self.best_metric_value = current_metric self.wait = 0 self.best_weights = self.model.get_weights() else: self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True self.model.set_weights(self.best_weights)
[docs] def on_train_end(self, logs=None): """ This function is called when the training is finished, and it is used to set a flag for early stopping. Args: logs (dict): Currently no data is passed to this argument for this method but that may change in the future. """ if self.stopped_epoch > 0: self.early_stopping = True