Skip to content

Commit 4ee0cfe

Browse files
committed
Add audio codec model tests
Signed-off-by: Ryan <rlangman@nvidia.com>
1 parent ca1c5dd commit 4ee0cfe

File tree

4 files changed

+407
-0
lines changed

4 files changed

+407
-0
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
name: AudioCodec
2+
3+
max_epochs: ???
4+
# Adjust batch size based on GPU memory
5+
batch_size: 16
6+
# When doing weighted sampling with multiple manifests, this defines how many training steps are in an epoch.
7+
# If null, then weighted sampling is disabled.
8+
weighted_sampling_steps_per_epoch: null
9+
10+
# Dataset metadata for each manifest
11+
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/data/vocoder_dataset.py#L39-L41
12+
train_ds_meta: ???
13+
val_ds_meta: ???
14+
15+
log_ds_meta: ???
16+
log_dir: ???
17+
18+
semantic_codec_path: ???
19+
20+
# Modify these values based on your sample rate
21+
sample_rate: 16000
22+
samples_per_frame: 640
23+
train_n_samples: 12800
24+
# The product of the down_sample_rates and up_sample_rates should match the samples_per_frame.
25+
# For example 2 * 5 * 8 * 8 = 640.
26+
down_sample_rates: [2, 5, 8, 8]
27+
up_sample_rates: [8, 8, 5, 2]
28+
29+
num_codebooks: 8
30+
encoder_out_dim: 42
31+
decoder_input_dim: 48
32+
33+
model:
34+
35+
semantic_codec_path: ${semantic_codec_path}
36+
37+
max_epochs: ${max_epochs}
38+
steps_per_epoch: ${weighted_sampling_steps_per_epoch}
39+
40+
sample_rate: ${sample_rate}
41+
samples_per_frame: ${samples_per_frame}
42+
43+
mel_loss_l1_scale: 10.0
44+
mel_loss_l2_scale: 0.0
45+
stft_loss_scale: 10.0
46+
time_domain_loss_scale: 0.0
47+
commit_loss_scale: 0.0
48+
49+
# Probability of updating the discriminator during each training step
50+
# For example, update the discriminator 1/2 times (1 update for every 2 batches)
51+
disc_updates_per_period: 1
52+
disc_update_period: 2
53+
54+
# All resolutions for mel reconstruction loss, ordered [num_fft, hop_length, window_length]
55+
loss_resolutions: [
56+
[32, 8, 32], [64, 16, 64], [128, 32, 128], [256, 64, 256], [512, 128, 512], [1024, 256, 1024]
57+
]
58+
mel_loss_dims: [5, 10, 20, 40, 80, 160]
59+
mel_loss_log_guard: 1.0
60+
stft_loss_log_guard: 1.0
61+
feature_loss_type: absolute
62+
63+
train_ds:
64+
dataset:
65+
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
66+
dataset_meta: ${train_ds_meta}
67+
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch}
68+
sample_rate: ${sample_rate}
69+
n_samples: ${train_n_samples}
70+
min_duration: 0.4 # seconds
71+
max_duration: null
72+
73+
dataloader_params:
74+
batch_size: ${batch_size}
75+
drop_last: true
76+
num_workers: 4
77+
78+
validation_ds:
79+
dataset:
80+
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
81+
sample_rate: ${sample_rate}
82+
n_samples: null
83+
min_duration: null
84+
max_duration: null
85+
trunc_duration: 10.0 # Only use the first 10 seconds of audio for computing validation loss
86+
dataset_meta: ${val_ds_meta}
87+
88+
dataloader_params:
89+
batch_size: 4
90+
num_workers: 2
91+
92+
# Configures how audio samples are generated and saved during training.
93+
# Remove this section to disable logging.
94+
log_config:
95+
log_dir: ${log_dir}
96+
log_epochs: [10, 50]
97+
epoch_frequency: 100
98+
log_tensorboard: false
99+
log_wandb: false
100+
101+
generators:
102+
- _target_: nemo.collections.tts.parts.utils.callbacks.AudioCodecArtifactGenerator
103+
log_audio: true
104+
log_encoding: false
105+
log_dequantized: false
106+
107+
dataset:
108+
_target_: nemo.collections.tts.data.vocoder_dataset.VocoderDataset
109+
sample_rate: ${sample_rate}
110+
n_samples: null
111+
min_duration: null
112+
max_duration: null
113+
trunc_duration: 10.0 # Only log the first 10 seconds of generated audio.
114+
dataset_meta: ${log_ds_meta}
115+
116+
dataloader_params:
117+
batch_size: 4
118+
num_workers: 2
119+
120+
audio_encoder:
121+
_target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANEncoder
122+
down_sample_rates: ${down_sample_rates}
123+
encoded_dim: ${encoder_out_dim}
124+
base_channels: 48
125+
activation: "lrelu"
126+
127+
audio_decoder:
128+
_target_: nemo.collections.tts.modules.audio_codec_modules.HiFiGANDecoder
129+
up_sample_rates: ${up_sample_rates}
130+
input_dim: ${decoder_input_dim}
131+
base_channels: 768
132+
activation: "half_snake"
133+
output_activation: "clamp"
134+
135+
vector_quantizer:
136+
_target_: nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer
137+
num_groups: ${num_codebooks}
138+
num_levels_per_group: [4, 4, 4, 4, 4, 4]
139+
140+
discriminator:
141+
_target_: nemo.collections.tts.modules.audio_codec_modules.Discriminator
142+
discriminators:
143+
- _target_: nemo.collections.tts.modules.audio_codec_modules.MultiPeriodDiscriminator
144+
- _target_: nemo.collections.tts.modules.audio_codec_modules.MultiResolutionDiscriminatorSTFT
145+
resolutions: [[512, 128, 512], [1024, 256, 1024]]
146+
stft_bands: [[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]]
147+
148+
generator_loss:
149+
_target_: nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss
150+
151+
discriminator_loss:
152+
_target_: nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss
153+
154+
optim:
155+
_target_: torch.optim.Adam
156+
lr: 2e-4
157+
betas: [0.8, 0.99]
158+
159+
sched:
160+
name: ExponentialLR
161+
gamma: 0.998
162+
163+
trainer:
164+
num_nodes: 1
165+
devices: -1
166+
accelerator: gpu
167+
strategy: ddp_find_unused_parameters_true
168+
precision: 16
169+
max_epochs: ${max_epochs}
170+
accumulate_grad_batches: 1
171+
enable_checkpointing: False # Provided by exp_manager
172+
logger: false # Provided by exp_manager
173+
log_every_n_steps: 100
174+
check_val_every_n_epoch: 10
175+
benchmark: false
176+
177+
exp_manager:
178+
exp_dir: null
179+
name: ${name}
180+
create_tensorboard_logger: false
181+
create_wandb_logger: false
182+
wandb_logger_kwargs:
183+
name: null
184+
project: null
185+
create_checkpoint_callback: true
186+
checkpoint_callback_params:
187+
monitor: val_loss
188+
mode: min
189+
save_top_k: 5
190+
save_best_model: true
191+
always_save_nemo: true
192+
resume_if_exists: false
193+
resume_ignore_no_checkpoint: false
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import pytest
17+
import torch
18+
from omegaconf import DictConfig
19+
20+
from nemo.collections.tts.models import AudioCodecModel
21+
22+
23+
def create_codec_config():
24+
audio_encoder = {
25+
'cls': 'nemo.collections.tts.modules.audio_codec_modules.MultiResolutionSTFTEncoder',
26+
'params': {
27+
'out_dim': 40,
28+
'resolutions': [[960, 240, 960], [1920, 480, 1920]],
29+
'resolution_filter_list': [256, 512],
30+
},
31+
}
32+
audio_decoder = {
33+
'cls': 'nemo.collections.tts.modules.audio_codec_modules.ResNetDecoder',
34+
'params': {
35+
'input_dim': 40,
36+
'input_filters': 512,
37+
'n_hidden_layers': 6,
38+
'hidden_filters': 512,
39+
'pre_up_sample_rates': [],
40+
'pre_up_sample_filters': [],
41+
'resblock_up_sample_rates': [10, 8, 6],
42+
'resblock_up_sample_filters': [256, 128, 32],
43+
},
44+
}
45+
vector_quantizer = {
46+
'cls': 'nemo.collections.tts.modules.audio_codec_modules.GroupFiniteScalarQuantizer',
47+
'params': {
48+
'num_groups': 8,
49+
'num_levels_per_group': [4, 4, 4, 4, 4],
50+
},
51+
}
52+
generator_loss = {
53+
'cls': 'nemo.collections.tts.losses.audio_codec_loss.GeneratorSquaredLoss',
54+
}
55+
discriminator_loss = {
56+
'cls': 'nemo.collections.tts.losses.audio_codec_loss.DiscriminatorSquaredLoss',
57+
}
58+
59+
model_cfg = DictConfig(
60+
{
61+
'sample_rate': 24000,
62+
'samples_per_frame': 480,
63+
'loss_resolutions': [[960, 240, 960], [1920, 480, 1920]],
64+
'mel_loss_dims': [160, 320],
65+
'commit_loss_scale': 0.0,
66+
'audio_encoder': DictConfig(audio_encoder),
67+
'audio_decoder': DictConfig(audio_decoder),
68+
'vector_quantizer': DictConfig(vector_quantizer),
69+
'generator_loss': DictConfig(generator_loss),
70+
'discriminator_loss': DictConfig(discriminator_loss),
71+
}
72+
)
73+
return model_cfg
74+
75+
76+
@pytest.fixture()
77+
def codec_model():
78+
model_cfg = create_codec_config()
79+
codec_model = AudioCodecModel(cfg=model_cfg)
80+
return codec_model
81+
82+
83+
@pytest.fixture()
84+
def acoustic_codec_model():
85+
semantic_model_cfg = create_codec_config()
86+
semantic_model_cfg.vector_quantizer.params.num_groups = 1
87+
semantic_model_cfg.audio_encoder.params.out_dim = 5
88+
semantic_model_cfg.audio_decoder.params.input_dim = 5
89+
90+
acoustic_model_cfg = create_codec_config()
91+
acoustic_model_cfg.semantic_codec = semantic_model_cfg
92+
acoustic_model_cfg.audio_encoder.params.out_dim = 35
93+
acoustic_codec_model = AudioCodecModel(cfg=acoustic_model_cfg)
94+
95+
return acoustic_codec_model
96+
97+
98+
class TestAudioCodecModel:
99+
@pytest.mark.unit
100+
def test_forward(self, codec_model):
101+
batch_size = 2
102+
audio = torch.randn(size=(batch_size, 20000))
103+
audio_len = torch.randint(size=[batch_size], low=10000, high=20000)
104+
output_audio, output_audio_len = codec_model.forward(
105+
audio=audio, audio_len=audio_len, sample_rate=codec_model.sample_rate
106+
)
107+
assert output_audio.shape[0] == batch_size
108+
assert output_audio.shape[1] == output_audio_len.max()
109+
110+
@pytest.mark.unit
111+
def test_forward_with_acoustic_codec(self, acoustic_codec_model):
112+
batch_size = 3
113+
audio = torch.randn(size=(batch_size, 20000))
114+
audio_len = torch.randint(size=[batch_size], low=10000, high=20000)
115+
output_audio, output_audio_len = acoustic_codec_model.forward(
116+
audio=audio, audio_len=audio_len, sample_rate=acoustic_codec_model.sample_rate
117+
)
118+
assert output_audio.shape[0] == batch_size
119+
assert output_audio.shape[1] == output_audio_len.max()
120+
121+
@pytest.mark.unit
122+
def test_encode_and_decode(self, codec_model):
123+
batch_size = 4
124+
audio = torch.randn(size=(batch_size, 20000))
125+
audio_len = torch.randint(size=[batch_size], low=10000, high=20000)
126+
127+
tokens, tokens_len = codec_model.encode(audio=audio, audio_len=audio_len, sample_rate=codec_model.sample_rate)
128+
assert tokens.shape[0] == batch_size
129+
assert tokens.shape[2] == tokens_len.max()
130+
131+
output_audio, output_audio_len = codec_model.decode(tokens=tokens, tokens_len=tokens_len)
132+
assert output_audio.shape[0] == batch_size
133+
assert output_audio.shape[1] == output_audio_len.max()
134+
135+
@pytest.mark.unit
136+
def test_encode_and_decode_with_acoustic_codec(self, acoustic_codec_model):
137+
batch_size = 5
138+
audio = torch.randn(size=(batch_size, 20000))
139+
audio_len = torch.randint(size=[batch_size], low=10000, high=20000)
140+
141+
tokens, tokens_len = acoustic_codec_model.encode(
142+
audio=audio, audio_len=audio_len, sample_rate=acoustic_codec_model.sample_rate
143+
)
144+
assert tokens.shape[0] == batch_size
145+
assert tokens.shape[2] == tokens_len.max()
146+
147+
output_audio, output_audio_len = acoustic_codec_model.decode(tokens=tokens, tokens_len=tokens_len)
148+
assert output_audio.shape[0] == batch_size
149+
assert output_audio.shape[1] == output_audio_len.max()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/audio_codec.py \
15+
--config-name acoustic_codec_16000.yaml \
16+
semantic_codec_path="/home/TestData/tts/TestSemanticCodec.nemo" \
17+
+train_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_train_context_v1.json" \
18+
+train_ds_meta.an4.audio_dir="/" \
19+
+train_ds_meta.an4.sample_weight=1.0 \
20+
+val_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \
21+
+val_ds_meta.an4.audio_dir="/" \
22+
+log_ds_meta.an4.manifest_path="/home/TestData/an4_dataset/an4_val_context_v1.json" \
23+
+log_ds_meta.an4.audio_dir="/" \
24+
log_dir="/tmp/audio_codec_training_output" \
25+
max_epochs=1 \
26+
batch_size=4 \
27+
weighted_sampling_steps_per_epoch=10 \
28+
+trainer.limit_val_batches=1 \
29+
trainer.devices="[0]" \
30+
trainer.strategy=auto \
31+
model.train_ds.dataloader_params.num_workers=0 \
32+
model.validation_ds.dataloader_params.num_workers=0 \
33+
~trainer.check_val_every_n_epoch

0 commit comments

Comments
 (0)