Write predict code using concat_examples


This tutorial corresponds to 03_custom_dataset_mlp folder in the source code.


We have trained the model with own dataset, MyDataset, in previous post, let’s write predict code.

Source code:


Prepare test data

It is not difficult for the model to fit to the train data, so we will check how the model is fit to the test data. 

I used the same seed (=13) to extract the train and test data used in the training phase.


Load trained model

The procedure to load the trained model is

  1. Instantiate the model (which is a subclass of Chain: here, it is MyMLP)
  2. Send the parameters to GPU if necessary.
  3. Load the trained parameters using serializers.load_npz function.


Predict with minibatch

Prepare minibatch from dataset with concat_examples

We need to feed minibatch instead of dataset itself into the model. The minibatch was constructed by the Iterator in training phase. In predict phase, it might be too much to prepare Iterator, then how to construct minibatch?

There is a convenient function, concat_examples, to prepare minibatch from dataset. It works as written in this figure. 

  • chainer.dataset.concat_examples(batch, device=None, padding=None)



concat_examples converts dataset list into minibatch for each feature (here, x and y) which can be input into neural network.

Usually when we access dataset by slice indexing, for example dataset[i:j], it returns a list where data is sequential. concat_examples separates each element of data and concatenates it to generate minibatch.

You can use as follows,


※ You can see more detail actual usage example code of concat_examples in dataset_introduction.ipynb, also refer official doc for more details.


Predict code configuration

Predict phase has some difference compared to training phase, 

  1. Function behavior
    – Expected behavior of some functions are different between training phase and validation/predict phase. For example, F.dropout is expected to drop out some unit in the training phase while it is better to not to drop out in validation/predict phase.

    These kinds of function behavior is handled by chainer.config.train configuration.

  2. Back propagation is not necessary
    When back propagation is enabled, the model need to construct computational graph which requires additional memory. However back propagation is not necessary in validation/predict phase and we can omit constructing computational graph to reduce memory usage.
    This can be controlled by chainer.config.enable_backprop, and chainer.no_backprop_mode() function can be used for convenient method.

By considering above, we can write predict code in the MyMLP model as,


Finally, predict code can be written as follows,


Plot the result

This is a regression task, so let’s see the difference between actual point and model’s predicted point.

which outputs this figure,



Appendix: Refactoring predict code

Move predict function into model class: if you want to simplify main predict code in predict_custom_dataset1.py, you may move predict for loop into model side.


In MyMLP class, define predict2 method as

then, we can write main predict code very simply,



model prediction is written in one line of code,


Sponsored Links

Leave a Reply

Your email address will not be published.