5252logger = logging .getLogger (__name__ )
5353
5454
55- # =============================================================================
56- # Whisper Encoder Wrapper
57- # =============================================================================
58-
59-
6055class WhisperEncoderExportable (nn .Module ):
6156 """
6257 Wrapper around Whisper's encoder for export.
@@ -72,11 +67,6 @@ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
7267 return self .encoder (input_features = input_features ).last_hidden_state
7368
7469
75- # =============================================================================
76- # Whisper Decoder Self-Attention with KV Cache
77- # =============================================================================
78-
79-
8070class WhisperSelfAttentionWithCache (nn .Module ):
8171 """
8272 Whisper self-attention layer with static KV cache.
@@ -147,11 +137,6 @@ def forward(
147137 return self .out_proj (attn_out )
148138
149139
150- # =============================================================================
151- # Whisper Cross-Attention (no cache update - K/V pre-computed)
152- # =============================================================================
153-
154-
155140class WhisperCrossAttention (nn .Module ):
156141 """
157142 Whisper cross-attention layer.
@@ -192,11 +177,6 @@ def forward(
192177 return self .out_proj (attn_out )
193178
194179
195- # =============================================================================
196- # Whisper Decoder Layer Wrapper
197- # =============================================================================
198-
199-
200180class WhisperDecoderLayerWithCache (nn .Module ):
201181 """
202182 Wrapper for a single Whisper decoder layer with KV cache.
@@ -254,11 +234,6 @@ def forward(
254234 return hidden_states
255235
256236
257- # =============================================================================
258- # Whisper Decoder Wrapper
259- # =============================================================================
260-
261-
262237class WhisperDecoderWithCache (nn .Module ):
263238 """
264239 Whisper decoder wrapper with static KV cache.
@@ -335,11 +310,6 @@ def forward(
335310 return logits
336311
337312
338- # =============================================================================
339- # Cross-KV Projection Module
340- # =============================================================================
341-
342-
343313class WhisperCrossKVProjection (nn .Module ):
344314 """
345315 Compute cross-attention K/V projections from encoder hidden states.
@@ -393,11 +363,6 @@ def forward(
393363 return tuple (k_list ), tuple (v_list )
394364
395365
396- # =============================================================================
397- # Export Functions
398- # =============================================================================
399-
400-
401366def export_whisper_to_mlx (
402367 model_id : str ,
403368 output_dir : str ,
@@ -528,9 +493,6 @@ def export_whisper_to_mlx(
528493 logger .error ("TorchAO not installed. Run: pip install torchao" )
529494 raise
530495
531- # =========================================================================
532- # Export Encoder
533- # =========================================================================
534496 logger .info ("Exporting encoder..." )
535497
536498 with torch .no_grad ():
@@ -541,9 +503,6 @@ def export_whisper_to_mlx(
541503
542504 _save_to_pte (encoder_ep , os .path .join (output_dir , "encoder.pte" ), "encoder" )
543505
544- # =========================================================================
545- # Export Cross-KV Projection
546- # =========================================================================
547506 logger .info ("Exporting cross-KV projection..." )
548507
549508 with torch .no_grad ():
@@ -561,9 +520,6 @@ def export_whisper_to_mlx(
561520
562521 _save_to_pte (cross_kv_ep , os .path .join (output_dir , "cross_kv.pte" ), "cross_kv" )
563522
564- # =========================================================================
565- # Export Decoder
566- # =========================================================================
567523 logger .info ("Exporting decoder..." )
568524
569525 # Example inputs for decoder
0 commit comments