Skip to content

Commit 2bf223a

Browse files
committed
TDT fix
Signed-off-by: arushid <arushid@nvidia.com>
1 parent 29f3884 commit 2bf223a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

nemo/collections/asr/parts/preprocessing/features.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def normalize_batch(x, seq_len, normalize_type):
8181
x_std = x_std.masked_fill(x_std.isnan(), 0.0) # edge case: only 1 frame in denominator
8282
# make sure x_std is not zero
8383
x_std += CONSTANT
84-
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
84+
normalized = (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
85+
normalized.masked_fill_(~valid_mask.unsqueeze(1), 0.0)
86+
return normalized, x_mean, x_std
8587
elif normalize_type == "all_features":
8688
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
8789
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)

0 commit comments

Comments
 (0)