Chainer sklearn wrapper


If you are familiar with machine learning before deep learning becomes popular, you might have been using sklearn (scikit-learn), which is very popular machine learning library in python.

Its interface is used for a long time, and I thought it is better to support this interface with python to allow users to try deep learning more easily! I wrote Chainer sklearn wrapper.

Here, I will explain how to use

Construct the model

Mainly, conventional machine learning task can be categorized in following three:

  • Classification – Classify the input’s class(label), sometimes the output is probability of being each class(label).
  • Regression – Predict the target feature’s value based on the input features’ value. 
  • Clustering – Given only input without label, make a group whose feature is similar in input space.

I want to support classifier model and regression model in deep learning.

Classifier model

You can use SklearnWrapperClassifier class, it can be constructed in almost same way with current Classifier class in Chainer. Just define your own predictor model and set it to classifier.


Regression model

[WIP] Currently it is not implemented yet..

Training the model with fit

Once the model is constructed, you can call fit function in the same way as sklearn.

Example 1. Iris data classification

Prepare the input data X and target data (label) y, and call fit.


See for whole training code.


Example 2. MNIST data classification

You can also use fit function with Chainer’s dataset class.

Below example shows to fit the model using Chainer’s TupleDataset.


See for whole training code.

Predict with trained model

You can use predict function to get the classification result, and predict_proba method to get the probability for being each class.


Set retain_inputs option to True to retrieve the model inputs. This convenient method is useful for chainer dataset because for example data augmentation preprocessing is done every time of when the data is accessed by index (get_example method of DatasetMixin), and thus it is not guaranteed to get same input when accessed the data with same index.

You may also predict only sliced data,




See for whole training code.

GridSearchCV, RandomizedSearchCV

It also supports to use GridSearchCV and RandomizedSearchCV implemented in sklearn for automated hyper parameter search.

One example is as follows,


When you want to search the hyper parameter which is used in predictor’s constructor or optimizer’s constructor, you can search these hyper parameters as well.

See for whole training code and more details.


Sponsored Links

Leave a Reply

Your email address will not be published.