From 82daf0db4b6da1326080df0799764f68912a9d61 Mon Sep 17 00:00:00 2001 From: arushid Date: Sun, 22 Mar 2026 17:34:42 +0530 Subject: [PATCH] TDT fix Signed-off-by: arushid --- nemo/collections/asr/parts/preprocessing/features.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index ec0fa8f6f74d..6e359ef05a3a 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -81,7 +81,9 @@ def normalize_batch(x, seq_len, normalize_type): x_std = x_std.masked_fill(x_std.isnan(), 0.0) # edge case: only 1 frame in denominator # make sure x_std is not zero x_std += CONSTANT - return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std + normalized = (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) + normalized.masked_fill_(~valid_mask.unsqueeze(1), 0.0) + return normalized, x_mean, x_std elif normalize_type == "all_features": x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)