Skip to content

Commit 4632a83

Browse files
Update
[ghstack-poisoned]
2 parents 59f88db + 98d2f81 commit 4632a83

3 files changed

Lines changed: 139 additions & 99 deletions

File tree

backends/apple/metal/ops/gather_qmv.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,37 @@ def gather_qmv(
4242
return y
4343

4444

45+
def _quantize_int4_affine(
46+
w: Tensor, group_size: int
47+
) -> tuple[Tensor, Tensor, Tensor]:
48+
"""Quantize float weights to packed INT4 using MLX affine format.
49+
50+
Args:
51+
w: [..., K] float weight tensor (last dim is quantized).
52+
group_size: Number of elements per quantization group.
53+
54+
Returns:
55+
(packed, scales, biases) where:
56+
- packed: [..., K//2] uint8, two INT4 values per byte.
57+
- scales: [..., K//group_size] per-group scales.
58+
- biases: [..., K//group_size] per-group biases (zero points).
59+
60+
The affine mapping is: dequantized = raw_uint4 * scale + bias,
61+
where raw_uint4 is in [0, 15].
62+
"""
63+
*leading, K = w.shape
64+
w_groups = w.reshape(*leading, K // group_size, group_size)
65+
g_min = w_groups.amin(dim=-1)
66+
g_max = w_groups.amax(dim=-1)
67+
scales = ((g_max - g_min) / 15.0).clamp(min=1e-8)
68+
biases = g_min
69+
w_int = (
70+
(w_groups - biases.unsqueeze(-1)) / scales.unsqueeze(-1)
71+
).round().clamp(0, 15).to(torch.uint8).reshape(*leading, K)
72+
packed = w_int[..., 0::2] | (w_int[..., 1::2] << 4)
73+
return packed, scales, biases
74+
75+
4576
def _dequantize_int4_affine(
4677
w_packed: Tensor, scales: Tensor, biases: Tensor, K: int, group_size: int
4778
) -> Tensor:

backends/apple/metal/runtime/ops/op_topk.mm

Lines changed: 97 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
// Top-k operator using MPSGraph.
1010
// Used by MoE routing (torch.topk in SparseMoE.forward).
11+
// Note: sorted parameter is accepted but MPSGraph always returns sorted results.
1112

1213
#include <executorch/backends/apple/metal/runtime/ops/common.h>
1314

@@ -40,6 +41,9 @@ AOTITorchError aoti_torch_mps_topk(
4041
return Error::Internal;
4142
}
4243

44+
void* values_ptr = nullptr;
45+
void* indices_ptr = nullptr;
46+
4347
try {
4448
@autoreleasepool {
4549
auto* self_tensor = reinterpret_cast<Tensor*>(self);
@@ -55,7 +59,7 @@ AOTITorchError aoti_torch_mps_topk(
5559

5660
int64_t dim_size = self_tensor->sizes()[dim];
5761
if (k > dim_size) {
58-
ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld\n", k, dim_size);
62+
ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld", k, dim_size);
5963
return Error::InvalidArgument;
6064
}
6165

@@ -96,18 +100,20 @@ AOTITorchError aoti_torch_mps_topk(
96100
size_t values_bytes = num_elements * element_size;
97101
size_t indices_bytes = num_elements * sizeof(int32_t);
98102

99-
void* values_ptr = nullptr;
100-
void* indices_ptr = nullptr;
101103
allocate_mtl_buffer(&values_ptr, values_bytes);
102104
allocate_mtl_buffer(&indices_ptr, indices_bytes);
103105

104-
// Build MPSGraph
105106
// Convert input shape to NSArray<NSNumber*>
106107
NSMutableArray<NSNumber*>* input_shape = [NSMutableArray arrayWithCapacity:ndim];
107108
for (int64_t i = 0; i < ndim; i++) {
108109
[input_shape addObject:@(self_tensor->sizes()[i])];
109110
}
110111

112+
NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim];
113+
for (int64_t i = 0; i < ndim; i++) {
114+
[out_ns_shape addObject:@(out_sizes[i])];
115+
}
116+
111117
// Check graph cache
112118
GraphCacheKey cache_key;
113119
cache_key.op_name = "topk";
@@ -120,101 +126,103 @@ AOTITorchError aoti_torch_mps_topk(
120126
cache_key.dtype = dtype;
121127
cache_key.transpose_flag = false;
122128

129+
stream->endKernelCoalescing();
130+
131+
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self");
132+
id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr];
133+
id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr];
134+
123135
auto cache_it = graph_cache.find(cache_key);
124136
if (cache_it != graph_cache.end()) {
125137
cache_stats.hits++;
138+
cache_stats.logStats();
126139
auto& cached = cache_it->second;
127140

128-
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self");
129-
id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr];
130-
id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr];
141+
MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype];
142+
MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype];
143+
MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32];
131144

