Skip to content

Commit 3cca73f

Browse files
committed
Merge remote-tracking branch 'upstream/main' into patch-1
2 parents cad30cc + 3635b23 commit 3cca73f

37 files changed

+2182
-1351
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ $ pip install .
1414
```
1515

1616
# Requirements
17-
Requires PyTorch 2.0 or later for Flash Attention support
17+
Requires PyTorch 2.5 or later for Flash Attention and Flex Attention support
1818

19-
Development for the repo is done in Python 3.8.10
19+
Development for the repo is done in Python 3.10
2020

2121
# Interface
2222

docs/datasets.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,31 @@ To load audio files and related metadata from .tar files in the WebDataset forma
4242
}
4343
```
4444

45+
## Pre Encoded Datasets
46+
To use pre encoded latents created with the [pre encoding script](pre_encoding.md), set the `dataset_type` property to `"pre_encoded"`, and provide the path to the directory containing the pre encoded `.npy` latent files and corresponding `.json` metadata files.
47+
48+
You can optionally specify a `latent_crop_length` in latent units (latent length = `audio_samples // 2048`) to crop the pre encoded latents to a smaller length than you encoded to. If not specified, uses the full pre encoded length. When `random_crop` is set to true, it will randomly crop from the sequence at your desired `latent_crop_length` while taking padding into account.
49+
50+
**Note**: `random_crop` does not currently update `seconds_start`, so it will be inaccurate when used to train or fine-tune models with that condition (e.g. `stable-audio-open-1.0`), but can be used with models that do not use `seconds_start` (e.g. `stable-audio-open-small`).
51+
52+
### Example config
53+
```json
54+
{
55+
"dataset_type": "pre_encoded",
56+
"datasets": [
57+
{
58+
"id": "my_pre_encoded_audio",
59+
"path": "/path/to/pre_encoded/output/",
60+
"latent_crop_length": 512,
61+
"custom_metadata_module": "/path/to/custom_metadata.py"
62+
}
63+
],
64+
"random_crop": true
65+
}
66+
```
67+
68+
For information on creating pre encoded datasets, see [Pre Encoding](pre_encoding.md).
69+
4570
# Custom metadata
4671
To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary.
4772

