@@ -33,13 +33,8 @@ def kaiser_window(n: int, beta: float) -> Array:
3333 alpha = (n - 1 ) / 2.0
3434 time = jnp .arange (n )
3535 term = beta * jnp .sqrt (1 - ((time - alpha ) / alpha ) ** 2 )
36- jax .debug .print ("kaiser_window term - min: {min}, max: {max}" , min = term .min (), max = term .max ())
37-
3836 i0_term = jss .i0 (term )
3937 i0_beta = jss .i0 (beta )
40- jax .debug .print ("kaiser_window i0_term - min: {min}, max: {max}" , min = i0_term .min (), max = i0_term .max ())
41- jax .debug .print ("kaiser_window i0_beta: {val}" , val = i0_beta )
42-
4338 res = i0_term / i0_beta
4439 return res
4540
@@ -49,19 +44,14 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
4944 half_size = kernel_size // 2
5045 amplitude = 2.285 * (half_size - 1 ) * math .pi * delta_f + 7.95
5146
52- print (f"kaiser_sinc_filter1d amplitude: { amplitude } " )
53-
5447 if amplitude > 50.0 :
5548 beta = 0.1102 * (amplitude - 8.7 )
5649 elif amplitude >= 21.0 :
5750 beta = 0.5842 * (amplitude - 21 ) ** 0.4 + 0.07886 * (amplitude - 21.0 )
5851 else :
5952 beta = 0.0
6053
61- print (f"kaiser_sinc_filter1d beta: { beta } " )
62-
6354 window = kaiser_window (kernel_size , beta )
64- jax .debug .print ("kaiser_sinc_filter1d window - min: {min}, max: {max}" , min = window .min (), max = window .max ())
6555
6656 even = kernel_size % 2 == 0
6757 time = jnp .arange (- half_size , half_size ) + 0.5 if even else jnp .arange (kernel_size ) - half_size
@@ -75,10 +65,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
7565 jnp .ones_like (time ),
7666 jnp .sin (math .pi * time ) / math .pi / time ,
7767 )
78- jax .debug .print ("kaiser_sinc_filter1d sinc - min: {min}, max: {max}" , min = sinc .min (), max = sinc .max ())
79-
8068 filter = 2 * cutoff * window * sinc
81- jax .debug .print ("kaiser_sinc_filter1d before norm - min: {min}, max: {max}, sum: {sum}" , min = filter .min (), max = filter .max (), sum = filter .sum ())
8269 filter = filter / filter .sum ()
8370 return filter
8471
@@ -121,7 +108,6 @@ def __call__(self, x: Array) -> Array:
121108 dimension_numbers = ('NLC' , 'LIO' , 'NLC' ),
122109 feature_group_count = num_channels ,
123110 )
124- jax .debug .print ("DownSample1d after conv - min: {min}, max: {max}" , min = x_filtered .min (), max = x_filtered .max ())
125111 return x_filtered
126112
127113
0 commit comments