@@ -82,7 +82,7 @@ def forward(self, x):
8282 _ , C , _ = x .shape
8383 if self .padding :
8484 x = F .pad (x , (self .pad_left , self .pad_right ), mode = self .padding_mode )
85- return F .conv1d (x , self .filter .expand (C , - 1 , - 1 ), stride = self .stride , groups = C )
85+ return F .conv1d (x , comfy . model_management . cast_to ( self .filter .expand (C , - 1 , - 1 ), dtype = x . dtype , device = x . device ), stride = self .stride , groups = C )
8686
8787
8888class UpSample1d (nn .Module ):
@@ -191,7 +191,7 @@ def __init__(
191191 self .eps = 1e-9
192192
193193 def forward (self , x ):
194- a = self .alpha .unsqueeze (0 ).unsqueeze (- 1 )
194+ a = comfy . model_management . cast_to ( self .alpha .unsqueeze (0 ).unsqueeze (- 1 ), dtype = x . dtype , device = x . device )
195195 if self .alpha_logscale :
196196 a = torch .exp (a )
197197 return x + (1.0 / (a + self .eps )) * torch .sin (x * a ).pow (2 )
@@ -218,8 +218,8 @@ def __init__(
218218 self .eps = 1e-9
219219
220220 def forward (self , x ):
221- a = self .alpha .unsqueeze (0 ).unsqueeze (- 1 )
222- b = self .beta .unsqueeze (0 ).unsqueeze (- 1 )
221+ a = comfy . model_management . cast_to ( self .alpha .unsqueeze (0 ).unsqueeze (- 1 ), dtype = x . dtype , device = x . device )
222+ b = comfy . model_management . cast_to ( self .beta .unsqueeze (0 ).unsqueeze (- 1 ), dtype = x . dtype , device = x . device )
223223 if self .alpha_logscale :
224224 a = torch .exp (a )
225225 b = torch .exp (b )
@@ -597,7 +597,7 @@ def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
597597 y = y .unsqueeze (1 ) # (B, 1, T)
598598 left_pad = max (0 , self .win_length - self .hop_length ) # causal: left-only
599599 y = F .pad (y , (left_pad , 0 ))
600- spec = F .conv1d (y , self .forward_basis , stride = self .hop_length , padding = 0 )
600+ spec = F .conv1d (y , comfy . model_management . cast_to ( self .forward_basis , dtype = y . dtype , device = y . device ) , stride = self .hop_length , padding = 0 )
601601 n_freqs = spec .shape [1 ] // 2
602602 real , imag = spec [:, :n_freqs ], spec [:, n_freqs :]
603603 magnitude = torch .sqrt (real ** 2 + imag ** 2 )
@@ -648,7 +648,7 @@ def mel_spectrogram(
648648 """
649649 magnitude , phase = self .stft_fn (y )
650650 energy = torch .norm (magnitude , dim = 1 )
651- mel = torch .matmul (self . mel_basis . to ( magnitude .dtype ), magnitude )
651+ mel = torch .matmul (comfy . model_management . cast_to ( self . mel_basis , dtype = magnitude .dtype , device = y . device ), magnitude )
652652 log_mel = torch .log (torch .clamp (mel , min = 1e-5 ))
653653 return log_mel , magnitude , phase , energy
654654
0 commit comments