Skip to content

Commit f296724

Browse files
committed
up
1 parent fdd9ea0 commit f296724

10 files changed

Lines changed: 81 additions & 88 deletions

File tree

CMakePresets.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@
289289
"inherits": ["common"],
290290
"cacheVariables": {
291291
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/mlx.cmake",
292-
"CMAKE_OSX_DEPLOYMENT_TARGET": "14.0"
292+
"CMAKE_OSX_DEPLOYMENT_TARGET": "14.0",
293+
"CMAKE_CXX_FLAGS": "-DABSL_USES_STD_STRING_VIEW"
293294
},
294295
"condition": {
295296
"lhs": "${hostSystemName}",

backends/apple/mlx/examples/whisper/run_whisper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def run_whisper_inference( # noqa: C901
256256
cache_position = cache_position + decoder_input_ids.shape[1]
257257

258258
# Generation loop
259-
for step in range(max_new_tokens):
259+
for _step in range(max_new_tokens):
260260
current_pos = cache_position.item()
261261

262262
# Check for forced token at this position

backends/apple/mlx/examples/whisper/run_whisper_hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def run_inference( # noqa: C901
222222
cache_position = cache_position + decoder_input_ids.shape[1]
223223

224224
# Generation loop
225-
for step in range(max_new_tokens):
225+
for _step in range(max_new_tokens):
226226
current_pos = cache_position.item()
227227

228228
# Check for forced token at this position

backends/apple/mlx/ops.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,20 +1277,18 @@ def _unsqueeze_handler(P: MLXProgramBuilder, n: Node) -> Slot:
12771277
@REGISTRY.register(
12781278
target=[torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims]
12791279
)
1280-
def _squeeze_handler(P: MLXProgramBuilder, n: Node) -> Slot:
1280+
def _squeeze_dims_handler(P: MLXProgramBuilder, n: Node) -> Slot:
12811281
"""Handle squeeze operation for specific dimensions.
12821282
12831283
Removes dimensions of size 1 from the tensor at specified positions.
1284-
If dims is empty, removes all dimensions of size 1.
12851284
"""
12861285
args = P.args(n)
12871286
require_args(args, 2, 2, "aten.squeeze.dims")
12881287
require_kwargs(P.kwargs(n), set(), "aten.squeeze.dims")
12891288
x, dims = args
12901289
out = P.make_or_get_slot(n)
12911290

1292-
# dims is typically a list of ints
1293-
dims_list = list(dims) if dims is not None else []
1291+
dims_list = list(dims) if dims is not None else None
12941292

12951293
P.emit(
12961294
SqueezeNode(
@@ -1302,6 +1300,30 @@ def _squeeze_handler(P: MLXProgramBuilder, n: Node) -> Slot:
13021300
return out
13031301

13041302

1303+
@REGISTRY.register(
1304+
target=[torch.ops.aten.squeeze.default, torch.ops.aten.squeeze_copy.default]
1305+
)
1306+
def _squeeze_default_handler(P: MLXProgramBuilder, n: Node) -> Slot:
1307+
"""Handle squeeze operation without specified dimensions.
1308+
1309+
Removes all dimensions of size 1 from the tensor.
1310+
"""
1311+
args = P.args(n)
1312+
require_args(args, 1, 1, "aten.squeeze.default")
1313+
require_kwargs(P.kwargs(n), set(), "aten.squeeze.default")
1314+
(x,) = args
1315+
out = P.make_or_get_slot(n)
1316+
1317+
P.emit(
1318+
SqueezeNode(
1319+
x=P.slot_to_tid(x),
1320+
out=P.slot_to_tid(out),
1321+
dims=None,
1322+
)
1323+
)
1324+
return out
1325+
1326+
13051327
@REGISTRY.register(target=[torch.ops.aten.cat.default])
13061328
def _cat_handler(P: MLXProgramBuilder, n: Node) -> Slot:
13071329
"""Handle concatenation of a list of tensors.

backends/apple/mlx/program_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_aten_target(target):
9393
torch.ops.aten.unsqueeze_copy.default: torch.ops.aten.unsqueeze.default,
9494
torch.ops.aten.squeeze_copy.dim: torch.ops.aten.squeeze.dim,
9595
torch.ops.aten.squeeze_copy.dims: torch.ops.aten.squeeze.dims,
96+
torch.ops.aten.squeeze_copy.default: torch.ops.aten.squeeze.default,
9697
torch.ops.aten.expand_copy.default: torch.ops.aten.expand.default,
9798
torch.ops.aten.alias_copy.default: torch.ops.aten.alias.default,
9899
}

backends/apple/mlx/runtime/MLXInterpreter.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,24 @@ inline void exec_tanh(const TanhNode& n, ExecutionState& st, StreamOrDevice s) {
464464
inline void
465465
exec_squeeze(const SqueezeNode& n, ExecutionState& st, StreamOrDevice s) {
466466
const auto& x = st.const_tensor_ref(n.x);
467-
auto dims_fb = n.dims;
467+
const auto& dims_fb = n.dims;
468468

469469
if (dims_fb.size() > 0) {
470-
// Squeeze specific dimensions
470+
// Squeeze specific dimensions, filtering out non-size-1 dims to match
471+
// PyTorch semantics where squeeze on a non-size-1 dim is a no-op.
471472
std::vector<int> dims;
472473
for (auto d : dims_fb) {
473-
dims.push_back(d);
474+
int axis = d < 0 ? d + static_cast<int>(x.ndim()) : d;
475+
if (axis >= 0 && axis < static_cast<int>(x.ndim()) &&
476+
x.shape(axis) == 1) {
477+
dims.push_back(d);
478+
}
479+
}
480+
if (dims.size() > 0) {
481+
st.set_tensor(n.out, squeeze(x, dims, s));
482+
} else {
483+
st.set_tensor(n.out, x);
474484
}
475-
st.set_tensor(n.out, squeeze(x, dims, s));
476485
} else {
477486
// Squeeze all dimensions of size 1
478487
st.set_tensor(n.out, squeeze(x, s));

backends/apple/mlx/test/test_ops.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,12 +475,12 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
475475
class SqueezeModel(nn.Module):
476476
"""Model that squeezes a tensor at specified dimensions."""
477477

478-
def __init__(self, dims: Tuple[int, ...]):
478+
def __init__(self, dims: Optional[Tuple[int, ...]] = None):
479479
super().__init__()
480480
self.dims = dims
481481

482482
def forward(self, x: torch.Tensor) -> torch.Tensor:
483-
if len(self.dims) == 0:
483+
if self.dims is None:
484484
return torch.squeeze(x)
485485
else:
486486
return torch.squeeze(x, dim=self.dims)
@@ -495,12 +495,19 @@ class SqueezeTest(OpTestCase):
495495
atol = 1e-5
496496

497497
def __init__(
498-
self, shape: Tuple[int, ...] = (1, 3, 1, 4), dims: Tuple[int, ...] = (0, 2)
498+
self,
499+
shape: Tuple[int, ...] = (1, 3, 1, 4),
500+
dims: Optional[Tuple[int, ...]] = (0, 2),
499501
):
500502
self.shape = shape
501503
self.dims = dims
502504
shape_str = "x".join(str(s) for s in shape)
503-
dims_str = "_".join(str(d) for d in dims) if dims else "all"
505+
if dims is None:
506+
dims_str = "all"
507+
elif len(dims) == 0:
508+
dims_str = "empty"
509+
else:
510+
dims_str = "_".join(str(d) for d in dims)
504511
self.name = f"squeeze_{shape_str}_dims{dims_str}"
505512

506513
@classmethod
@@ -511,6 +518,10 @@ def get_test_configs(cls) -> List["SqueezeTest"]:
511518
cls(shape=(3, 1, 4), dims=(1,)),
512519
cls(shape=(1, 1, 8), dims=(0, 1)),
513520
cls(shape=(2, 1, 3, 1), dims=(1, 3)),
521+
# Squeeze all singleton dims (no dims specified)
522+
cls(shape=(1, 3, 1, 4), dims=None),
523+
# Dims include non-size-1 axes (should be no-op for those axes)
524+
cls(shape=(1, 1, 1, 8198), dims=(0, 1, 2, 3)),
514525
]
515526

516527
def create_inputs(self) -> Tuple[torch.Tensor, ...]:

examples/models/parakeet/README.md

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav
2727
| `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) |
2828
| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `mlx`, `cuda`, `cuda-windows` (default: `xnnpack`) |
2929
| `--dtype` | Data type: `fp32`, `bf16`, `fp16` (default: `fp32`). Metal backend supports `fp32` and `bf16` only (no `fp16`). |
30-
| `--quantize` | Quantization mode: `int4` for int4 weight-only quantization via TorchAO (default: none) |
3130
| `--audio` | Path to audio file for transcription test |
3231

3332
**Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting.
@@ -135,18 +134,20 @@ This generates:
135134

136135
### MLX Export (macOS)
137136

138-
Export with MLX backend:
137+
Export with MLX backend (bf16, int4 quantized, group size 128):
139138
```bash
140-
python export_parakeet_tdt.py --backend mlx --output-dir ./parakeet_mlx
141-
```
142-
143-
Export with int4 quantization (reduces model size ~4x):
144-
```bash
145-
python export_parakeet_tdt.py --backend mlx --quantize int4 --output-dir ./parakeet_mlx_int4
139+
python export_parakeet_tdt.py \
140+
--backend mlx \
141+
--dtype bf16 \
142+
--qlinear_encoder 4w \
143+
--qlinear_encoder_group_size 128 \
144+
--qlinear 4w \
145+
--qlinear_group_size 128 \
146+
--output-dir ./parakeet_mlx_4w
146147
```
147148

148149
This generates:
149-
- `parakeet_tdt.pte` - The compiled model with MLX delegate
150+
- `model.pte` - The compiled model with MLX delegate (~470 MB)
150151
- `tokenizer.model` - SentencePiece tokenizer
151152

152153
## C++ Runner
@@ -172,23 +173,18 @@ Then build the parakeet runner:
172173
cd examples/models/parakeet
173174

174175
# CPU/XNNPACK build
175-
make parakeet-cpu
176+
cmake --workflow --preset parakeet-cpu
176177

177178
# Metal build (macOS)
178-
make parakeet-metal
179+
cmake --workflow --preset parakeet-metal
179180

180181
# CUDA build (Linux)
181-
make parakeet-cuda
182+
cmake --workflow --preset parakeet-cuda
182183

183184
# MLX build (macOS)
184-
make parakeet-mlx
185+
cmake --workflow --preset parakeet-mlx
185186
```
186187

187-
Available presets:
188-
- `parakeet-cpu` - CPU-only build
189-
- `parakeet-cuda` - CUDA acceleration (Linux/Windows)
190-
- `parakeet-metal` - Metal acceleration (macOS)
191-
- `parakeet-mlx` - MLX acceleration (macOS)
192188
### Running
193189

194190
From the executorch root directory:
@@ -212,6 +208,12 @@ DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_runner
212208
--data_path examples/models/parakeet/parakeet_cuda/aoti_cuda_blob.ptd \
213209
--audio_path /path/to/audio.wav \
214210
--tokenizer_path examples/models/parakeet/parakeet_cuda/tokenizer.model
211+
212+
# MLX
213+
./cmake-out/examples/models/parakeet/parakeet_runner \
214+
--model_path examples/models/parakeet/parakeet_mlx_4w/model.pte \
215+
--audio_path /path/to/audio.wav \
216+
--tokenizer_path examples/models/parakeet/parakeet_mlx_4w/tokenizer.model
215217
```
216218

217219
### Runner Arguments

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Export nvidia/parakeet-tdt-0.6b-v3 components to ExecuTorch."""
22

33
import argparse
4-
import logging
54
import os
65
import shutil
76
import tarfile
@@ -20,8 +19,6 @@
2019
from executorch.exir.passes import MemoryPlanningPass
2120
from torch.export import Dim, export
2221

23-
logger = logging.getLogger(__name__)
24-
2522

2623
def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor:
2724
"""Load audio file and resample to target sample rate."""
@@ -442,7 +439,6 @@ def export_all(
442439
strict=False,
443440
)
444441

445-
446442
sample_rate = model.preprocessor._cfg.sample_rate
447443
window_stride = float(model.preprocessor._cfg.window_stride)
448444
encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8))
@@ -564,20 +560,13 @@ def _create_cuda_partitioners(programs, is_windows=False):
564560

565561

566562
def _create_mlx_partitioners(programs):
567-
"""Create MLX partitioners for all programs except preprocessor."""
563+
"""Create MLX partitioners for all programs."""
568564
from executorch.backends.apple.mlx.partitioner import MLXPartitioner
569565

570566
print("\nLowering to ExecuTorch with MLX...")
571567

572568
partitioner = {}
573569
for key in programs.keys():
574-
# if key == "preprocessor":
575-
# # Skip preprocessor - FFT ops are not supported by MLX and fall back
576-
# # to portable pocketfft implementation. There is a bug in pocketfft
577-
# # that causes SIGABRT ("pointer being freed was not allocated") in
578-
# # release builds but not debug builds.
579-
# partitioner[key] = []
580-
# else:
581570
partitioner[key] = [MLXPartitioner()]
582571

583572
return partitioner, programs
@@ -621,38 +610,6 @@ def lower_to_executorch(programs, metadata=None, backend="portable"):
621610
)
622611

623612

624-
def apply_quantization(model, quantize: str) -> None:
625-
"""Apply quantization to the model using TorchAO.
626-
627-
Args:
628-
model: The model to quantize
629-
quantize: Quantization method ("int4" or "int8")
630-
"""
631-
try:
632-
from torchao.quantization.granularity import PerGroup
633-
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
634-
except ImportError:
635-
logger.error("TorchAO not installed. Run: pip install torchao")
636-
raise
637-
638-
logger.info(f"Applying {quantize} quantization to linear layers...")
639-
640-
if quantize == "int4":
641-
quantize_(
642-
model,
643-
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(128)),
644-
lambda m, fqn: isinstance(m, torch.nn.Linear),
645-
)
646-
elif quantize == "int8":
647-
quantize_(
648-
model,
649-
IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerGroup(128)),
650-
lambda m, fqn: isinstance(m, torch.nn.Linear),
651-
)
652-
else:
653-
logger.warning(f"Unknown quantization method: {quantize}")
654-
655-
656613
def main():
657614

658615
parser = argparse.ArgumentParser()
@@ -729,13 +686,6 @@ def main():
729686
help="Group size for embedding quantization (default: 0 = per-axis)",
730687
)
731688

732-
parser.add_argument(
733-
"--quantize",
734-
type=str,
735-
choices=["int4", "int8"],
736-
default=None,
737-
help="Quantization method for linear layers (requires torchao)",
738-
)
739689
args = parser.parse_args()
740690

741691
# Validate dtype
@@ -764,10 +714,6 @@ def main():
764714
print("Converting model to float16...")
765715
model = model.to(torch.float16)
766716

767-
# Apply quantization if requested
768-
if args.quantize:
769-
apply_quantization(model, args.quantize)
770-
771717
print("\nExporting components...")
772718
export_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float
773719
programs, metadata = export_all(

examples/models/parakeet/quantize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def linear_filter(m, fqn):
109109
config = IntxWeightOnlyConfig(
110110
weight_dtype=torch.int4,
111111
granularity=granularity,
112+
intx_choose_qparams_algorithm="hqq_scale_only",
112113
)
113114
elif qlinear_config == "8w":
114115
config = IntxWeightOnlyConfig(

0 commit comments

Comments
 (0)