@@ -678,54 +678,35 @@ def __init__(
678678 window_type = "hann" ,
679679 )
680680
681+ @nnx .jit
681682 def __call__ (self , mel_spec : Array ) -> Array :
682- print (f"=== BWE Vocoder Debug ===" )
683- print (f"Input mel_spec - shape: { mel_spec .shape } , min: { mel_spec .min ()} , max: { mel_spec .max ()} " )
684-
685683 x = self .vocoder (mel_spec )
686- print (f"Base vocoder output (x) - shape: { x .shape } , min: { x .min ()} , max: { x .max ()} " )
687-
688684 x = jnp .transpose (x , (0 , 2 , 1 ))
689685 batch_size , num_samples , num_channels = x .shape
690- print (f"Transposed x - shape: { x .shape } " )
691686
692687 remainder = num_samples % self .hop_length
693688 if remainder != 0 :
694689 x = jnp .pad (x , ((0 , 0 ), (0 , self .hop_length - remainder ), (0 , 0 )))
695- print (f"Padded x - shape: { x .shape } " )
696690
697691 x_flattened = x .transpose (0 , 2 , 1 ).reshape (- 1 , x .shape [1 ], 1 )
698- print (f"x_flattened - shape: { x_flattened .shape } " )
699-
700692 log_mel , _ , _ , _ = self .mel_stft (x_flattened )
701- print (f"MelSTFT output (log_mel) before reshape - shape: { log_mel .shape } , min: { log_mel .min ()} , max: { log_mel .max ()} " )
702693
703694 log_mel = log_mel .reshape (batch_size , num_channels , - 1 , log_mel .shape [- 1 ])
704- print (f"Reshaped log_mel - shape: { log_mel .shape } " )
705-
706695 residual = self .bwe_generator (log_mel , time_last = False )
707- print (f"BWE generator output (residual) - shape: { residual .shape } , min: { residual .min ()} , max: { residual .max ()} " )
708-
709696 skip = self .resampler (x )
710- print (f"Resampler output (skip) - shape: { skip .shape } , min: { skip .min ()} , max: { skip .max ()} " )
711697
712698 residual = jnp .transpose (residual , (0 , 2 , 1 ))
713699
714700 if residual .shape [1 ] < skip .shape [1 ]:
715701 residual = jnp .pad (residual , ((0 , 0 ), (0 , skip .shape [1 ] - residual .shape [1 ]), (0 , 0 )), mode = 'edge' )
716702 elif residual .shape [1 ] > skip .shape [1 ]:
717703 residual = residual [:, :skip .shape [1 ], :]
718- print (f"Matched residual - shape: { residual .shape } " )
719704
720705 raw_waveform = residual + skip
721- print (f"Raw waveform (residual + skip) - min: { raw_waveform .min ()} , max: { raw_waveform .max ()} " )
722-
723706 waveform = jnp .clip (raw_waveform , - 1 , 1 )
724707
725708 output_samples = num_samples * self .output_sampling_rate // self .input_sampling_rate
726709 waveform = waveform [:, :output_samples , :]
727710 waveform = jnp .transpose (waveform , (0 , 2 , 1 ))
728- print (f"Final waveform - shape: { waveform .shape } , min: { waveform .min ()} , max: { waveform .max ()} " )
729- print (f"=========================" )
730711
731712 return waveform
0 commit comments