Predict code is easy, implemented in predict_simple_sequence.py
.
First, construct the model and load the trained model parameters,
# Model Setup model = archs[args.arch](n_vocab=N_VOCABULARY, n_units=args.unit) if args.gpu >= 0: chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current model.to_gpu() # Copy the model to the GPU xp = np if args.gpu < 0 else cuda.cupy serializers.load_npz(args.modelpath, model)
Then we only specify the first index (corresponds to word id), primeindex
, and generate next index. We can generate next index repeatedly based on the generated index.
# Predict predicted_sequence = [prev_index] for i in range(args.length): prev = chainer.Variable(xp.array([prev_index], dtype=xp.int32)) current = model(prev) current_index = np.argmax(cuda.to_cpu(current.data)) predicted_sequence.append(current_index) prev_index = current_index print('Predicted sequence: ', predicted_sequence)
The result is the following, successfully generate the sequence.
Predicted sequence: [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5]
This is a simple example to check if RNN has an ability to remember past sequence, so I didn’t prepare validation data. I just wanted to check if the RNN model can “memorize” the training data sequence or not.
Note that the situation is little bit different during training phase and inference/predict phase. In training phase, the model is trained to generate \(xt\) based on the correct sequence \([x_0, x_1, \dots, x_{t-1}]\).
However in predict phase, we only specify the first index \(x_0\), and the model will generate \(x’_1\) (here ‘ indicates output from the model). After that, the model need to generate \(x’_2\) based on \(x’_1\). Therefore, the model will generate \(x’_t\) based on the predicted sequence \([x_0, x’_1 \dots, x’_{t-1}]\).