VoiceChat EA STT training reproducible features#15558
VoiceChat EA STT training reproducible features#15558ankitapasad wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Ankita Pasad <apasad@nvidia.com>
…ization, clean-up token ID init, and corresponding tests Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Ankita Pasad <apasad@nvidia.com>
| import os | ||
|
|
||
| import pytest | ||
| import torch |
Check notice
Code scanning / CodeQL
Unused import Note test
| assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS" | ||
|
|
||
| # Now collate source tokens, passing in the target channel for EOS placement | ||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS" | ||
|
|
||
| # Now collate source tokens, passing in the target channel for EOS placement | ||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| skip_eos=True, | ||
| ) | ||
|
|
||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| skip_eos=True, | ||
| ) | ||
|
|
||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| from nemo.collections.common.tokenizers import AutoTokenizer | ||
| from nemo.collections.speechlm2.data.duplex_stt_dataset import DuplexSTTDataset | ||
| from nemo.collections.speechlm2.data.utils import get_pad_id |
Check notice
Code scanning / CodeQL
Unused import Note test
| train_batch = train_ds[cuts] | ||
| val_batch = val_ds[cuts] | ||
|
|
||
| train_targets = train_batch["audio_data"]["target_tokens"] |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| # Force aligner should be created but never called during validation | ||
| val_ds.force_aligner = MagicMock() | ||
| val_ds[cuts] |
Check notice
Code scanning / CodeQL
Statement has no effect Note test
| # Mock the force aligner to avoid loading wav2vec2 | ||
| train_ds.force_aligner = MagicMock() | ||
| train_ds.force_aligner.batch_force_align_user_audio.side_effect = lambda cuts, **kwargs: cuts | ||
| train_ds[cuts] |
Check notice
Code scanning / CodeQL
Statement has no effect Note test
| - is_mcq_cut_train / is_mcq_cut_val / is_asr_cut | ||
| """ | ||
|
|
||
| import pytest |
Check notice
Code scanning / CodeQL
Unused import Note test
| assert tokenizer.bos is not None, "BOS support in the tokenizer is required." | ||
| assert tokenizer.eos is not None, "EOS support in the tokenizer is required." | ||
|
|
||
| user_bos_token = '^' |
There was a problem hiding this comment.
I use the same bos and eos for user and agent channels. I feel that is cleaner and I verified that does not impact model performance. I see you want to make exactly match EA, let's make these as configurable and one can set ^ and
| Prompt selection priority: | ||
| 1. Per-cut custom prompt (cut.custom['system_prompt']) | ||
| 2. MCQ training cut -> THINK prompt for think-cuts, NOTHINK prompt for others | ||
| 3. MCQ validation cut (when add_mcq_prompt=True) -> THINK prompt |
There was a problem hiding this comment.
Can you also add a support for custom prompt? We can then easily evaluate different demo setups we have used.
|
A high-level question: Can you also share a training script/wandb to make sure metrics look roughly good? I think additional efforts may be needed to catch the EA ckpt but it is better to check at intermediate steps as well. |
| tokenizer: TokenizerSpec, | ||
| train_dataset: torch.utils.data.Dataset = None, | ||
| val_dataset: torch.utils.data.Dataset = None, | ||
| dataset: torch.utils.data.Dataset = None, |
There was a problem hiding this comment.
That's too many datasets. It's OK to remove dataset parameter and property, and update the code across collection to use train_dataset instead.
| @@ -11,9 +11,26 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
There was a problem hiding this comment.
Move the changes in this file to a new file nemo/collections/speechlm2/data/mcq.py and document what is the purpose and the top-level entry-point, with expected usage. Let's make this re-usable across models/projects.
| @@ -11,6 +11,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import json | |||
There was a problem hiding this comment.
I skipped the review of this file - I don't understand the logic being added here from a quick skim.
Something for your consideration: this file has 1k+ lines of complex data preparation logic that isn't well documented. Users trying to train or finetune VoiceChat may have a hard time understanding how to prepare data and which options to use. Try to build this documentation, with examples showing what's the expected input and output for each of these steps (and the entire pipeline).
What does this PR do ?
Adds following features to the dataset class to support VoiceChat EA STT training and fine-tuning
Internal-access document with training, inference recipes, and notes on parity.
Collection: speechlm2
Usage
# Add a code snippet demonstrating how to use thisPR Type:
If you haven't finished some of the above items you can still open "Draft" PR.