hosa.models.cnn.cnn_models.CNNRegression.fit

CNNRegression.fit(x, y, validation_size=0.33, shuffle=True, atol=0.0001, rtol=0.001, **kwargs)[source]

Fits the model to data matrix x and target(s) y.

Parameters
  • x (numpy.ndarray) – Input data.

  • y (numpy.ndarray) – Target values (i.e., real numbers).

  • validation_size (float or int) – Proportion of the train dataset to include in the validation split.

  • shuffle (bool) – Whether to shuffle the data before splitting.

  • atol (float) – Absolute tolerance used for early stopping based on the performance metric.

  • rtol (float) – Relative tolerance used for early stopping based on the performance metric.

  • **kwargs – Extra arguments that are used in the TensorFlow’s model fit function. See here.

Returns

tensorflow.keras.Sequential – Returns a trained TensorFlow model.