victoresque / pytorch-template
- понедельник, 26 марта 2018 г. в 09:06:58
Python
PyTorch template project
A simple template project using PyTorch which can be modified to fit many deep learning projects.
The code in this repo is an MNIST example of the template, try run:
python main.py
The default arguments list is shown below:
usage: main.py [-h] [-b BATCH_SIZE] [-e EPOCHS] [--resume RESUME]
               [--verbosity VERBOSITY] [--save-dir SAVE_DIR]
               [--save-freq SAVE_FREQ] [--data-dir DATA_DIR]
               [--validation-split VALIDATION_SPLIT] [--no-cuda]
PyTorch Template
optional arguments:
  -h, --help    show this help message and exit
  -b BATCH_SIZE, --batch-size BATCH_SIZE
                        mini-batch size (default: 32)
  -e EPOCHS, --epochs EPOCHS
                        number of total epochs (default: 32)
  --resume RESUME
                        path to latest checkpoint (default: none)
  --verbosity VERBOSITY
                        verbosity, 0: quiet, 1: per epoch, 2: complete (default: 2)
  --save-dir SAVE_DIR
                        directory of saved model (default: model/saved)
  --save-freq SAVE_FREQ
                        training checkpoint frequency (default: 1)
  --data-dir DATA_DIR
                        directory of training/testing data (default: datasets)
  --validation-split VALIDATION_SPLIT
                        ratio of split validation data, [0.0, 1.0) (default: 0.0)
  --no-cuda   use CPU in case there's no GPU support
You can add your own arguments.
├── base/ - abstract base classes
│   ├── base_data_loader.py - abstract base class for data loaders.
│   ├── base_model.py - abstract base class for models.
│   └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│   └── data_loader.py
│
├── datasets/ - default dataset folder
│
├── logger/ - for training process logging
│   └── logger.py
│
├── model/ - models, losses, and metrics
│   ├── modules/ - submodules of your model
│   ├── saved/ - default checkpoint folder
│   ├── loss.py
│   ├── metric.py
│   └── model.py
│
├── trainer/ - trainers for your project
│   └── trainer.py
│
└── utils
     ├── utils.py
     └── ...
In most cases, you need to modify trainer/trainer.py to fit the training logic of your project
You can customize data loader to fit your project, just modify data_loader/data_loader.py or add other files.
Implement your model under model/
If you need to change the loss function or metrics, first import those function in main.py, then modify this part:
loss = my_loss
metrics = [my_metric]You'll see the logging has changed during training:
⋯
Train Epoch: 1 [53920/53984 (100%)] Loss: 0.033256
{'epoch': 1, 'loss': 0.14182623870152963, 'my_metric': 0.9568761114404268, 'val_loss': 0.06394806604976841, 'val_my_metric': 0.9804478609625669}
Saving checkpoint: model/saved/Model_checkpoint_epoch01_loss_0.14183.pth.tar ...
Train Epoch: 2 [0/53984 (0%)] Loss: 0.013225
⋯
If you have multiple metrics in your project, just add it to the metrics list:
loss = my_loss
metrics = [my_metric, my_metric2]Now the logging shows two metrics:
⋯
Train Epoch: 1 [53920/53984 (100%)] Loss: 0.003278
{'epoch': 1, 'loss': 0.13541310020907665, 'my_metric': 0.9590804682868999, 'my_metric2': 1.9181609365737997, 'val_loss': 0.05264156081223173, 'val_my_metric': 0.9837901069518716, 'val_my_metric2': 1.9675802139037433}
Saving checkpoint: model/saved/Model_checkpoint_epoch01_loss_0.13541.pth.tar ...
Train Epoch: 2 [0/53984 (0%)] Loss: 0.023072
⋯
Currently the name shown in log is the name of the function.
If you have additional information to be logged, you can modify _train_epoch() in class Trainer, for example, say you have an additional log saved as a dictionary:
additional_log = {"x": x, "y": y}just merge it with log as shown below before returning:
log = {**log, **additional_log}
return logIf you have separate validation data, try implement another data loader for validation, otherwise if you just want to split validation data from training data, try pass --validation-split 0.1, in some cases you might need to modify utils/util.py
If you need to add prefix to your checkpoint, modify this line in main.py
identifier = type(model).__name__ + '_'The prefix of the model will change, if you need to further change the naming of checkpoints, try modify _save_checkpoint() in class BaseTrainer
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
This project is heavily inspired by the project Tensorflow-Project-Template by Mahmoud Gemy, be sure to star it!