Training LSTM model with Penn Bank Tree (ptb) dataset

 

This post mainly explains train_ptb.py, uploaded on github.

 

We have already learned RNN and LSTM network architecture, let’s apply it to PTB dataset.

It is quite similar to train_simple_sequence.py explained in Training RNN with simple sequence dataset, so no much explanation is necessary.

 

Train code

I will just paste whole the training code for PTB at first,

 

I will explain different point from simple_sequence dataset in the following.

 

PTB dataset preparation: train, validation and test

Dataset preparation is done by get_ptb_words method provided by chainer,

 

Note that PTB dataset consists of train, validation and test data, while previous project like MNIST, CIFAR-10, CIFAR-100 consisted of train and test data.

In above training code, we use train dataset for train the model, validation dataset to monitor the validation loss during the training (for example you may tune hyper parameter using validation loss), and test dataset only after the training is completely finished to just check/evaluate the model’s performance.

 

Monitor the loss by perplexity

In NLP, it is common to measure the model’s performance by perplexity, instead of softmax cross entropy or correct percentage.

Perplexity of a probability distribution

The perplexity of a discrete probability distribution p is defined as

Perplexity per word

In natural language processing, perplexity is a way of evaluating language models. A language model is a probability distribution over entire sentences or texts.

cite from https://en.wikipedia.org/wiki/Perplexity

It is calculated easily by just take exponential of the mean softmax cross entropy loss

result['perplexity'] = np.exp(result['main/loss'])

 

and in chainer, we can show it by LogReport extension. 

It is done by passing post processing function “compute_perplexity” into LogReport argument.

 

LogReport‘s postprocess argument will take a function, where the function will take the argument “result” which is a dictionary containing the repoted value.

Since ‘main/loss’ and ‘validation/main/loss’ is reported by Classifier and Evaluator, we can extract these values from result to calculate perplexity and val_perplexity. When it is set to result dictionary, it can be shown by PrintReport by the same key name.

Sponsored Links

Leave a Reply

Your email address will not be published.