Design patterns for defining model

 

Machine learning consists of training phase and predict/inference phase, and what  model need to calculate is different

  • Training phase: calculate loss (between on output and target)
  • Predict/Inference phase: calculate output

To manage this, I often see below 2 patterns to manage this.

 

Predictor – Classifier framework

See train_mnist_2_predictor_classifier.py (train_mnist_1_minimum.py and train_mnist_4_trainer.py are also implemented in Predictor – Classifier framework)

2 Chain classes, “Predictor” and “Classifier” are used for this framework.

  • Training phase: Predictor’s output is fed into Classifier to calculate loss.
  • Predict/Inference phase: Only predictor’s output is used.

 

  • Predictor

Predictor simply calculates output based on input.

 

 

  • Classifier

Classifier “wraps” predictors output y to calculate loss between y and actual target t.

which invokes classifier_model(x, t) internally, calculates loss and update internal parameter by back propagation.

Refer source code of Classifier for the detail.

 

Train flag framework

See train_mnist_3_train_flag.py.

Both the loss calculation in train phase and predict code for inference phase are implemented within one model, and the behavior is managed by “train flag” (or “test flag”).

As default, self.train = True, and this model will calculate loss so that optimizer can update its internal parameters.

To predict value, we can set train flag to False,

 

 

Comparison

Predictor – Classifier framework has an advantage that Classifier module can be independent and it will be reusable. However, if loss calculation is complicated, it is difficult to apply this framework.

In train flag framework, train loss calculation and predict calculation can be independent. You can implement any loss calculation,  even the loss calculation is very different from predict calculation.

Basically, you can use Predictor – Classifier framework if the loss function is typical. Use train flag framework otherwise.

 

 

Sponsored Links

Leave a Reply

Your email address will not be published.