docs/diffusion.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ The `training` config in the diffusion model config file should have the followi
6161
- Optional, overrides `learning_rate`
6262
- `demo`
6363
- Configuration for the demos during training, including conditioning information
64+
- `pre_encoded`
65+
- If true, indicates that the model should operate on [pre encoded latents](pre_encoding.md) instead of raw audio
66+
- Required when training with [pre encoded datasets](datasets.md#pre-encoded-datasets)
67+
- Optional. Default: `false`
6468

6569
## Example config
6670
```json

docs/pre_encoding.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Pre Encoding
2+
3+
When training models on encoded latents from a frozen pre-trained autoencoder, the encoder is typically frozen. Because of that, it is common to pre-encode audio to latents and store them on disk instead of computing them on-the-fly during training. This can improve training throughput as well as free up GPU memory that would otherwise be used for encoding.
4+
5+
## Prerequisites
6+
7+
To pre-encode audio to latents, you'll need a dataset config file, an autoencoder model config file, and an **unwrapped** autoencoder checkpoint file.
8+
9+
**Note:** You can find a copy of the unwrapped VAE checkpoint (`vae_model.ckpt`) and config (`vae_config.json`) in the `stabilityai/stable-audio-open-1.0` Hugging Face [repo](https://huggingface.co/stabilityai/stable-audio-open-1.0). This is the same VAE used in `stable-audio-open-small`.
10+
11+
## Run the Pre Encoding Script
12+
13+
To pre-encode latents from an autoencoder model, you can use `pre_encode.py`. This script will load a pre-trained autoencoder, encode the latents/tokens, and save them to disk in a format that can be easily loaded during training.
14+
15+
The `pre_encode.py` script accepts the following command line arguments:
16+
17+
- `--model-config`
18+
- Path to model config
19+
- `--ckpt-path`
20+
- Path to **unwrapped** autoencoder model checkpoint
21+
- `--model-half`
22+
- If true, uses half precision for model weights
23+
- Optional
24+
- `--dataset-config`
25+
- Path to dataset config file
26+
- Required
27+
- `--output-path`
28+
- Path to output folder
29+
- Required
30+
- `--batch-size`
31+
- Batch size for processing
32+
- Optional, defaults to 1
33+
- `--sample-size`
34+
- Number of audio samples to pad/crop to for pre-encoding
35+
- Optional, defaults to 1320960 (~30 seconds)
36+
- `--is-discrete`
37+
- If true, treats the model as discrete, saving discrete tokens instead of continuous latents
38+
- Optional
39+
- `--num-nodes`
40+
- Number of nodes to use for distributed processing, if available.
41+
- Optional, defaults to 1
42+
- `--num-workers`
43+
- Number of dataloader workers
44+
- Optional, defaults to 4
45+
- `--strategy`
46+
- PyTorch Lightning strategy
47+
- Optional, defaults to 'auto'
48+
- `--limit-batches`
49+
- Limits the number of batches processed
50+
- Optional
51+
- `--shuffle`
52+
- If true, shuffles the dataset
53+
- Optional
54+
55+
**Note:** When pre encoding, it's recommended to set `"drop_last": false` in your dataset config to ensure the last batch is processed even if it's not full.
56+
57+
For example, if you wanted to encode latents with padding up to 30 seconds long in half precision, you could run the following:
58+
59+
```bash
60+
$ python3 ./pre_encode.py \
61+
--model-config /path/to/model/config.json \
62+
--ckpt-path /path/to/autoencoder/model.ckpt \
63+
--model-half \
64+
--dataset-config /path/to/dataset/config.json \
65+
--output-path /path/to/output/dir \
66+
--sample-size 1320960 \
67+
```
68+
69+
When you run the above, the `--output-path` directory will contain numbered subdirectories for each GPU process used to encode the latents, and a `details.json` file that keeps track of settings used when the script was run.
70+
71+
Inside the numbered subdirectories, you will find the encoded latents as `.npy` files, along with associated `.json` metadata files.
72+
73+
```bash
74+
/path/to/output/dir/
75+
├── 0
76+
│ ├── 0000000000000.json
77+
│ ├── 0000000000000.npy
78+
│ ├── 0000000000001.json
79+
│ ├── 0000000000001.npy
80+
│ ├── 0000000000002.json
81+
│ ├── 0000000000002.npy
82+
...
83+
└── details.json
84+
```
85+
86+
## Training on Pre Encoded Latents
87+
88+
Once you have saved your latents to disk, you can use them to train a model by providing a dataset config file to `train.py` that points to the pre-encoded latents, specifying `"dataset_type"` is `"pre_encoded"`. Under the hood, this will configure a `stable_audio_tools.data.dataset.PreEncodedDataset`. For more information on configuring pre encoded datasets, see the [Pre Encoded Datasets](datasets.md#pre-encoded-datasets) section of the datasets docs.
89+
90+
The dataset config file should look something like this:
91+
92+
```json
93+
{
94+
"dataset_type": "pre_encoded",
95+
"datasets": [
96+
{
97+
"id": "my_audio",
98+
"path": "/path/to/output/dir"
99+
}
100+
],
101+
"random_crop": false
102+
}
103+
```
104+
105+
In your diffusion model config, you'll also need to specify `pre_encoded: true` in the [`training` section](diffusion.md#training-configs) to tell the training wrapper to operate on pre encoded latents instead of audio.
106+
107+
```json
108+
"training": {
109+
"pre_encoded": true,
110+
...
111+
}
112+
```

pre_encode.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import argparse
2+
import gc
3+
import json
4+
from pathlib import Path
5+
6+
import numpy as np
7+
import pytorch_lightning as pl
8+
import torch
9+
from torch.nn import functional as F
10+
11+
from stable_audio_tools.data.dataset import create_dataloader_from_config
12+
from stable_audio_tools.models.factory import create_model_from_config
13+
from stable_audio_tools.models.pretrained import get_pretrained_model
14+
from stable_audio_tools.models.utils import load_ckpt_state_dict, copy_state_dict
15+
16+
17+
def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, model_half=False):
18+
if pretrained_name is not None:
19+
print(f"Loading pretrained model {pretrained_name}")
20+
model, model_config = get_pretrained_model(pretrained_name)
21+
22+
elif model_config is not None and model_ckpt_path is not None:
23+
print(f"Creating model from config")
24+
model = create_model_from_config(model_config)
25+
26+
print(f"Loading model checkpoint from {model_ckpt_path}")
27+
copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
28+
29+
model.eval().requires_grad_(False)
30+
31+
if model_half:
32+
model.to(torch.float16)
33+
34+
print("Done loading model")
35+
36+
return model, model_config
37+
38+
39+
class PreEncodedLatentsInferenceWrapper(pl.LightningModule):
40+
def __init__(
41+
self,
42+
model,
43+
output_path,
44+
is_discrete=False,
45+
model_half=False,
46+
model_config=None,
47+
dataset_config=None,
48+
sample_size=1920000,
49+
args_dict=None
50+
):
51+
super().__init__()
52+
self.save_hyperparameters(ignore=['model'])
53+
self.model = model
54+
self.output_path = Path(output_path)
55+
56+
def prepare_data(self):
57+
# runs on rank 0
58+
self.output_path.mkdir(parents=True, exist_ok=True)
59+
details_path = self.output_path / "details.json"
60+
if not details_path.exists(): # Only save if it doesn't exist
61+
details = {
62+
"model_config": self.hparams.model_config,
63+
"dataset_config": self.hparams.dataset_config,
64+
"sample_size": self.hparams.sample_size,
65+
"args": self.hparams.args_dict
66+
}
67+
details_path.write_text(json.dumps(details))
68+
69+
def setup(self, stage=None):
70+
# runs on each device
71+
process_dir = self.output_path / str(self.global_rank)
72+
process_dir.mkdir(parents=True, exist_ok=True)
73+
74+
def validation_step(self, batch, batch_idx):
75+
audio, metadata = batch
76+
77+
if audio.ndim == 4 and audio.shape[0] == 1:
78+
audio = audio[0]
79+
80+
if torch.cuda.is_available():
81+
torch.cuda.empty_cache()
82+
gc.collect()
83+
84+
if self.hparams.model_half:
85+
audio = audio.to(torch.float16)
86+
87+
with torch.no_grad():
88+
if not self.hparams.is_discrete:
89+
latents = self.model.encode(audio)
90+
else:
91+
_, info = self.model.encode(audio, return_info=True)
92+
latents = info[self.model.bottleneck.tokens_id]
93+
94+
latents = latents.cpu().numpy()
95+
96+
# Save each sample in the batch
97+
for i, latent in enumerate(latents):
98+
latent_id = f"{self.global_rank:03d}{batch_idx:06d}{i:04d}"
99+
100+
# Save latent as numpy file
101+
latent_path = self.output_path / str(self.global_rank) / f"{latent_id}.npy"
102+
with open(latent_path, "wb") as f:
103+
np.save(f, latent)
104+
105+
md = metadata[i]
106+
padding_mask = F.interpolate(
107+
md["padding_mask"].unsqueeze(0).unsqueeze(1).float(),
108+
size=latent.shape[1],
109+
mode="nearest"
110+
).squeeze().int()
111+
md["padding_mask"] = padding_mask.cpu().numpy().tolist()
112+
113+
# Convert tensors in md to serializable types
114+
for k, v in md.items():
115+
if isinstance(v, torch.Tensor):
116+
md[k] = v.cpu().numpy().tolist()
117+
118+
# Save metadata to json file
119+
metadata_path = self.output_path / str(self.global_rank) / f"{latent_id}.json"
120+
with open(metadata_path, "w") as f:
121+
json.dump(md, f)
122+
123+
def configure_optimizers(self):
124+
return None
125+
126+
127+
def main(args):
128+
with open(args.model_config) as f:
129+
model_config = json.load(f)
130+
131+
with open(args.dataset_config) as f:
132+
dataset_config = json.load(f)
133+
134+
model, model_config = load_model(
135+
model_config=model_config,
136+
model_ckpt_path=args.ckpt_path,
137+
model_half=args.model_half
138+
)
139+
140+
data_loader = create_dataloader_from_config(
141+
dataset_config,
142+
batch_size=args.batch_size,
143+
num_workers=args.num_workers,
144+
sample_rate=model_config["sample_rate"],
145+
sample_size=args.sample_size,
146+
audio_channels=model_config.get("audio_channels", 2),
147+
shuffle=args.shuffle
148+
)
149+
150+
pl_module = PreEncodedLatentsInferenceWrapper(
151+
model=model,
152+
output_path=args.output_path,
153+
is_discrete=args.is_discrete,
154+
model_half=args.model_half,
155+
model_config=args.model_config,
156+
dataset_config=args.dataset_config,
157+
sample_size=args.sample_size,
158+
args_dict=vars(args)
159+
)
160+
161+
trainer = pl.Trainer(
162+
accelerator="gpu",
163+
devices="auto",
164+
num_nodes = args.num_nodes,
165+
strategy=args.strategy,
166+
precision="16-true" if args.model_half else "32",
167+
max_steps=args.limit_batches if args.limit_batches else -1,
168+
logger=False, # Disable logging since we're just doing inference
169+
enable_checkpointing=False,
170+
)
171+
trainer.validate(pl_module, data_loader)
172+
173+
if __name__ == "__main__":
174+
parser = argparse.ArgumentParser(description='Encode audio dataset to VAE latents using PyTorch Lightning')
175+
parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
176+
parser.add_argument('--ckpt-path', type=str, help='Path to unwrapped autoencoder model checkpoint', required=False)
177+
parser.add_argument('--model-half', action='store_true', help='Whether to use half precision')
178+
parser.add_argument('--dataset-config', type=str, help='Path to dataset config file', required=True)
179+
parser.add_argument('--output-path', type=str, help='Path to output folder', required=True)
180+
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
181+
parser.add_argument('--sample-size', type=int, help='Number of audio samples to pad/crop to', default=1320960)
182+
parser.add_argument('--is-discrete', action='store_true', help='Whether the model is discrete')
183+
parser.add_argument('--num-nodes', type=int, help='Number of GPU nodes', default=1)
184+
parser.add_argument('--num-workers', type=int, help='Number of dataloader workers', default=4)
185+
parser.add_argument('--strategy', type=str, help='PyTorch Lightning strategy', default='auto')
186+
parser.add_argument('--limit-batches', type=int, help='Limit number of batches (optional)', default=None)
187+
parser.add_argument('--shuffle', action='store_true', help='Shuffle dataset')
188+
args = parser.parse_args()
189+
main(args)

run_gradio.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ def main(args):
1212
ckpt_path=args.ckpt_path,
1313
pretrained_name=args.pretrained_name,
1414
pretransform_ckpt_path=args.pretransform_ckpt_path,
15-
model_half=args.model_half
15+
model_half=args.model_half,
16+
gradio_title=args.title
1617
)
1718
interface.queue()
1819
interface.launch(share=args.share, auth=(args.username, args.password) if args.username is not None else None)
@@ -28,5 +29,6 @@ def main(args):
2829
parser.add_argument('--username', type=str, help='Gradio username', required=False)
2930
parser.add_argument('--password', type=str, help='Gradio password', required=False)
3031
parser.add_argument('--model-half', action='store_true', help='Whether to use half precision', required=False, default=True)
32+
parser.add_argument('--title', type=str, help='Display Title top of Gradio', required=False)
3133
args = parser.parse_args()
3234
main(args)

0 commit comments

Comments
 (0)