11"""Wav2Mel for processing audio data."""
22
33import torch
4+ import torch .nn as nn
45from torchaudio .sox_effects import apply_effects_tensor
56from torchaudio .transforms import MelSpectrogram
67
78
8- class Wav2Mel (torch . nn .Module ):
9+ class Wav2Mel (nn .Module ):
910 """Transform audio file into mel spectrogram tensors."""
1011
1112 def __init__ (
1213 self ,
13- sample_rate : float = 16000 ,
14+ sample_rate : int = 16000 ,
1415 norm_db : float = - 3.0 ,
1516 sil_threshold : float = 1.0 ,
1617 sil_duration : float = 0.1 ,
@@ -41,7 +42,7 @@ def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
4142 return mel_tensor
4243
4344
44- class SoxEffects (torch . nn .Module ):
45+ class SoxEffects (nn .Module ):
4546 """Transform waveform tensors."""
4647
4748 def __init__ (
@@ -72,12 +73,12 @@ def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
7273 return wav_tensor
7374
7475
75- class LogMelspectrogram (torch . nn .Module ):
76+ class LogMelspectrogram (nn .Module ):
7677 """Transform waveform tensors into log mel spectrogram tensors."""
7778
7879 def __init__ (
7980 self ,
80- sample_rate : float ,
81+ sample_rate : int ,
8182 fft_window_ms : float ,
8283 fft_hop_ms : float ,
8384 f_min : float ,
@@ -94,4 +95,4 @@ def __init__(
9495
9596 def forward (self , wav_tensor : torch .Tensor ) -> torch .Tensor :
9697 mel_tensor = self .melspectrogram (wav_tensor ).squeeze (0 ).T # (time, n_mels)
97- return torch .log (mel_tensor . squeeze ( 0 ) + 1e-9 )
98+ return torch .log (torch . clamp ( mel_tensor , min = 1e-9 ) )
0 commit comments