-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[TTS] Add code for training semantic codec #15524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e845094
d9e8b46
08db7b2
4a8f4c8
b2f5f4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| name: AudioCodec | ||
|
|
||
| max_epochs: ??? | ||
| # Adjust batch size based on GPU memory | ||
| batch_size: 16 | ||
| # When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch. | ||
| # If null, then weighted sampling is disabled. | ||
| weighted_sampling_steps_per_epoch: null | ||
|
|
||
| # Dataset metadata for each manifest | ||
| # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41 | ||
| train_ds_meta: ??? | ||
| val_ds_meta: ??? | ||
|
|
||
| log_ds_meta: ??? | ||
| log_dir: ??? | ||
|
|
||
| semantic_codec_path: ??? | ||
|
|
||
| # Modify these values based on your sample rate | ||
| sample_rate: 16000 | ||
| samples_per_frame: 640 | ||
| train_n_samples: 12800 | ||
| # The product of the down_sample_rates and up_sample_rates should match the samples_per_frame. | ||
| # For example 2 * 5 * 8 * 8 = 640. | ||
| down_sample_rates: [2, 5, 8, 8] | ||
| up_sample_rates: [8, 8, 5, 2] | ||
|
|
||
| num_codebooks: 8 | ||
| encoder_out_dim: 42 | ||
| decoder_input_dim: 48 | ||
|
|
||
| model: | ||
|
|
||
| semantic_codec_path: ${semantic_codec_path} | ||
|
|
||
| max_epochs: ${max_epochs} | ||
| steps_per_epoch: ${weighted_sampling_steps_per_epoch} | ||
|
|
||
| sample_rate: ${sample_rate} | ||
| samples_per_frame: ${samples_per_frame} | ||
|
|
||
| mel_loss_l1_scale: 10.0 | ||
| mel_loss_l2_scale: 0.0 | ||
| stft_loss_scale: 10.0 | ||
| time_domain_loss_scale: 0.0 | ||
| commit_loss_scale: 0.0 | ||
|
|
||
| # Probability of updating the discriminator during each training step | ||
| # For example, update the discriminator 1/2 times (1 update for every 2 batches) | ||
| disc_updates_per_period: 1 | ||
| disc_update_period: 2 | ||
|
|
||
| # All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length] | ||
| loss_resolutions: [ | ||
| [32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024] | ||
| ] | ||
| mel_loss_dims: [5, 10, 20, 40, 80, 160] | ||
| mel_loss_log_guard: 1.0 | ||
| stft_loss_log_guard: 1.0 | ||
| feature_loss_type: absolute | ||
|
|
||
| train_ds: | ||
| dataset: | ||
| _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset | ||
| dataset_meta: ${train_ds_meta} | ||
| weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} | ||
| sample_rate: ${sample_rate} | ||
| n_samples: ${train_n_samples} | ||
| min_duration: 0.4 # seconds | ||
| max_duration: null | ||
|
|
||
| dataloader_params: | ||
| batch_size: ${batch_size} | ||
| drop_last: true | ||
| num_workers: 4 | ||
|
|
||
| validation_ds: | ||
| dataset: | ||
| _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset | ||
| sample_rate: ${sample_rate} | ||
| n_samples: null | ||
| min_duration: null | ||
| max_duration: null | ||
| trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss | ||
| dataset_meta: ${val_ds_meta} | ||
|
|
||
| dataloader_params: | ||
| batch_size: 4 | ||
| num_workers: 2 | ||
|
|
||
| # Configures how audio samples are generated and saved during training. | ||
| # Remove this section to disable logging. | ||
| log_config: | ||
| log_dir: ${log_dir} | ||
| log_epochs: [10, 50] | ||
| epoch_frequency: 100 | ||
| log_tensorboard: false | ||
| log_wandb: false | ||
|
|
||
| generators: | ||
| - _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator | ||
| log_audio: true | ||
| log_encoding: false | ||
| log_dequantized: false | ||
|
|
||
| dataset: | ||
| _target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset | ||
| sample_rate: ${sample_rate} | ||
| n_samples: null | ||
| min_duration: null | ||
| max_duration: null | ||
| trunc_duration: 10.0 # Only log the first 10 seconds of generated audio. | ||
| dataset_meta: ${log_ds_meta} | ||
|
|
||
| dataloader_params: | ||
| batch_size: 4 | ||
| num_workers: 2 | ||
|
|
||
| audio_encoder: | ||
| _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANEncoder | ||
| down_sample_rates: ${down_sample_rates} | ||
| encoded_dim: ${encoder_out_dim} | ||
| base_channels: 48 | ||
| activation: "lrelu" | ||
|
|
||
| audio_decoder: | ||
| _target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder | ||
| up_sample_rates: ${up_sample_rates} | ||
| input_dim: ${decoder_input_dim} | ||
| base_channels: 768 | ||
| activation: "half_snake" | ||
| output_activation: "clamp" | ||
|
|
||
| vector_quantizer: | ||
| _target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer | ||
| num_groups: ${num_codebooks} | ||
| num_levels_per_group: [4, 4, 4, 4, 4, 4] | ||
|
|
||
| discriminator: | ||
| _target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator | ||
| discriminators: | ||
| - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator | ||
| - _target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT | ||
| resolutions: [[512, 128, 512], [1024, 256, 1024]] | ||
| stft_bands: [[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]] | ||
|
|
||
| generator_loss: | ||
| _target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss | ||
|
|
||
| discriminator_loss: | ||
| _target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss | ||
|
|
||
| optim: | ||
| _target_: torch.optim.Adam | ||
| lr: 2e-4 | ||
| betas: [0.8, 0.99] | ||
|
|
||
| sched: | ||
| name: ExponentialLR | ||
| gamma: 0.998 | ||
|
|
||
| trainer: | ||
| num_nodes: 1 | ||
| devices: -1 | ||
| accelerator: gpu | ||
| strategy: ddp_find_unused_parameters_true | ||
| precision: 16 | ||
| max_epochs: ${max_epochs} | ||
| accumulate_grad_batches: 1 | ||
| enable_checkpointing: False # Provided by exp_manager | ||
| logger: false # Provided by exp_manager | ||
| log_every_n_steps: 100 | ||
| check_val_every_n_epoch: 10 | ||
| benchmark: false | ||
|
|
||
| exp_manager: | ||
| exp_dir: null | ||
| name: ${name} | ||
| create_tensorboard_logger: false | ||
| create_wandb_logger: false | ||
| wandb_logger_kwargs: | ||
| name: null | ||
| project: null | ||
| create_checkpoint_callback: true | ||
| checkpoint_callback_params: | ||
| monitor: val_loss | ||
| mode: min | ||
| save_top_k: 5 | ||
| save_best_model: true | ||
| always_save_nemo: true | ||
| resume_if_exists: false | ||
| resume_ignore_no_checkpoint: false |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| import os | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Dict, List, Optional | ||
| from typing import Dict, List, Optional, Tuple | ||
|
|
||
| import librosa | ||
| import soundfile as sf | ||
|
|
@@ -30,6 +30,7 @@ | |
| filter_dataset_by_duration, | ||
| get_weighted_sampler, | ||
| load_audio, | ||
| resample_batch, | ||
| sample_audio, | ||
| stack_tensors, | ||
| ) | ||
|
|
@@ -56,7 +57,7 @@ class DatasetSample: | |
| audio_dir: Path | ||
|
|
||
|
|
||
| def audio_collate_fn(batch: List[dict]): | ||
| def audio_collate_fn(batch: List[dict], resample_rates: Optional[Tuple[int]] = None): | ||
| dataset_name_list = [] | ||
| audio_filepath_list = [] | ||
| audio_list = [] | ||
|
|
@@ -73,6 +74,14 @@ def audio_collate_fn(batch: List[dict]): | |
|
|
||
| batch_audio = stack_tensors(audio_list, max_lens=[audio_max_len]) | ||
|
|
||
| if resample_rates: | ||
| batch_audio, batch_audio_len = resample_batch( | ||
| audio=batch_audio, | ||
| audio_len=batch_audio_len, | ||
| input_sample_rate=resample_rates[0], | ||
| output_sample_rate=resample_rates[1], | ||
| ) | ||
|
|
||
| batch_dict = { | ||
| "dataset_names": dataset_name_list, | ||
| "audio_filepaths": audio_filepath_list, | ||
|
|
@@ -117,7 +126,8 @@ class VocoderDataset(Dataset): | |
| Args: | ||
| dataset_meta: Dict of dataset names (string) to dataset metadata. | ||
| sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will | ||
| be resampled. | ||
| be resampled using librosa. | ||
| resample_rate: Optional sample rate to resample to, using torch-based resampling. | ||
| n_samples: Optional int, if provided then n_samples samples will be randomly sampled from the full | ||
| audio file. | ||
| weighted_sampling_steps_per_epoch: Optional int, If provided, then data will be sampled (with replacement) based on | ||
|
|
@@ -135,6 +145,7 @@ def __init__( | |
| self, | ||
| dataset_meta: Dict, | ||
| sample_rate: int, | ||
| resample_rate: Optional[int] = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you update the docstring to add this argument?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added docstring. The functionality might be a bit confusing, because the feature that is actually being added is the option to resample using batched NeMo code instead of librosa. |
||
| n_samples: Optional[int] = None, | ||
| weighted_sampling_steps_per_epoch: Optional[int] = None, | ||
| feature_processors: Optional[Dict[str, FeatureProcessor]] = None, | ||
|
|
@@ -146,6 +157,11 @@ def __init__( | |
| super().__init__() | ||
|
|
||
| self.sample_rate = sample_rate | ||
| if resample_rate and self.sample_rate != resample_rate: | ||
| self.resample_rates = [sample_rate, resample_rate] | ||
| else: | ||
| self.resample_rates = None | ||
|
|
||
| self.n_samples = n_samples | ||
| self.trunc_duration = trunc_duration | ||
| self.volume_norm = volume_norm | ||
|
|
@@ -221,7 +237,7 @@ def __getitem__(self, index): | |
| return example | ||
|
|
||
| def collate_fn(self, batch): | ||
| return audio_collate_fn(batch) | ||
| return audio_collate_fn(batch, resample_rates=self.resample_rates) | ||
|
|
||
|
|
||
| class TarredVocoderDataset(IterableDataset): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.