132145
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
133-
cached.input1: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype],
146+
cached.input1: selfData,
134147
};
135-
136-
NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim];
137-
for (int64_t i = 0; i < ndim; i++) {
138-
[out_ns_shape addObject:@(out_sizes[i])];
139-
}
140-
141148
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
142-
cached.output: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype],
143-
cached.input2: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32],
149+
cached.output: valuesData,
150+
cached.input2: indicesData,
144151
};
145152

146-
stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT);
153+
@try {
154+
stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT);
155+
} @catch (NSException* e) {
156+
ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s",
157+
e.name.UTF8String, e.reason.UTF8String);
158+
throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String);
159+
}
160+
161+
[selfData release];
162+
[valuesData release];
163+
[indicesData release];
147164
} else {
148165
cache_stats.misses++;
166+
cache_stats.logStats();
149167
ET_LOG(Debug, "aoti_torch_mps_topk: cache miss, building graph");
150168

151169
@try {
152-
MPSGraph* graph = [[MPSGraph alloc] init];
153-
MPSGraphTensor* input = [graph placeholderWithShape:input_shape
154-
dataType:mps_dtype
155-
name:@"self"];
156-
157-
// MPSGraph topK: returns (values, indices) along the last dimension.
158-
// If dim != -1, we need to transpose dim to last, topk, then transpose back.
159-
MPSGraphTensor* work = input;
160-
bool need_transpose = (dim != ndim - 1);
161-
162-
if (need_transpose) {
163-
work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil];
164-
}
165-
166-
// MPSGraph topKWithTensor returns along the last axis
167-
NSArray<MPSGraphTensor*>* topk_results;
168-
if (largest) {
169-
topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil];
170-
} else {
171-
// For smallest: negate, topk, negate back
172-
MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil];
173-
topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil];
174-
topk_results = @[
175-
[graph negativeWithTensor:topk_results[0] name:nil],
176-
topk_results[1]
177-
];
178-
}
179-
180-
MPSGraphTensor* values_out = topk_results[0];
181-
MPSGraphTensor* indices_out = topk_results[1];
182-
183-
if (need_transpose) {
184-
values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil];
185-
indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil];
186-
}
187-
188-
// Cache the graph
189-
CachedGraph cached_graph;
190-
cached_graph.graph = graph;
191-
cached_graph.input1 = input;
192-
cached_graph.input2 = indices_out; // reuse input2 slot for indices output
193-
cached_graph.output = values_out;
194-
graph_cache[cache_key] = cached_graph;
195-
196-
// Execute
197-
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self");
198-
id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr];
199-
id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr];
200-
201-
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
202-
input: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype],
203-
};
204-
205-
NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim];
206-
for (int64_t i = 0; i < ndim; i++) {
207-
[out_ns_shape addObject:@(out_sizes[i])];
208-
}
209-
210-
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
211-
values_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype],
212-
indices_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32],
213-
};
214-
215-
ET_LOG(Debug, "aoti_torch_mps_topk: executing MPSGraph");
216-
stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT);
217-
ET_LOG(Debug, "aoti_torch_mps_topk: MPSGraph done");
170+
MPSGraph* graph = [[MPSGraph alloc] init];
171+
MPSGraphTensor* input = [graph placeholderWithShape:input_shape
172+
dataType:mps_dtype
173+
name:@"self"];
174+
175+
MPSGraphTensor* work = input;
176+
bool need_transpose = (dim != ndim - 1);
177+
178+
if (need_transpose) {
179+
work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil];
180+
}
181+
182+
NSArray<MPSGraphTensor*>* topk_results;
183+
if (largest) {
184+
topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil];
185+
} else {
186+
MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil];
187+
topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil];
188+
topk_results = @[
189+
[graph negativeWithTensor:topk_results[0] name:nil],
190+
topk_results[1]
191+
];
192+
}
193+
194+
MPSGraphTensor* values_out = topk_results[0];
195+
MPSGraphTensor* indices_out = topk_results[1];
196+
197+
if (need_transpose) {
198+
values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil];
199+
indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil];
200+
}
201+
202+
CachedGraph cached_graph;
203+
cached_graph.graph = graph;
204+
cached_graph.input1 = input;
205+
cached_graph.input2 = indices_out;
206+
cached_graph.output = values_out;
207+
graph_cache[cache_key] = cached_graph;
208+
209+
MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype];
210+
MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype];
211+
MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32];
212+
213+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
214+
input: selfData,
215+
};
216+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
217+
values_out: valuesData,
218+
indices_out: indicesData,
219+
};
220+
221+
stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT);
222+
223+
[selfData release];
224+
[valuesData release];
225+
[indicesData release];
218226
} @catch (NSException* e) {
219227
ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s",
220228
e.name.UTF8String, e.reason.UTF8String);
@@ -223,7 +231,6 @@ AOTITorchError aoti_torch_mps_topk(
223231
}
224232

