|
1 | 1 | """Export nvidia/parakeet-tdt-0.6b-v3 components to ExecuTorch.""" |
2 | 2 |
|
3 | 3 | import argparse |
4 | | -import logging |
5 | 4 | import os |
6 | 5 | import shutil |
7 | 6 | import tarfile |
|
20 | 19 | from executorch.exir.passes import MemoryPlanningPass |
21 | 20 | from torch.export import Dim, export |
22 | 21 |
|
23 | | -logger = logging.getLogger(__name__) |
24 | | - |
25 | 22 |
|
26 | 23 | def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: |
27 | 24 | """Load audio file and resample to target sample rate.""" |
@@ -442,7 +439,6 @@ def export_all( |
442 | 439 | strict=False, |
443 | 440 | ) |
444 | 441 |
|
445 | | - |
446 | 442 | sample_rate = model.preprocessor._cfg.sample_rate |
447 | 443 | window_stride = float(model.preprocessor._cfg.window_stride) |
448 | 444 | encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8)) |
@@ -564,20 +560,13 @@ def _create_cuda_partitioners(programs, is_windows=False): |
564 | 560 |
|
565 | 561 |
|
566 | 562 | def _create_mlx_partitioners(programs): |
567 | | - """Create MLX partitioners for all programs except preprocessor.""" |
| 563 | + """Create MLX partitioners for all programs.""" |
568 | 564 | from executorch.backends.apple.mlx.partitioner import MLXPartitioner |
569 | 565 |
|
570 | 566 | print("\nLowering to ExecuTorch with MLX...") |
571 | 567 |
|
572 | 568 | partitioner = {} |
573 | 569 | for key in programs.keys(): |
574 | | - # if key == "preprocessor": |
575 | | - # # Skip preprocessor - FFT ops are not supported by MLX and fall back |
576 | | - # # to portable pocketfft implementation. There is a bug in pocketfft |
577 | | - # # that causes SIGABRT ("pointer being freed was not allocated") in |
578 | | - # # release builds but not debug builds. |
579 | | - # partitioner[key] = [] |
580 | | - # else: |
581 | 570 | partitioner[key] = [MLXPartitioner()] |
582 | 571 |
|
583 | 572 | return partitioner, programs |
@@ -621,38 +610,6 @@ def lower_to_executorch(programs, metadata=None, backend="portable"): |
621 | 610 | ) |
622 | 611 |
|
623 | 612 |
|
624 | | -def apply_quantization(model, quantize: str) -> None: |
625 | | - """Apply quantization to the model using TorchAO. |
626 | | -
|
627 | | - Args: |
628 | | - model: The model to quantize |
629 | | - quantize: Quantization method ("int4" or "int8") |
630 | | - """ |
631 | | - try: |
632 | | - from torchao.quantization.granularity import PerGroup |
633 | | - from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ |
634 | | - except ImportError: |
635 | | - logger.error("TorchAO not installed. Run: pip install torchao") |
636 | | - raise |
637 | | - |
638 | | - logger.info(f"Applying {quantize} quantization to linear layers...") |
639 | | - |
640 | | - if quantize == "int4": |
641 | | - quantize_( |
642 | | - model, |
643 | | - IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(128)), |
644 | | - lambda m, fqn: isinstance(m, torch.nn.Linear), |
645 | | - ) |
646 | | - elif quantize == "int8": |
647 | | - quantize_( |
648 | | - model, |
649 | | - IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerGroup(128)), |
650 | | - lambda m, fqn: isinstance(m, torch.nn.Linear), |
651 | | - ) |
652 | | - else: |
653 | | - logger.warning(f"Unknown quantization method: {quantize}") |
654 | | - |
655 | | - |
656 | 613 | def main(): |
657 | 614 |
|
658 | 615 | parser = argparse.ArgumentParser() |
@@ -729,13 +686,6 @@ def main(): |
729 | 686 | help="Group size for embedding quantization (default: 0 = per-axis)", |
730 | 687 | ) |
731 | 688 |
|
732 | | - parser.add_argument( |
733 | | - "--quantize", |
734 | | - type=str, |
735 | | - choices=["int4", "int8"], |
736 | | - default=None, |
737 | | - help="Quantization method for linear layers (requires torchao)", |
738 | | - ) |
739 | 689 | args = parser.parse_args() |
740 | 690 |
|
741 | 691 | # Validate dtype |
@@ -764,10 +714,6 @@ def main(): |
764 | 714 | print("Converting model to float16...") |
765 | 715 | model = model.to(torch.float16) |
766 | 716 |
|
767 | | - # Apply quantization if requested |
768 | | - if args.quantize: |
769 | | - apply_quantization(model, args.quantize) |
770 | | - |
771 | 717 | print("\nExporting components...") |
772 | 718 | export_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float |
773 | 719 | programs, metadata = export_all( |
|
0 commit comments