Skip to content

Commit e90f77d

Browse files
committed
remove debug
1 parent f2d8574 commit e90f77d

1 file changed

Lines changed: 0 additions & 14 deletions

File tree

src/maxdiffusion/models/ltx2/vocoder_ltx2.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@ def kaiser_window(n: int, beta: float) -> Array:
3333
alpha = (n - 1) / 2.0
3434
time = jnp.arange(n)
3535
term = beta * jnp.sqrt(1 - ((time - alpha) / alpha) ** 2)
36-
jax.debug.print("kaiser_window term - min: {min}, max: {max}", min=term.min(), max=term.max())
37-
3836
i0_term = jss.i0(term)
3937
i0_beta = jss.i0(beta)
40-
jax.debug.print("kaiser_window i0_term - min: {min}, max: {max}", min=i0_term.min(), max=i0_term.max())
41-
jax.debug.print("kaiser_window i0_beta: {val}", val=i0_beta)
42-
4338
res = i0_term / i0_beta
4439
return res
4540

@@ -49,19 +44,14 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
4944
half_size = kernel_size // 2
5045
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
5146

52-
print(f"kaiser_sinc_filter1d amplitude: {amplitude}")
53-
5447
if amplitude > 50.0:
5548
beta = 0.1102 * (amplitude - 8.7)
5649
elif amplitude >= 21.0:
5750
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
5851
else:
5952
beta = 0.0
6053

61-
print(f"kaiser_sinc_filter1d beta: {beta}")
62-
6354
window = kaiser_window(kernel_size, beta)
64-
jax.debug.print("kaiser_sinc_filter1d window - min: {min}, max: {max}", min=window.min(), max=window.max())
6555

6656
even = kernel_size % 2 == 0
6757
time = jnp.arange(-half_size, half_size) + 0.5 if even else jnp.arange(kernel_size) - half_size
@@ -75,10 +65,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
7565
jnp.ones_like(time),
7666
jnp.sin(math.pi * time) / math.pi / time,
7767
)
78-
jax.debug.print("kaiser_sinc_filter1d sinc - min: {min}, max: {max}", min=sinc.min(), max=sinc.max())
79-
8068
filter = 2 * cutoff * window * sinc
81-
jax.debug.print("kaiser_sinc_filter1d before norm - min: {min}, max: {max}, sum: {sum}", min=filter.min(), max=filter.max(), sum=filter.sum())
8269
filter = filter / filter.sum()
8370
return filter
8471

@@ -121,7 +108,6 @@ def __call__(self, x: Array) -> Array:
121108
dimension_numbers=('NLC', 'LIO', 'NLC'),
122109
feature_group_count=num_channels,
123110
)
124-
jax.debug.print("DownSample1d after conv - min: {min}, max: {max}", min=x_filtered.min(), max=x_filtered.max())
125111
return x_filtered
126112

127113

0 commit comments

Comments
 (0)