Stability-AI / stable-audio-tools
- суббота, 8 июня 2024 г. в 00:00:02
Generative models for conditional audio generation
Training and inference code for audio generation models
The library can be installed from PyPI with:
$ pip install stable-audio-toolsTo run the training scripts or inference code, you'll want to clone this repository, navigate to the root, and run:
$ pip install .Requires PyTorch 2.0 or later for Flash Attention support
Development for the repo is done in Python 3.8.10
A basic Gradio interface is provided to test out trained models.
For example, to create an interface for the stable-audio-open-1.0 model, once you've accepted the terms for the model on Hugging Face, you can run:
$ python3 ./run_gradio.py --pretrained-name stabilityai/stable-audio-open-1.0The run_gradio.py script accepts the following command line arguments:
--pretrained-name
model.safetensors over model.ckpt in the repomodel-config and ckpt-path when using pre-trained model checkpoints on Hugging Face--model-config
--ckpt-path
--pretransform-ckpt-path
--username and --password
--model-half
Before starting your training run, you'll need a model config file, as well as a dataset config file. For more information about those, refer to the Configurations section below
The training code also requires a Weights & Biases account to log the training outputs and demos. Create an account and log in with:
$ wandb loginTo start a training run, run the train.py script in the repo root with:
$ python3 ./train.py --dataset-config /path/to/dataset/config --model-config /path/to/model/config --name harmonai_trainThe --name parameter will set the project name for your Weights and Biases run.
stable-audio-tools uses PyTorch Lightning to facilitate multi-GPU and multi-node training.
When a model is being trained, it is wrapped in a "training wrapper", which is a pl.LightningModule that contains all of the relevant objects needed only for training. That includes things like discriminators for autoencoders, EMA copies of models, and all of the optimizer states.
The checkpoint files created during training include this training wrapper, which greatly increases the size of the checkpoint file.
unwrap_model.py in the repo root will take in a wrapped model checkpoint and save a new checkpoint file including only the model itself.
That can be run with from the repo root with:
$ python3 ./unwrap_model.py --model-config /path/to/model/config --ckpt-path /path/to/wrapped/ckpt --name model_unwrapUnwrapped model checkpoints are required for:
Fine-tuning a model involves continuning a training run from a pre-trained checkpoint.
To continue a training run from a wrapped model checkpoint, you can pass in the checkpoint path to train.py with the --ckpt-path flag.
To start a fresh training run using a pre-trained unwrapped model, you can pass in the unwrapped checkpoint to train.py with the --pretrained-ckpt-path flag.
Additional optional flags for train.py include:
--config-file
train.py from a directory other than the repo root--pretransform-ckpt-path
--save-dir
--checkpoint-every
--batch-size
--num-gpus
--num-nodes
--accum-batches
--strategy
deepspeed will enable DeepSpeed ZeRO Stage 2.ddp if --num_gpus > 1, else None--precision
--num-workers
--seed
Training and inference code for stable-audio-tools is based around JSON configuration files that define model hyperparameters, training settings, and information about your training dataset.
The model config file defines all of the information needed to load a model for training or inference. It also contains the training configuration needed to fine-tune a model or train from scratch.
The following properties are defined in the top level of the model configuration:
model_type
"autoencoder", "diffusion_uncond", "diffusion_cond", "diffusion_cond_inpaint", "diffusion_autoencoder", "lm".sample_size
sample_rate
audio_channels
model
model_typetraining
model_type. Provides parameters for training as well as demos.stable-audio-tools currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in the dataset config documentation