Chainer sklearn wrapper

If you are familiar with machine learning before deep learning becomes popular, you might have been using sklearn (scikit-learn), which is very popular machine learning library in python.

Its interface is used for a long time, and I thought it is better to support this interface with python to allow users to try deep learning more easily! I wrote Chainer sklearn wrapper.

Here, I will explain how to use

Construct the model

Mainly, conventional machine learning task can be categorized in following three:

  • Classification – Classify the input’s class(label), sometimes the output is probability of being each class(label).
  • Regression – Predict the target feature’s value based on the input features’ value. 
  • Clustering – Given only input without label, make a group whose feature is similar in input space.

I want to support classifier model and regression model in deep learning.

Classifier model

You can use SklearnWrapperClassifier class, it can be constructed in almost same way with current Classifier class in Chainer. Just define your own predictor model and set it to classifier.

# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

# Construct classifier model
n_unit = 50
model = SklearnWrapperClassifier(MLP(n_unit, 10))

Regression model

[WIP] Currently it is not implemented yet..

Training the model with fit

Once the model is constructed, you can call fit function in the same way as sklearn.

Example 1. Iris data classification

Prepare the input data X and target data (label) y, and call fit.

    # Load the iris dataset
    data, target = load_iris(return_X_y=True)
    X = data.astype(np.float32)
    y = target.astype(np.int32)

    # Construct model
    model = SklearnWrapperClassifier(MLP(args.unit, 10), device=args.gpu)

    # Train the model with fit, y, epoch=args.epoch)

See for whole training code.

Example 2. MNIST data classification

You can also use fit function with Chainer’s dataset class.

Below example shows to fit the model using Chainer’s TupleDataset.

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    model = SklearnWrapperClassifier(MLP(args.unit, 10), device=args.gpu)

See for whole training code.

Predict with trained model

You can use predict function to get the classification result, and predict_proba method to get the probability for being each class.

    # --- Example 1. Predict all test data ---
    outputs = model.predict(test,

    x, t = model.inputs

Set retain_inputs option to True to retrieve the model inputs. This convenient method is useful for chainer dataset because for example data augmentation preprocessing is done every time of when the data is accessed by index (get_example method of DatasetMixin), and thus it is not guaranteed to get same input when accessed the data with same index.

You may also predict only sliced data,

    outputs = model.predict_proba(test[:20])
    x, t = model.inputs
    #y, = outputs
    y = outputs

See for whole training code.

GridSearchCV, RandomizedSearchCV

It also supports to use GridSearchCV and RandomizedSearchCV implemented in sklearn for automated hyper parameter search.

One example is as follows,

    predictor = MLP(args.unit, 10)
    model = SklearnWrapperClassifier(predictor, device=args.gpu)
    optimizer1 = chainer.optimizers.SGD()
    optimizer2 = chainer.optimizers.Adam()
    gs = GridSearchCV(model,
                          # hyperparameter search for predictor
                          #'n_units': [10, 50],
                          # hyperparameter search for different optimizers
                          'optimizer': [optimizer1, optimizer2],
                          # 'batchsize', 'epoch' can be used as hyperparameter
                          'epoch': [args.epoch],
                          'batchsize': [100, 1000],
                          'progress_report': False,
                      }, verbose=2)

    best_model = gs.best_estimator_
    best_mlp = best_model.predictor

    # Save trained model
    serializers.save_npz('{}/best_mlp.model'.format(args.out), best_mlp)

When you want to search the hyper parameter which is used in predictor’s constructor or optimizer’s constructor, you can search these hyper parameters as well.

See for whole training code and more details.

Leave a Comment

Your email address will not be published. Required fields are marked *