So how to implement custom extensions for trainer in Chainer? There are mainly 3 approaches.
- Define function
- Use decorator, @chainer.training.extension.make_extension
- Define class
Most of the case, 1. Define function is the easiest way to quickly implement your extension.
Contents
1. Define function
Just a function can be a trainer extension. Simply, define a function which takes one argument (in below case “t”), which is trainer instance.
1-1. define function
# 1-1. Define function for trainer extension
def my_extension(t):
print('my_extension function is called at epoch {}!'
.format(t.updater.epoch_detail))
# Change optimizer's learning rate
optimizer.lr *= 0.99
print('Updated optimizer.lr to {}'.format(optimizer.lr))
trainer.extend(my_extension, trigger=(1, 'epoch'))
Here the argument of my_extension function, t, is trainer instance. You may obtain a lot of information related to the training procedure from trainer. In this case, I took the current epoch information by accessing updater’s property (trainer holds updater’s instance), t.updater.epoch_detail.
The extension is invoked based on the trigger configuration. In above code trigger=(1, 'epoch') means that this extension is invoked every once in one epoch.
Try changing the code from trainer.extend(my_extension, trigger=(1, 'epoch')) to trainer.extend(my_extension, trigger=(1, 'iteration')). Then the code is invoked every one iteration (Causion: it outpus the log very frequently, please stop it after executed and you have checked the behavior).
1-2. Use lambda
Instead of defining a function explicitly, you can simply use lambda function if the extension’s logic is simple.
# Use lambda function for extension
trainer.extend(lambda t: print('lambda function called at epoch {}!'
.format(t.updater.epoch_detail)),
trigger=(1, 'epoch'))
2. Use make_extension decorator on function
TBD…