225233
// Create output tensor handles
226-
// Values tensor
227234
AOTITensorHandle values_handle = nullptr;
228235
aoti_torch_create_tensor_from_blob_v2(
229236
values_ptr, ndim, out_sizes.data(), out_strides.data(),
@@ -235,29 +242,25 @@ AOTITorchError aoti_torch_mps_topk(
235242
aoti_torch_mps_free(indices_ptr);
236243
return Error::Internal;
237244
}
238-
ET_LOG(Debug, "aoti_torch_mps_topk: values tensor created");
239245

240-
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
241246
memory_to_n_tensor[values_ptr] = 1;
242247

243248
// Indices tensor — MPSGraph outputs int32, AOTInductor expects int64.
244-
// Allocate a new int64 buffer and convert.
245249
size_t indices_i64_bytes = num_elements * sizeof(int64_t);
246250
void* indices_i64_ptr = nullptr;
247251
allocate_mtl_buffer(&indices_i64_ptr, indices_i64_bytes);
248252

249253
// Copy int32 → int64 on CPU (small tensor, fast)
254+
stream->synchronize(SyncType::COMMIT_AND_WAIT);
250255
{
251-
auto* stream_sync = getCurrentMetalStream();
252-
stream_sync->synchronize(SyncType::COMMIT_AND_WAIT);
253-
254256
int32_t* src = reinterpret_cast<int32_t*>(indices_ptr);
255257
int64_t* dst = reinterpret_cast<int64_t*>(indices_i64_ptr);
256258
for (size_t i = 0; i < num_elements; i++) {
257259
dst[i] = static_cast<int64_t>(src[i]);
258260
}
259261
}
260262
aoti_torch_mps_free(indices_ptr);
263+
indices_ptr = nullptr;
261264

262265
int32_t indices_dtype = static_cast<int32_t>(exec_aten::ScalarType::Long);
263266
std::vector<int64_t> indices_strides(ndim);
@@ -281,17 +284,19 @@ AOTITorchError aoti_torch_mps_topk(
281284
*ret0 = values_handle;
282285
*ret1 = indices_handle;
283286

284-
ET_LOG(Debug, "aoti_torch_mps_topk: Completed successfully");
285-
286287
} // @autoreleasepool
287288

288289
return Error::Ok;
289290

290291
} catch (const std::exception& e) {
291292
ET_LOG(Error, "aoti_torch_mps_topk exception: %s", e.what());
293+
if (values_ptr) aoti_torch_mps_free(values_ptr);
294+
if (indices_ptr) aoti_torch_mps_free(indices_ptr);
292295
return Error::Internal;
293296
} catch (...) {
294297
ET_LOG(Error, "aoti_torch_mps_topk: unknown exception");
298+
if (values_ptr) aoti_torch_mps_free(values_ptr);
299+
if (indices_ptr) aoti_torch_mps_free(indices_ptr);
295300
return Error::Internal;
296301
}
297302
}

backends/apple/metal/tests/test_modules.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -702,13 +702,17 @@ class GatherQMV(nn.Module):
702702

703703
def __init__(self):
704704
super().__init__()
705+
from executorch.backends.apple.metal.ops.gather_qmv import (
706+
_quantize_int4_affine,
707+
)
708+
705709
E, N, K, gs = 4, 64, 128, 32
706710
torch.manual_seed(0)
707-
self.register_buffer(
708-
"w", torch.randint(0, 255, (E, N, K // 2), dtype=torch.uint8)
709-
)
710-
self.register_buffer("scales", torch.randn(E, N, K // gs))
711-
self.register_buffer("biases", torch.randn(E, N, K // gs))
711+
w_float = torch.randn(E, N, K)
712+
packed, scales, biases = _quantize_int4_affine(w_float, gs)
713+
self.register_buffer("w", packed)
714+
self.register_buffer("scales", scales)
715+
self.register_buffer("biases", biases)
712716
self.group_size = gs
713717
self.num_experts = E
714718

@@ -733,8 +737,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
733737
"description": "Expert-indexed quantized matmul for MoE (metal::gather_qmv)",
734738
"atol_float32": 5e-2,
735739
"rtol_float32": 5e-2,
736-
"atol_bfloat16": 5.0,
737-
"rtol_bfloat16": 2e-1,
740+
"atol_bfloat16": 1e-1,
741+
"rtol_bfloat16": 1e-1,
738742
}
739743

740744

0 commit comments

Comments
 (0)