We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c42e444 commit 82daf0dCopy full SHA for 82daf0d
1 file changed
nemo/collections/asr/parts/preprocessing/features.py
@@ -81,7 +81,9 @@ def normalize_batch(x, seq_len, normalize_type):
81
x_std = x_std.masked_fill(x_std.isnan(), 0.0) # edge case: only 1 frame in denominator
82
# make sure x_std is not zero
83
x_std += CONSTANT
84
- return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
+ normalized = (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2)
85
+ normalized.masked_fill_(~valid_mask.unsqueeze(1), 0.0)
86
+ return normalized, x_mean, x_std
87
elif normalize_type == "all_features":
88
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
89
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
0 commit comments