hosa.models.rnn.rnn_models.RNNClassification.fit

RNNClassification.fit(x, y, validation_size=0.33, shuffle=True, rtol=0.001, atol=0.0001, class_weights=None, imbalance_correction=False, **kwargs)[source]

Fits the model to data matrix x and target(s) y.

Parameters
  • x (numpy.ndarray) – Input data.

  • y (numpy.ndarray) – Target values (i.e., class labels).

  • class_weights (None or dict) – Dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only).

  • validation_size (float or int) – Proportion of the train dataset to include in the validation split.

  • shuffle (bool) – Whether to shuffle the data before splitting.

  • atol (float) – Absolute tolerance used for early stopping based on the performance metric.

  • rtol (float) – Relative tolerance used for early stopping based on the performance metric.

  • class_weights – Dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). imbalance_correction (bool): True if correction for imbalance should be applied to the metrics; False otherwise.

  • **kwargs – Extra arguments that are used in the TensorFlow’s model fit function. See here.

Returns

tensorflow.keras.Sequential – Returns a trained TensorFlow model.