julian121266 / RecurrentHighwayNetworks
- воскресенье, 30 октября 2016 г. в 03:14:51
Python
Recurrent Highway Networks - Author implementation for Tensorflow and Torch
What?
This repository contains code accompanying the paper Recurrent Highway Networks (RHNs). RHNs are an extension of Long Short Term Memory Networks with forget gates to enable the learning of deep recurrent state transitions. We provide implementations in Tensorflow, Torch7 and Brainstorm libraries, and welcome additional implementations from the community.
Why?
The recurrent state transition in typical recurrent networks is modeled with a single step non-linear function. This can be very inefficient in principle for modeling complicated transitions, requiring very large networks. Increased recurrence depth allows RHNs to model complex transitions more efficiently achieving substantially improved results.
Moreover, using depth d in the recurrent state transition is much more powerful than stacking d recurrent layers. The figures below illustrate that if we consider the functions mapping one hidden state to another T time steps apart, its maximum depth scales as the product of d and T instead of the sum. Of course, in general RHNs can also be stacked to get the best of both worlds.
Stacked RNN | Deep Transition RNN |
---|---|
The score (perplexity) of word-level language models on the Penn Treebank dataset dramatically improves as recurrence depth increases while keeping the model size fixed. WT refers to tying the input and output weights for regularization. See the paper for details.
Rec. depth | #Units/Layer | Best Validation | Test | Best Validation (WT) | Test (WT) |
---|---|---|---|---|---|
1 | 1275 | 92.4 | 89.2 | 93.2 | 90.6 |
2 | 1180 | 79.0 | 76.3 | 76.9 | 75.1 |
3 | 1110 | 75.0 | 72.6 | 72.7 | 70.6 |
4 | 1050 | 73.3 | 70.9 | 70.8 | 68.6 |
5 | 1000 | 72.0 | 69.8 | 69.7 | 67.7 |
6 | 960 | 71.9 | 69.3 | 69.1 | 66.6 |
7 | 920 | 71.7 | 68.7 | 68.7 | 66.4 |
8 | 890 | 71.2 | 68.5 | 68.2 | 66.1 |
9 | 860 | 71.3 | 68.5 | 68.1 | 66.0 |
10 | 830 | 71.3 | 68.3 | 68.3 | 66.0 |
Network | Size | Best Validation | Test |
---|---|---|---|
LSTM+dropout | 66 M | 82.2 | 78.4 |
Variational LSTM | 66 M | 77.3 | 75.0 |
Variational LSTM with MC dropout | 66 M | - | 73.4 |
Variational LSTM + WT | 51 M | 75.8 | 73.2 |
Pointer Sentinel LSTM | 21 M | 72.4 | 70.9 |
Ensemble of 38 large LSTMs+dropout | 66 M per LSTM | 71.9 | 68.7 |
Ensemble of 10 large Variational LSTMs | 66 M per LSTM | - | 68.7 |
Variational RHN (depth=8) | 32 M | 71.2 | 68.5 |
Variational RHN + WT (depth=9) | 24 M | 68.1 | 66.0 |
Variational RHN + WT with MC dropout (depth=5)* | 22 M | - | 64.4 |
*We used 1000 samples for MC dropout as done by Gal for LSTMs, but we've only evaluated the depth 5 model so far.
Network | Network size | Test BPC |
---|---|---|
GF-RNN | 20 M | 1.58 |
Grid-LSTM | 16.8 M | 1.47 |
MI-LSTM | 17 M | 1.44 |
HM-LSTM | 48 M | 1.40 |
HyperLSTM | 17.9 M | 1.38 |
Variational RHN | 27.6 M | 1.32 |
Tensorflow code for RHNs is built by heavily extending the LSTM language modeling example provided in Tensorflow. It supports Variational RHNs as used in the paper, which use the same dropout mask at each time step and at all layers inside the recurrence. Note that this implementation uses the same dropout mask for both the H and T non-linear transforms in RHNs while the Torch7 implementation uses different dropout masks for different transformations.
We recommend installing Tensorflow in a virtual environment. In addition to the usual Tensorflow dependencies, the code uses Sacred so you need to do:
$ pip install sacred
To reproduce SOTA results on Penn Treebank:
$ python rhn_train.py with ptb_sota
To reproduce SOTA results on enwik8 (Wikipedia), first download the dataset from http://mattmahoney.net/dc/enwik8.zip and unzip it into the data
directory, then run:
$ python rhn_train.py with enwik8_sota
Change some hyperparameters and run:
$ python rhn_train.py with ptb_sota depth=20
This is a Sacred experiment, so you check the hyperparameter options using the print_config
command, e.g.
$ python rhn_train.py print_config with ptb_sota
Torch7 code is based on Yarin Gal's adaptation of Wojciech Zaremba's code implementing variational dropout.
The main additions to Gal's code are the Recurrent Highway Network layer, the initial biasing of T-gate activations to facilitate learning and a few adjustments to other network parameters such as rnn_size
and dropout probabilities.
We recommend installing Torch from the official website. To ensure the code runs some packages need to be installed:
$ luarocks install nngraph
$ luarocks install cutorch
$ luarocks install nn
$ luarocks install hdf5
$ th torch_rhn_ptb.lua
To run on the enwik8 dataset, first download and prepare the data (see data/README for details):
$ cd data
$ python create_enwik8.py
Then you can train by running:
$ th toch_rhn_enwik8.lua
An RHN layer implementation is also provided in Brainstorm.
This implementation does not use variational dropout.
It can be used in a Brainstorm experiment by simply importing HighwayRNNCoupledGates
from brainstorm_rhn.py.
If you use RHNs in your work, please cite us:
@article{zilly2016recurrent,
title="{Recurrent Highway Networks}",
author={Zilly, Julian Georg and Srivastava, Rupesh Kumar and Koutn{\'\i}k, Jan and Schmidhuber, J{\"u}rgen},
journal={arXiv preprint arXiv:1607.03474},
year={2016}
}
MIT License.