hosa.models.rnn.rnn_models.RNNClassification.fit
- RNNClassification.fit(x, y, validation_size=0.33, shuffle=True, rtol=0.001, atol=0.0001, class_weights=None, imbalance_correction=False, **kwargs)[source]
Fits the model to data matrix x and target(s) y.
- Parameters
x (numpy.ndarray) – Input data.
y (numpy.ndarray) – Target values (i.e., class labels).
class_weights (None or dict) – Dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only).
validation_size (float or int) – Proportion of the train dataset to include in the validation split.
shuffle (bool) – Whether to shuffle the data before splitting.
atol (float) – Absolute tolerance used for early stopping based on the performance metric.
rtol (float) – Relative tolerance used for early stopping based on the performance metric.
class_weights – Dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). imbalance_correction (bool): True if correction for imbalance should be applied to the metrics; False otherwise.
**kwargs – Extra arguments that are used in the TensorFlow’s model
fit
function. See here.
- Returns
tensorflow.keras.Sequential – Returns a trained TensorFlow model.