Skip to content

Commit 9ed8f7d

Browse files
committed
refactor wav2mel module
1 parent c14058e commit 9ed8f7d

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

data/wav2mel.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
"""Wav2Mel for processing audio data."""
22

33
import torch
4+
import torch.nn as nn
45
from torchaudio.sox_effects import apply_effects_tensor
56
from torchaudio.transforms import MelSpectrogram
67

78

8-
class Wav2Mel(torch.nn.Module):
9+
class Wav2Mel(nn.Module):
910
"""Transform audio file into mel spectrogram tensors."""
1011

1112
def __init__(
1213
self,
13-
sample_rate: float = 16000,
14+
sample_rate: int = 16000,
1415
norm_db: float = -3.0,
1516
sil_threshold: float = 1.0,
1617
sil_duration: float = 0.1,
@@ -41,7 +42,7 @@ def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
4142
return mel_tensor
4243

4344

44-
class SoxEffects(torch.nn.Module):
45+
class SoxEffects(nn.Module):
4546
"""Transform waveform tensors."""
4647

4748
def __init__(
@@ -72,12 +73,12 @@ def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor:
7273
return wav_tensor
7374

7475

75-
class LogMelspectrogram(torch.nn.Module):
76+
class LogMelspectrogram(nn.Module):
7677
"""Transform waveform tensors into log mel spectrogram tensors."""
7778

7879
def __init__(
7980
self,
80-
sample_rate: float,
81+
sample_rate: int,
8182
fft_window_ms: float,
8283
fft_hop_ms: float,
8384
f_min: float,
@@ -94,4 +95,4 @@ def __init__(
9495

9596
def forward(self, wav_tensor: torch.Tensor) -> torch.Tensor:
9697
mel_tensor = self.melspectrogram(wav_tensor).squeeze(0).T # (time, n_mels)
97-
return torch.log(mel_tensor.squeeze(0) + 1e-9)
98+
return torch.log(torch.clamp(mel_tensor, min=1e-9))

0 commit comments

Comments
 (0)