Define your own trainer extensions in Chainer

 

 

So how to implement custom extensions for trainer in Chainer? There are mainly 3 approaches.

  1. Define function
  2. Use decorator, @chainer.training.extension.make_extension
  3. Define class

Most of the case, 1. Define function is the easiest way to quickly implement your extension.

 

1. Define function

Just a function can be a trainer extension. Simply, define a function which takes one argument (in below case “t”), which is trainer instance

 

1-1. define function

 

Here the argument of my_extension function, t, is trainer instance. You may obtain a lot of information related to the training procedure from trainer. In this case, I took the current epoch information by accessing updater’s property (trainer holds updater’s instance), t.updater.epoch_detail.

 

The extension is invoked based on the trigger configuration. In above code trigger=(1, 'epoch') means that this extension is invoked every once in one epoch. 

Try changing the code from trainer.extend(my_extension, trigger=(1, 'epoch')) to trainer.extend(my_extension, trigger=(1, 'iteration')). Then the code is invoked every one iteration (Causion: it outpus the log very frequently, please stop it after executed and you have checked the behavior). 

 

1-2. Use lambda

Instead of defining a function explicitly, you can simply use lambda function if the extension’s logic is simple.

 

 

2. Use make_extension decorator on function

TBD…

 

 

 

3. Define as a class

 

 

 

 

Sponsored Links

Leave a Reply

Your email address will not be published.