So how to implement custom extensions for trainer in Chainer? There are mainly 3 approaches.
- Define function
- Use decorator, @chainer.training.extension.make_extension
- Define class
Most of the case, 1. Define function is the easiest way to quickly implement your extension.
Contents
1. Define function
Just a function can be a trainer extension. Simply, define a function which takes one argument (in below case “t”), which is trainer instance.
1-1. define function
# 1-1. Define function for trainer extension def my_extension(t): print('my_extension function is called at epoch {}!' .format(t.updater.epoch_detail)) # Change optimizer's learning rate optimizer.lr *= 0.99 print('Updated optimizer.lr to {}'.format(optimizer.lr)) trainer.extend(my_extension, trigger=(1, 'epoch'))
Here the argument of my_extension
function, t
, is trainer instance. You may obtain a lot of information related to the training procedure from trainer. In this case, I took the current epoch information by accessing updater’s property (trainer holds updater’s instance), t.updater.epoch_detail
.
The extension is invoked based on the trigger
configuration. In above code trigger=(1, 'epoch')
means that this extension is invoked every once in one epoch.
Try changing the code from trainer.extend(my_extension, trigger=(1, 'epoch'))
to trainer.extend(my_extension, trigger=(1, 'iteration'))
. Then the code is invoked every one iteration (Causion: it outpus the log very frequently, please stop it after executed and you have checked the behavior).
1-2. Use lambda
Instead of defining a function explicitly, you can simply use lambda function if the extension’s logic is simple.
# Use lambda function for extension trainer.extend(lambda t: print('lambda function called at epoch {}!' .format(t.updater.epoch_detail)), trigger=(1, 'epoch'))
2. Use make_extension decorator on function
TBD…