diff --git a/.gitignore b/.gitignore index 3e6aee68..482c1707 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,10 @@ cython_debug/ *.ckpt *.wav -wandb/* \ No newline at end of file +wandb/* + +# Dataset folders and outputs +/outputs +/pre_encoded* +/rawfiles* +*.zip diff --git a/dataset_config.json b/dataset_config.json new file mode 100644 index 00000000..e402ec9e --- /dev/null +++ b/dataset_config.json @@ -0,0 +1,8 @@ +{ + "dataset_type": "pre_encoded", + "datasets": [{ + "id": "audio_pre_encoded", + "path": "pre_encoded", + "custom_metadata_module": "paths_md.py" + }] +} diff --git a/paths_md.py b/paths_md.py new file mode 100644 index 00000000..5ead4a86 --- /dev/null +++ b/paths_md.py @@ -0,0 +1,17 @@ +import os +import re + + +def get_custom_metadata(info, audio): + # Get filename without extension + file_name = os.path.basename(info["relpath"]) + file_name_without_extension = os.path.splitext(file_name)[0] + + # Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces + cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip() + #cleaned_file_name = re.match('', cleaned_file_name).groups()[0] + + # Sanity check + print(f'{info["relpath"]} => {cleaned_file_name}') + + return {"prompt": cleaned_file_name} diff --git a/paths_md_pre_encode.py b/paths_md_pre_encode.py new file mode 100644 index 00000000..5ead4a86 --- /dev/null +++ b/paths_md_pre_encode.py @@ -0,0 +1,17 @@ +import os +import re + + +def get_custom_metadata(info, audio): + # Get filename without extension + file_name = os.path.basename(info["relpath"]) + file_name_without_extension = os.path.splitext(file_name)[0] + + # Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces + cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip() + #cleaned_file_name = re.match('', cleaned_file_name).groups()[0] + + # Sanity check + print(f'{info["relpath"]} => {cleaned_file_name}') + + return {"prompt": cleaned_file_name} diff --git a/pe_dataset_config.json b/pe_dataset_config.json new file mode 100644 index 00000000..9e6194ae --- /dev/null +++ b/pe_dataset_config.json @@ -0,0 +1,11 @@ +{ + "dataset_type": "audio_dir", + "datasets": [{ + "id": "audio", + "path": "./rawfiles", + "custom_metadata_module": "./paths_md_pre_encode.py", + "drop_last": false + }], + "drop_last": false, + "random_crop": false +} diff --git a/pre-encode.bat b/pre-encode.bat new file mode 100644 index 00000000..123de59e --- /dev/null +++ b/pre-encode.bat @@ -0,0 +1,9 @@ +python ./pre_encode.py ^ + --ckpt-path ./vae_model.ckpt ^ + --model-config ./vae_model_config.json ^ + --batch-size 8 ^ + --dataset-config pe_dataset_config.json ^ + --output-path ./pre_encoded ^ + --model-half ^ + --sample-size 131072 ^ + diff --git a/setup.py b/setup.py index f96f3bc1..2a8b41d9 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,9 @@ install_requires=[ 'alias-free-torch==0.0.6', 'auraloss==0.4.0', + 'bitsandbytes==0.47.0', 'descript-audio-codec==1.0.0', + 'dill==0.4.0', 'einops', 'einops-exts', 'ema-pytorch==0.2.3', @@ -23,7 +25,8 @@ 'local-attention==1.8.6', 'pandas==2.0.2', 'prefigure==0.0.9', - 'pytorch_lightning==2.1.0', + 'pytorch_lightning==2.4.0', + 'pytorch_optimizer==3.1.2', 'PyWavelets==1.4.1', 'safetensors', 'sentencepiece==0.1.99', diff --git a/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_small_adamw8bit_cawr_base_model_config.json b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_small_adamw8bit_cawr_base_model_config.json new file mode 100644 index 00000000..0c4abf5d --- /dev/null +++ b/stable_audio_tools/configs/model_configs/txt2audio/stable_audio_small_adamw8bit_cawr_base_model_config.json @@ -0,0 +1,132 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 524288, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "model_half": true, + "chunked": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "t5", + "config": { + "t5_model_name": "google/t5gemma-b-b-ul2", + "max_length": 128 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 256 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_total"], + "global_cond_ids": ["seconds_total"], + "diffusion_objective": "rectified_flow", + "distribution_shift_options": { + "min_length": 256, + "max_length": 4096 + }, + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1024, + "depth": 16, + "num_heads": 8, + "cond_token_dim": 768, + "global_cond_dim": 768, + "transformer_type": "continuous_transformer", + "attn_kwargs": { + "qk_norm": "ln" + } + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "pre_encoded": true, + "timestep_sampler": "trunc_logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW8bit", + "config": { + "lr": 1e-5, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 1e-2, + "block_wise": true + } + }, + "scheduler": { + "type": "CosineAnnealingWarmRestarts", + "config": { + "T_0": 10, + "T_mult": 2 + } + } + } + }, + "demo": { + "demo_every": 512, + "demo_steps": 100, + "num_demos": 7, + "demo_cond": [ + {"prompt": "kick", "seconds_total": 2}, + {"prompt": "bass", "seconds_total": 2}, + {"prompt": "drum breaks 174 BPM", "seconds_total": 6}, + {"prompt": "A short, beautiful piano riff in C minor", "seconds_total": 6}, + {"prompt": "Tight Snare Drum", "seconds_total": 1}, + {"prompt": "Glitchy bass design, I used Serum for this", "seconds_total": 4}, + {"prompt": "Synth pluck arp with reverb and delay, 128 BPM", "seconds_total": 6} + ], + "demo_cfg_scales": [0.5, 1, 1.5, 8] + } + } +} diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 7543ac17..7e02b6d6 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -1,3 +1,4 @@ +import dill import importlib import numpy as np import io @@ -19,7 +20,10 @@ from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm -AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") +from torchdata.stateful_dataloader import StatefulDataLoader + + +AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus", "aiff", "aif") # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py @@ -94,7 +98,7 @@ def keyword_scandir( def get_audio_filenames( paths: list, # directories in which to search keywords=None, - exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] + exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus', '.aif', '.aiff'] ): "recursively get a list of audio filenames" filenames = [] @@ -178,7 +182,7 @@ def __init__( self.root_paths.append(config.path) self.filenames.extend(get_audio_filenames(config.path, keywords)) if config.custom_metadata_fn is not None: - self.custom_metadata_fns[config.path] = config.custom_metadata_fn + self.custom_metadata_fns[config.path] = dill.dumps(config.custom_metadata_fn) print(f'Found {len(self.filenames)} files') @@ -238,8 +242,8 @@ def __getitem__(self, idx): for custom_md_path in self.custom_metadata_fns.keys(): if custom_md_path in audio_filename: - custom_metadata_fn = self.custom_metadata_fns[custom_md_path] - custom_metadata = custom_metadata_fn(info, audio) + custom_metadata_fn_deserialized = dill.loads(self.custom_metadata_fns[custom_md_path]) + custom_metadata = custom_metadata_fn_deserialized(info, audio) info.update(custom_metadata) if "__reject__" in info and info["__reject__"]: @@ -282,7 +286,7 @@ def __init__( for config in configs: self.filenames.extend(get_latent_filenames(config.path, [latent_extension])) if config.custom_metadata_fn is not None: - self.custom_metadata_fns[config.path] = config.custom_metadata_fn + self.custom_metadata_fns[config.path] = dill.dumps(config.custom_metadata_fn) self.latent_crop_length = latent_crop_length self.random_crop = random_crop @@ -339,8 +343,9 @@ def __getitem__(self, idx): for custom_md_path in self.custom_metadata_fns.keys(): if custom_md_path in latent_filename: - custom_metadata_fn = self.custom_metadata_fns[custom_md_path] - custom_metadata = custom_metadata_fn(info, None) + + custom_metadata_fn_deserialized = dill.loads(self.custom_metadata_fns[custom_md_path]) + custom_metadata = custom_metadata_fn_deserialized(info, None) info.update(custom_metadata) if "__reject__" in info and info["__reject__"]: @@ -849,8 +854,14 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl force_channels=force_channels ) - return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, - num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn) + # https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader + g = torch.Generator() + g.manual_seed(0) + + #return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, + # num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) + return StatefulDataLoader(train_set, batch_size, shuffle=shuffle, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) elif dataset_type == "pre_encoded": @@ -899,8 +910,14 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl latent_extension=latent_extension ) - return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, - num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn) + # https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader + g = torch.Generator() + g.manual_seed(0) + + #return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, + # num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) + return StatefulDataLoader(train_set, batch_size, shuffle=shuffle, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility wds_configs = [] diff --git a/stable_audio_tools/interface/aeiou.py b/stable_audio_tools/interface/aeiou.py index 59d94b6e..0fb7e626 100644 --- a/stable_audio_tools/interface/aeiou.py +++ b/stable_audio_tools/interface/aeiou.py @@ -15,6 +15,7 @@ import numpy as np + def embeddings_table(tokens): from wandb import Table from pandas import DataFrame @@ -33,6 +34,7 @@ def embeddings_table(tokens): df['LABEL'] = labels return Table(columns=df.columns.to_list(), data=df.values) + def project_down(tokens, # batched high-dimensional data with dims (b,d,n) proj_dims=3, # dimensions to project to method='pca', # projection method: 'pca'|'umap' @@ -56,6 +58,7 @@ def project_down(tokens, # batched high-dimensional data with dims (b,d,n) else: proj_data = A if debug: print("proj_data.shape =",proj_data.shape) + return torch.reshape(proj_data, (tokens.size()[0], -1, proj_dims)) # put it in shape [batch, n, proj_dims] @@ -148,7 +151,8 @@ def point_cloud( else: from wandb import Object3D return Object3D(point_cloud) - + + def pca_point_cloud( tokens, # embeddings / latent vectors. shape = (b, d, n) color_scheme='batch', # 'batch': group by sample, otherwise color sequentially @@ -161,6 +165,7 @@ def pca_point_cloud( return point_cloud(tokens, method='pca', color_scheme=color_scheme, output_type=output_type, mode=mode, size=size, line=line, **kwargs) + def power_to_db(spec, *, amin = 1e-10): magnitude = np.asarray(spec) @@ -171,17 +176,32 @@ def power_to_db(spec, *, amin = 1e-10): return log_spec + def mel_spectrogram(waveform, power=2.0, sample_rate=48000, db=False, n_fft=1024, n_mels=128, debug=False): "calculates data array for mel spectrogram (in however many channels)" win_length = None + # https://docs.pytorch.org/audio/2.8.0/generated/torchaudio.transforms.MelSpectrogram.html?highlight=melspectrogram#torchaudio.transforms.MelSpectrogram + # n_fft (int, optional) – Size of FFT, creates n_fft // 2 + 1 bins. + # Ergo n_fft//2 = 513, not 512, when n_fft=1024, as seen in this error: + # + # torchaudio\functional\functional.py:585: UserWarning: At least one mel filterbank has all zero values. + # The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (513) may be set too low. hop_length = n_fft//2 # 512 mel_spectrogram_op = T.MelSpectrogram( - sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, - hop_length=hop_length, center=True, pad_mode="reflect", power=power, - norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk") + sample_rate=sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + center=True, + pad_mode="reflect", + power=power, + norm='slaney', + n_mels=n_mels, + mel_scale="htk") melspec = mel_spectrogram_op(waveform.float()) + if db: amp_to_db_op = T.AmplitudeToDB() melspec = amp_to_db_op(melspec) @@ -189,8 +209,10 @@ def mel_spectrogram(waveform, power=2.0, sample_rate=48000, db=False, n_fft=1024 print_stats(melspec, print=print) print(f"torch.max(melspec) = {torch.max(melspec)}") print(f"melspec.shape = {melspec.shape}") + return melspec + def spectrogram_image( spec, title=None, @@ -227,12 +249,14 @@ def spectrogram_image( #print(f"im.size = {im.size}") return im + def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000, print=print, db=False, db_range=[35,120], justimage=False, log=False, figsize=(5, 4)): "Wrapper for calling above two routines at once, does Mel scale; Modified from PyTorch tutorial https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html" melspec = mel_spectrogram(waveform, power=power, db=db, sample_rate=sample_rate, debug=log) melspec = melspec[0] # TODO: only left channel for now return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)', db_range=db_range, justimage=justimage, figsize=figsize) + from matplotlib.ticker import AutoLocator def tokens_spectrogram_image( tokens, # the embeddings themselves (in some diffusion codes these are called 'tokens') @@ -241,9 +265,9 @@ def tokens_spectrogram_image( ylabel='index', # label for y axis of plot cmap='coolwarm', # colormap to use. (default used to be 'viridis') symmetric=True, # make color scale symmetric about zero, i.e. +/- same extremes - figsize=(8, 4), # matplotlib size of the figure + figsize=(8, 4), # matplotlib size of the figure dpi=100, # dpi of figure - mark_batches=False, # separate batches with dividing lines + mark_batches=False, # separate batches with dividing lines debug=False, # print debugging info ): "for visualizing embeddings in a spectrogram-like way" diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py index 64f6e3a8..48d59da2 100644 --- a/stable_audio_tools/models/conditioners.py +++ b/stable_audio_tools/models/conditioners.py @@ -287,7 +287,7 @@ class T5Conditioner(Conditioner): T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", - "google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl"] + "google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl", "google/t5gemma-b-b-ul2"] T5_MODEL_DIMS = { "t5-small": 512, @@ -304,6 +304,7 @@ class T5Conditioner(Conditioner): "google/flan-t5-11b": 1024, "google/flan-t5-xl": 2048, "google/flan-t5-xxl": 4096, + "google/t5gemma-b-b-ul2": 768 } def __init__( @@ -317,7 +318,7 @@ def __init__( assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) - from transformers import T5EncoderModel, AutoTokenizer + from transformers import T5EncoderModel, T5GemmaEncoderModel, AutoTokenizer self.max_length = max_length self.enable_grad = enable_grad @@ -331,7 +332,11 @@ def __init__( # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) - model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + if 'gemma' in t5_model_name: + #T5GemmaEncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"] + model = T5GemmaEncoderModel.from_pretrained(t5_model_name, is_encoder_decoder=False, torch_dtype=torch.float16).train(enable_grad).requires_grad_(enable_grad) + else: + model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) finally: logging.disable(previous_level) @@ -359,7 +364,7 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t self.model.eval() - with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + with torch.amp.autocast('cuda', dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): embeddings = self.model( input_ids=input_ids, attention_mask=attention_mask )["last_hidden_state"] @@ -758,4 +763,4 @@ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.An else: raise ValueError(f"Unknown conditioner type: {conditioner_type}") - return MultiConditioner(conditioners, default_keys=default_keys, pre_encoded_keys=pre_encoded_keys) \ No newline at end of file + return MultiConditioner(conditioners, default_keys=default_keys, pre_encoded_keys=pre_encoded_keys) diff --git a/stable_audio_tools/training/diffusion.py b/stable_audio_tools/training/diffusion.py index 4583b3e4..e2d943f5 100644 --- a/stable_audio_tools/training/diffusion.py +++ b/stable_audio_tools/training/diffusion.py @@ -10,6 +10,7 @@ from einops import rearrange from safetensors.torch import save_file from torch import optim +import bitsandbytes as bnb from torch.nn import functional as F from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -119,7 +120,7 @@ def training_step(self, batch, batch_idx): noised_inputs = diffusion_input * alphas + noise * sigmas targets = noise * alphas - diffusion_input * sigmas - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): v = self.diffusion(noised_inputs, t) loss_info.update({ @@ -184,7 +185,7 @@ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) try: - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) if module.diffusion.pretransform is not None: @@ -365,7 +366,7 @@ def training_step(self, batch, batch_idx): self.diffusion.pretransform.to(self.device) if not self.pre_encoded: - with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) diffusion_input = self.diffusion.pretransform.encode(diffusion_input) @@ -501,7 +502,7 @@ def validation_step(self, batch, batch_idx): diffusion_input = reals - with torch.cuda.amp.autocast() and torch.no_grad(): + with torch.amp.autocast('cuda') and torch.no_grad(): conditioning = self.diffusion.conditioner(metadata, self.device) # TODO: decide what to do with padding masks during validation @@ -517,7 +518,7 @@ def validation_step(self, batch, batch_idx): self.diffusion.pretransform.to(self.device) if not self.pre_encoded: - with torch.cuda.amp.autocast() and torch.no_grad(): + with torch.amp.autocast('cuda') and torch.no_grad(): self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) diffusion_input = self.diffusion.pretransform.encode(diffusion_input) @@ -556,7 +557,7 @@ def validation_step(self, batch, batch_idx): # if use_padding_mask: # extra_args["mask"] = padding_masks - with torch.cuda.amp.autocast() and torch.no_grad(): + with torch.amp.autocast('cuda') and torch.no_grad(): output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0, **extra_args) val_loss = F.mse_loss(output, targets) @@ -654,7 +655,7 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp try: print("Getting conditioning") - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): conditioning = module.diffusion.conditioner(demo_cond, module.device) cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) @@ -698,7 +699,7 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp print(f"Generating demo for cfg scale {cfg_scale}") - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): model = module.diffusion_ema.ema_model if module.diffusion_ema is not None else module.diffusion.model if module.diffusion_objective == "v": @@ -879,7 +880,7 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model print(f"Generating demo for cfg scale {cfg_scale}") - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): if module.diffusion_objective == "v": fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, dist_shift=module.diffusion.dist_shift, batch_cfg=True) elif module.diffusion_objective == "rectified_flow": @@ -1034,7 +1035,7 @@ def training_step(self, batch, batch_idx): noised_reals = reals * alphas + noise * sigmas targets = noise * alphas - reals * sigmas - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) loss_info.update({ @@ -1114,7 +1115,7 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe demo_reals = demo_reals.to(module.device) - with torch.no_grad() and torch.cuda.amp.autocast(): + with torch.no_grad() and torch.amp.autocast('cuda'): latents = module.diffae_ema.ema_model.encode(encoder_input).float() fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) @@ -1147,7 +1148,7 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe audio_spectrogram_image(reals_fakes)) if module.diffae_ema.ema_model.pretransform is not None: - with torch.no_grad() and torch.cuda.amp.autocast(): + with torch.no_grad() and torch.amp.autocast('cuda'): initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') @@ -1163,4 +1164,4 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe tokens_spectrogram_image(initial_latents)) log_image( trainer.logger, "first_stage_melspec_left", - audio_spectrogram_image(first_stage_fakes)) \ No newline at end of file + audio_spectrogram_image(first_stage_fakes)) diff --git a/stable_audio_tools/training/utils.py b/stable_audio_tools/training/utils.py index 8b69cd8c..74b698c6 100644 --- a/stable_audio_tools/training/utils.py +++ b/stable_audio_tools/training/utils.py @@ -1,10 +1,17 @@ from pytorch_lightning.loggers import WandbLogger, CometLogger from ..interface.aeiou import pca_point_cloud +from pytorch_optimizer.lr_scheduler.chebyshev import ( + get_chebyshev_perm_steps, + get_chebyshev_permutation, + get_chebyshev_schedule + ) import wandb import torch import os +import bitsandbytes as bnb + def get_rank(): """Get rank of current process.""" @@ -73,6 +80,8 @@ def create_optimizer_from_config(optimizer_config, parameters): if optimizer_type == "FusedAdam": from deepspeed.ops.adam import FusedAdam optimizer = FusedAdam(parameters, **optimizer_config["config"]) + elif optimizer_type == "AdamW8bit": + optimizer = bnb.optim.AdamW8bit(parameters, **optimizer_config["config"]) else: optimizer_fn = getattr(torch.optim, optimizer_type) optimizer = optimizer_fn(parameters, **optimizer_config["config"]) @@ -90,6 +99,8 @@ def create_scheduler_from_config(scheduler_config, optimizer): """ if scheduler_config["type"] == "InverseLR": scheduler_fn = InverseLR + elif scheduler_config["type"] == "Chebyshev": + scheduler_fn = get_chebyshev_schedule(optimizer=optimizer, num_epochs=scheduler_config["num_epochs"], is_warmup=scheduler_config["is_warmup"], last_epoch=scheduler_config["last_epoch"]) else: scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) @@ -126,4 +137,4 @@ def log_point_cloud(logger, key, tokens, caption=None): logger.experiment.log({key: point_cloud}) elif isinstance(logger, CometLogger): point_cloud = pca_point_cloud(tokens, rgb_float=True, output_type="points") - #logger.experiment.log_points_3d(scene_name=key, points=point_cloud) \ No newline at end of file + #logger.experiment.log_points_3d(scene_name=key, points=point_cloud) diff --git a/train.bat b/train.bat new file mode 100644 index 00000000..b493f136 --- /dev/null +++ b/train.bat @@ -0,0 +1,12 @@ +python train.py ^ + --name saos1 ^ + --pretrained-ckpt-path .\sao_small\base_model.ckpt ^ + --model-config .\sao_small\base_model_config.json ^ + --batch-size 8 ^ + --num-workers 8 ^ + --seed 1937401721 ^ + --checkpoint-every 1000 ^ + --dataset-config dataset_config.json ^ + --save-dir outputs ^ + --precision 16-mixed ^ +