Define your own trainer extensions in Chainer



So how to implement custom extensions for trainer in Chainer? There are mainly 3 approaches.

  1. Define function
  2. Use decorator,
  3. Define class

Most of the case, 1. Define function is the easiest way to quickly implement your extension.


1. Define function

Just a function can be a trainer extension. Simply, define a function which takes one argument (in below case “t”), which is trainer instance


1-1. define function


Here the argument of my_extension function, t, is trainer instance. You may obtain a lot of information related to the training procedure from trainer. In this case, I took the current epoch information by accessing updater’s property (trainer holds updater’s instance), t.updater.epoch_detail.


The extension is invoked based on the trigger configuration. In above code trigger=(1, 'epoch') means that this extension is invoked every once in one epoch. 

Try changing the code from trainer.extend(my_extension, trigger=(1, 'epoch')) to trainer.extend(my_extension, trigger=(1, 'iteration')). Then the code is invoked every one iteration (Causion: it outpus the log very frequently, please stop it after executed and you have checked the behavior). 


1-2. Use lambda

Instead of defining a function explicitly, you can simply use lambda function if the extension’s logic is simple.



2. Use make_extension decorator on function





3. Define as a class





Sponsored Links

Leave a Reply

Your email address will not be published.