hosa.models.rnn.rnn_models.RNNClassification

class hosa.models.rnn.rnn_models.RNNClassification(n_outputs, n_neurons_dense_layer, n_units, n_subs_layers, is_bidirectional=False, model_type='lstm', optimizer='adam', dropout_percentage=0.1, metrics=None, activation_function_dense='relu', kernel_initializer='normal', batch_size=1000, epochs=50, patience=5, **kwargs)[source]

Bases: hosa.models.rnn.rnn_models.BaseRNN

Recurrent Neural Network (RNN) model classifier.

The model comprises an input layer (an RNN or a bidirectional RNN cell), n_subs_layers subsequent layers (similar to the input cell), a dropout layer, a dense layer, and an output layer.

Parameters
  • n_outputs (int) – Number of class labels to predict.

  • n_neurons_dense_layer (int) – Number of neurons units of the penultimate dense layer ( i.e., before the output layer).

  • n_units (int) – Dimensionality of the output space, i.e., the dimensionality of the hidden state.

  • n_subs_layers (int) – Number of subsequent layers beteween the input and output layers.

  • is_bidirectional (bool) – If true, then bidirectional layers will be used to build the RNN model.

  • model_type (str) – Type of RNN model to be used. Available options are lstm, for a Long Short-Term Memory model, or gru, for a Gated Recurrent Unit model.

  • optimizer (str) – Name of the optimizer. See tensorflow.keras.optimizers.

  • dropout_percentage (float) – Fraction of the input units to drop.

  • metrics (list) – List of metrics to be evaluated by the model during training and testing. Each item of the list can be a string (name of a TensorFlow’s built-in function), function, or a tf.keras.metrics.Metric instance. If None, metrics will default to ['accuracy'].

  • activation_function_dense (str) – Activation function to use on the penultimate dense layer. If not specified, no activation is applied (i.e., uses the linear activation function). See tensorflow.keras.activations.

  • kernel_initializer (str) – Initializer for the kernel weights matrix, used for the linear transformation of the inputs.

  • batch_size (int or None) – Number of samples per batch of computation. If None, batch_size will default to 32.

  • epochs (int) – Maximum number of epochs to train the model.

  • patience (int) – Number of epochs with no improvement after which training will be stopped.

  • **kwargsIgnored. Extra arguments that are used for compatibility’s sake.

Examples

 1from keras.datasets import imdb
 2from keras_preprocessing.sequence import pad_sequences
 3from tensorflow import keras
 4
 5from hosa.models.rnn import RNNClassification
 6from hosa.aux import create_overlapping
 7
 8# 1 - Load and split the data
 9max_sequence_length = 50
10fashion_mnist = keras.datasets.fashion_mnist
11(x_train, y_train), (X_test, y_test) = imdb.load_data(num_words=50)
12# 2 - Prepare the data for rnn input
13x_train = pad_sequences(x_train, maxlen=max_sequence_length, value=0.0)
14X_test = pad_sequences(X_test, maxlen=max_sequence_length, value=0.0)
15x_train, y_train = create_overlapping(x_train, y_train, RNNClassification,
16'central', 3, stride=1, timesteps=2)
17X_test, y_test = create_overlapping(X_test, y_test, RNNClassification, 'central',
183, stride=1, timesteps=2)
19# 3 - Create and fit the model
20clf = RNNClassification(2, 10, is_bidirectional=True)
21clf.prepare(x_train, y_train)
22clf.compile()
23clf.fit(x_train, y_train)
24# 4 - Calculate predictions
25clf.predict(X_test)
26# 5 - Compute the score
27score = clf.score(X_test, y_test)

Methods

aux_fit(x, y, callback, validation_size[, ...])

Auxiliar function for classification and regression models compatibility.

compile()

Compiles the model for training.

fit(x, y[, validation_size, shuffle, rtol, ...])

Fits the model to data matrix x and target(s) y.

predict(x, **kwargs)

Predicts the target values using the input data in the trained model.

prepare(x, y)

Prepares the model by adding the layers to the estimator: input layer, n_subs_layers subsequent layers, a dropout layer, a dense layer, and an output layer.

score(x, y[, imbalance_correction])

Computes the performance metrics on the given input data and target values.