Define your own trainer extensions in Chainer

So how to implement custom extensions for trainer in Chainer? There are mainly 3 approaches.

  1. Define function
  2. Use decorator,
  3. Define class

Most of the case, 1. Define function is the easiest way to quickly implement your extension.

1. Define function

Just a function can be a trainer extension. Simply, define a function which takes one argument (in below case “t”), which is trainer instance

1-1. define function

    # 1-1. Define function for trainer extension
    def my_extension(t):
        print('my_extension function is called at epoch {}!'
        # Change optimizer's learning rate *= 0.99
        print('Updated to {}'.format(

    trainer.extend(my_extension, trigger=(1, 'epoch'))

Here the argument of my_extension function, t, is trainer instance. You may obtain a lot of information related to the training procedure from trainer. In this case, I took the current epoch information by accessing updater’s property (trainer holds updater’s instance), t.updater.epoch_detail.

The extension is invoked based on the trigger configuration. In above code trigger=(1, 'epoch') means that this extension is invoked every once in one epoch. 

Try changing the code from trainer.extend(my_extension, trigger=(1, 'epoch')) to trainer.extend(my_extension, trigger=(1, 'iteration')). Then the code is invoked every one iteration (Causion: it outpus the log very frequently, please stop it after executed and you have checked the behavior). 

1-2. Use lambda

Instead of defining a function explicitly, you can simply use lambda function if the extension’s logic is simple.

    # Use lambda function for extension
    trainer.extend(lambda t: print('lambda function called at epoch {}!'
                   trigger=(1, 'epoch'))

2. Use make_extension decorator on function


3. Define as a class

Leave a Comment

Your email address will not be published. Required fields are marked *