hosa.models.rnn.rnn_models.RNNRegression.aux_fit
- RNNRegression.aux_fit(x, y, callback, validation_size, rtol=0.001, atol=0.0001, class_weights=None, imbalance_correction=None, shuffle=True, **kwargs)
Auxiliar function for classification and regression models compatibility.
Warning
This function is not meant to be called by itself. It is just an auxiliary function called by the child classes’
fitfunction.- Parameters
x (numpy.ndarray) – Input data.
y (numpy.ndarray) – Target values (class labels in classification, real numbers in regression).
callback (object) – Early stopping callback for halting the model’s training.
validation_size (float or int) – Proportion of the training dataset that will be used the validation split.
atol (float) – Absolute tolerance used for early stopping based on the performance metric.
rtol (float) – Relative tolerance used for early stopping based on the performance metric.
class_weights (None or dict) – Dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). Only used for classification problems. Ignored for regression.
imbalance_correction (None or bool) – Whether to apply correction to class imbalances. Only used for classification problems. Ignored for regression.
shuffle (bool) – Whether to shuffle the data before splitting.
**kwargs – Extra arguments used in the TensorFlow’s model
fitfunction. See here.