Training LSTM model with Penn Bank Tree (ptb) dataset

This post mainly explains train_ptb.py, uploaded on github.

We have already learned RNN and LSTM network architecture, let’s apply it to PTB dataset.

It is quite similar to train_simple_sequence.py explained in Training RNN with simple sequence dataset, so no much explanation is necessary.

Train code

I will just paste whole the training code for PTB at first,

"""
RNN Training code with Penn Treebank (ptb) dataset
Ref: https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb.py
"""
from __future__ import print_function

import os
import sys
import argparse

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, iterators, serializers, optimizers
from chainer.training import extensions

sys.path.append(os.pardir)
from RNN import RNN
from RNN2 import RNN2
from RNN3 import RNN3
from RNNForLM import RNNForLM
from parallel_sequential_iterator import ParallelSequentialIterator
from bptt_updater import BPTTUpdater



# Routine to rewrite the result dictionary of LogReport to add perplexity
# values
def compute_perplexity(result):
    result['perplexity'] = np.exp(result['main/loss'])
    if 'validation/main/loss' in result:
        result['val_perplexity'] = np.exp(result['validation/main/loss'])


def main():
    archs = {
        'rnn': RNN,
        'rnn2': RNN2,
        'rnn3': RNN3,
        'lstm': RNNForLM
    }

    parser = argparse.ArgumentParser(description='RNN example')
    parser.add_argument('--arch', '-a', choices=archs.keys(),
                        default='rnn', help='Net architecture')
    parser.add_argument('--unit', '-u', type=int, default=100,
                        help='Number of RNN units in each layer')
    parser.add_argument('--bproplen', '-l', type=int, default=20,
                        help='Number of words in each mini-batch '
                             '(= length of truncated BPTT)')
    parser.add_argument('--batchsize', '-b', type=int, default=10,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=10,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# Architecture: {}'.format(args.arch))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))

    # 1. Load dataset: Penn Tree Bank long word sequence dataset
    train, val, test = chainer.datasets.get_ptb_words()
    n_vocab = max(train) + 1  # train is just an array of integers
    print('# vocab: {}'.format(n_vocab))
    print('')

    # 2. Setup model
    model = archs[args.arch](n_vocab=n_vocab,
                             n_units=args.unit)  # , activation=F.tanh
    classifier_model = L.Classifier(model)
    classifier_model.compute_accuracy = False  # we only want the perplexity

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        classifier_model.to_gpu()  # Copy the model to the GPU

    eval_classifier_model = classifier_model.copy()  # Model with shared params and distinct states
    eval_model = classifier_model.predictor

    # 2. Setup an optimizer
    optimizer = optimizers.Adam(alpha=0.001)
    #optimizer = optimizers.MomentumSGD()
    optimizer.setup(classifier_model)

    # 4. Setup an Iterator
    train_iter =ParallelSequentialIterator(train, args.batchsize)
    val_iter = ParallelSequentialIterator(val, 1, repeat=False)
    test_iter = ParallelSequentialIterator(test, 1, repeat=False)

    # 5. Setup an Updater
    updater = BPTTUpdater(train_iter, optimizer, args.bproplen, args.gpu)
    # 6. Setup a trainer (and extensions)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(val_iter, eval_classifier_model,
                                        device=args.gpu,
                                        # Reset the RNN state at the beginning of each evaluation
                                        eval_hook=lambda _: eval_model.reset_state())
                   )

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))
    interval = 500
    trainer.extend(extensions.LogReport(postprocess=compute_perplexity,
                                        trigger=(interval, 'iteration')))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'perplexity', 'val_perplexity', 'elapsed_time']
    ), trigger=(interval, 'iteration'))
    trainer.extend(extensions.PlotReport(
        ['perplexity', 'val_perplexity'],
        x_key='epoch', file_name='perplexity.png'))

    trainer.extend(extensions.ProgressBar(update_interval=10))

    # Resume from a snapshot
    if args.resume:
        serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()
    serializers.save_npz('{}/{}_ptb.model'
                         .format(args.out, args.arch), model)

    # Evaluate the final model
    print('test')
    eval_model.reset_state()
    evaluator = extensions.Evaluator(test_iter, eval_classifier_model, device=args.gpu)
    result = evaluator()
    print('test perplexity:', np.exp(float(result['main/loss'])))


if __name__ == '__main__':
    main()


I will explain different point from simple_sequence dataset in the following.

PTB dataset preparation: train, validation and test

Dataset preparation is done by get_ptb_words method provided by chainer,

    # 1. Load dataset: Penn Tree Bank long word sequence dataset
    train, val, test = chainer.datasets.get_ptb_words()
    n_vocab = max(train) + 1  # train is just an array of integers
    print('# vocab: {}'.format(n_vocab))
    print('')


Note that PTB dataset consists of trainvalidation and test data, while previous project like MNIST, CIFAR-10, CIFAR-100 consisted of train and test data.

In above training code, we use train dataset for train the model, validation dataset to monitor the validation loss during the training (for example you may tune hyper parameter using validation loss), and test dataset only after the training is completely finished to just check/evaluate the model’s performance.

Monitor the loss by perplexity

In NLP, it is common to measure the model’s performance by perplexity, instead of softmax cross entropy or correct percentage.

Perplexity of a probability distribution

The perplexity of a discrete probability distribution p is defined as2^{{H(p)}}=2^{{-\sum _{x}p(x)\log _{2}p(x)}}

Perplexity per word

In natural language processing, perplexity is a way of evaluating language models. A language model is a probability distribution over entire sentences or texts.

cite from https://en.wikipedia.org/wiki/Perplexity

It is calculated easily by just take exponential of the mean softmax cross entropy loss

result['perplexity'] = np.exp(result['main/loss'])

and in chainer, we can show it by LogReport extension. 

It is done by passing post processing function “compute_perplexity” into LogReport argument.

# Routine to rewrite the result dictionary of LogReport to add perplexity
# values
def compute_perplexity(result):
    result['perplexity'] = np.exp(result['main/loss'])
    if 'validation/main/loss' in result:
        result['val_perplexity'] = np.exp(result['validation/main/loss'])


    ...

    interval = 500
    trainer.extend(extensions.LogReport(postprocess=compute_perplexity,
                                        trigger=(interval, 'iteration')))

LogReport‘s postprocess argument will take a function, where the function will take the argument “result” which is a dictionary containing the repoted value.

Since ‘main/loss’ and ‘validation/main/loss’ is reported by Classifier and Evaluator, we can extract these values from result to calculate perplexity and val_perplexity. When it is set to result dictionary, it can be shown by PrintReport by the same key name.

Leave a Comment

Your email address will not be published. Required fields are marked *