hosa.models.rnn.rnn_models.RNNClassification.predict

RNNClassification.predict(x, **kwargs)[source]

Predicts the target values using the input data in the trained model.

Parameters
  • x (numpy.ndarray) – Input data.

  • **kwargs – Extra arguments that are used in the TensorFlow’s model predict function. See here.

Returns

tuple – Returns a tuple containing the probability estimates and predicted classes.