Skip to content

Commit 17b43c2

Browse files
LTX audio vae novram fixes. (#12796)
1 parent 8befce5 commit 17b43c2

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

comfy/ldm/lightricks/vocoders/vocoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def forward(self, x):
8282
_, C, _ = x.shape
8383
if self.padding:
8484
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
85-
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
85+
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
8686

8787

8888
class UpSample1d(nn.Module):
@@ -191,7 +191,7 @@ def __init__(
191191
self.eps = 1e-9
192192

193193
def forward(self, x):
194-
a = self.alpha.unsqueeze(0).unsqueeze(-1)
194+
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
195195
if self.alpha_logscale:
196196
a = torch.exp(a)
197197
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
@@ -218,8 +218,8 @@ def __init__(
218218
self.eps = 1e-9
219219

220220
def forward(self, x):
221-
a = self.alpha.unsqueeze(0).unsqueeze(-1)
222-
b = self.beta.unsqueeze(0).unsqueeze(-1)
221+
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
222+
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
223223
if self.alpha_logscale:
224224
a = torch.exp(a)
225225
b = torch.exp(b)
@@ -597,7 +597,7 @@ def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
597597
y = y.unsqueeze(1) # (B, 1, T)
598598
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
599599
y = F.pad(y, (left_pad, 0))
600-
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
600+
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
601601
n_freqs = spec.shape[1] // 2
602602
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
603603
magnitude = torch.sqrt(real ** 2 + imag ** 2)
@@ -648,7 +648,7 @@ def mel_spectrogram(
648648
"""
649649
magnitude, phase = self.stft_fn(y)
650650
energy = torch.norm(magnitude, dim=1)
651-
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
651+
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
652652
log_mel = torch.log(torch.clamp(mel, min=1e-5))
653653
return log_mel, magnitude, phase, energy
654654

0 commit comments

Comments
 (0)