Skip to content

Commit 98d2f81

Browse files
Update
[ghstack-poisoned]
1 parent 1fbb94f commit 98d2f81

1 file changed

Lines changed: 97 additions & 92 deletions

File tree

backends/apple/metal/runtime/ops/op_topk.mm

Lines changed: 97 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
// Top-k operator using MPSGraph.
1010
// Used by MoE routing (torch.topk in SparseMoE.forward).
11+
// Note: sorted parameter is accepted but MPSGraph always returns sorted results.
1112

1213
#include <executorch/backends/apple/metal/runtime/ops/common.h>
1314

@@ -40,6 +41,9 @@ AOTITorchError aoti_torch_mps_topk(
4041
return Error::Internal;
4142
}
4243

44+
void* values_ptr = nullptr;
45+
void* indices_ptr = nullptr;
46+
4347
try {
4448
@autoreleasepool {
4549
auto* self_tensor = reinterpret_cast<Tensor*>(self);
@@ -55,7 +59,7 @@ AOTITorchError aoti_torch_mps_topk(
5559

5660
int64_t dim_size = self_tensor->sizes()[dim];
5761
if (k > dim_size) {
58-
ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld\n", k, dim_size);
62+
ET_LOG(Error, "aoti_torch_mps_topk: k=%lld > dim_size=%lld", k, dim_size);
5963
return Error::InvalidArgument;
6064
}
6165

@@ -96,18 +100,20 @@ AOTITorchError aoti_torch_mps_topk(
96100
size_t values_bytes = num_elements * element_size;
97101
size_t indices_bytes = num_elements * sizeof(int32_t);
98102

99-
void* values_ptr = nullptr;
100-
void* indices_ptr = nullptr;
101103
allocate_mtl_buffer(&values_ptr, values_bytes);
102104
allocate_mtl_buffer(&indices_ptr, indices_bytes);
103105

104-
// Build MPSGraph
105106
// Convert input shape to NSArray<NSNumber*>
106107
NSMutableArray<NSNumber*>* input_shape = [NSMutableArray arrayWithCapacity:ndim];
107108
for (int64_t i = 0; i < ndim; i++) {
108109
[input_shape addObject:@(self_tensor->sizes()[i])];
109110
}
110111

112+
NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim];
113+
for (int64_t i = 0; i < ndim; i++) {
114+
[out_ns_shape addObject:@(out_sizes[i])];
115+
}
116+
111117
// Check graph cache
112118
GraphCacheKey cache_key;
113119
cache_key.op_name = "topk";
@@ -120,101 +126,103 @@ AOTITorchError aoti_torch_mps_topk(
120126
cache_key.dtype = dtype;
121127
cache_key.transpose_flag = false;
122128

129+
stream->endKernelCoalescing();
130+
131+
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self");
132+
id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr];
133+
id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr];
134+
123135
auto cache_it = graph_cache.find(cache_key);
124136
if (cache_it != graph_cache.end()) {
125137
cache_stats.hits++;
138+
cache_stats.logStats();
126139
auto& cached = cache_it->second;
127140

128-
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self");
129-
id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr];
130-
id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr];
141+
MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype];
142+
MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype];
143+
MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32];
131144

132145
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
133-
cached.input1: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype],
146+
cached.input1: selfData,
134147
};
135-
136-
NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim];
137-
for (int64_t i = 0; i < ndim; i++) {
138-
[out_ns_shape addObject:@(out_sizes[i])];
139-
}
140-
141148
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
142-
cached.output: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype],
143-
cached.input2: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32],
149+
cached.output: valuesData,
150+
cached.input2: indicesData,
144151
};
145152

146-
stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT);
153+
@try {
154+
stream->executeMPSGraph(cached.graph, feeds, results, SyncType::COMMIT);
155+
} @catch (NSException* e) {
156+
ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s",
157+
e.name.UTF8String, e.reason.UTF8String);
158+
throw std::runtime_error(std::string("MPSGraph topk failed: ") + e.reason.UTF8String);
159+
}
160+
161+
[selfData release];
162+
[valuesData release];
163+
[indicesData release];
147164
} else {
148165
cache_stats.misses++;
166+
cache_stats.logStats();
149167
ET_LOG(Debug, "aoti_torch_mps_topk: cache miss, building graph");
150168

151169
@try {
152-
MPSGraph* graph = [[MPSGraph alloc] init];
153-
MPSGraphTensor* input = [graph placeholderWithShape:input_shape
154-
dataType:mps_dtype
155-
name:@"self"];
156-
157-
// MPSGraph topK: returns (values, indices) along the last dimension.
158-
// If dim != -1, we need to transpose dim to last, topk, then transpose back.
159-
MPSGraphTensor* work = input;
160-
bool need_transpose = (dim != ndim - 1);
161-
162-
if (need_transpose) {
163-
work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil];
164-
}
165-
166-
// MPSGraph topKWithTensor returns along the last axis
167-
NSArray<MPSGraphTensor*>* topk_results;
168-
if (largest) {
169-
topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil];
170-
} else {
171-
// For smallest: negate, topk, negate back
172-
MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil];
173-
topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil];
174-
topk_results = @[
175-
[graph negativeWithTensor:topk_results[0] name:nil],
176-
topk_results[1]
177-
];
178-
}
179-
180-
MPSGraphTensor* values_out = topk_results[0];
181-
MPSGraphTensor* indices_out = topk_results[1];
182-
183-
if (need_transpose) {
184-
values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil];
185-
indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil];
186-
}
187-
188-
// Cache the graph
189-
CachedGraph cached_graph;
190-
cached_graph.graph = graph;
191-
cached_graph.input1 = input;
192-
cached_graph.input2 = indices_out; // reuse input2 slot for indices output
193-
cached_graph.output = values_out;
194-
graph_cache[cache_key] = cached_graph;
195-
196-
// Execute
197-
id<MTLBuffer> self_buffer = get_mtl_buffer(self_tensor, "topk", "self");
198-
id<MTLBuffer> values_buffer = ptr_to_mtl_buffer[values_ptr];
199-
id<MTLBuffer> indices_buffer = ptr_to_mtl_buffer[indices_ptr];
200-
201-
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
202-
input: [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype],
203-
};
204-
205-
NSMutableArray<NSNumber*>* out_ns_shape = [NSMutableArray arrayWithCapacity:ndim];
206-
for (int64_t i = 0; i < ndim; i++) {
207-
[out_ns_shape addObject:@(out_sizes[i])];
208-
}
209-
210-
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
211-
values_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype],
212-
indices_out: [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32],
213-
};
214-
215-
ET_LOG(Debug, "aoti_torch_mps_topk: executing MPSGraph");
216-
stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT);
217-
ET_LOG(Debug, "aoti_torch_mps_topk: MPSGraph done");
170+
MPSGraph* graph = [[MPSGraph alloc] init];
171+
MPSGraphTensor* input = [graph placeholderWithShape:input_shape
172+
dataType:mps_dtype
173+
name:@"self"];
174+
175+
MPSGraphTensor* work = input;
176+
bool need_transpose = (dim != ndim - 1);
177+
178+
if (need_transpose) {
179+
work = [graph transposeTensor:work dimension:dim withDimension:ndim - 1 name:nil];
180+
}
181+
182+
NSArray<MPSGraphTensor*>* topk_results;
183+
if (largest) {
184+
topk_results = [graph topKWithSourceTensor:work k:(NSUInteger)k name:nil];
185+
} else {
186+
MPSGraphTensor* neg = [graph negativeWithTensor:work name:nil];
187+
topk_results = [graph topKWithSourceTensor:neg k:(NSUInteger)k name:nil];
188+
topk_results = @[
189+
[graph negativeWithTensor:topk_results[0] name:nil],
190+
topk_results[1]
191+
];
192+
}
193+
194+
MPSGraphTensor* values_out = topk_results[0];
195+
MPSGraphTensor* indices_out = topk_results[1];
196+
197+
if (need_transpose) {
198+
values_out = [graph transposeTensor:values_out dimension:dim withDimension:ndim - 1 name:nil];
199+
indices_out = [graph transposeTensor:indices_out dimension:dim withDimension:ndim - 1 name:nil];
200+
}
201+
202+
CachedGraph cached_graph;
203+
cached_graph.graph = graph;
204+
cached_graph.input1 = input;
205+
cached_graph.input2 = indices_out;
206+
cached_graph.output = values_out;
207+
graph_cache[cache_key] = cached_graph;
208+
209+
MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:input_shape dataType:mps_dtype];
210+
MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:values_buffer shape:out_ns_shape dataType:mps_dtype];
211+
MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc] initWithMTLBuffer:indices_buffer shape:out_ns_shape dataType:MPSDataTypeInt32];
212+
213+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
214+
input: selfData,
215+
};
216+
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
217+
values_out: valuesData,
218+
indices_out: indicesData,
219+
};
220+
221+
stream->executeMPSGraph(graph, feeds, results, SyncType::COMMIT);
222+
223+
[selfData release];
224+
[valuesData release];
225+
[indicesData release];
218226
} @catch (NSException* e) {
219227
ET_LOG(Error, "aoti_torch_mps_topk: ObjC exception: %s - %s",
220228
e.name.UTF8String, e.reason.UTF8String);
@@ -223,7 +231,6 @@ AOTITorchError aoti_torch_mps_topk(
223231
}
224232

225233
// Create output tensor handles
226-
// Values tensor
227234
AOTITensorHandle values_handle = nullptr;
228235
aoti_torch_create_tensor_from_blob_v2(
229236
values_ptr, ndim, out_sizes.data(), out_strides.data(),
@@ -235,29 +242,25 @@ AOTITorchError aoti_torch_mps_topk(
235242
aoti_torch_mps_free(indices_ptr);
236243
return Error::Internal;
237244
}
238-
ET_LOG(Debug, "aoti_torch_mps_topk: values tensor created");
239245

240-
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
241246
memory_to_n_tensor[values_ptr] = 1;
242247

243248
// Indices tensor — MPSGraph outputs int32, AOTInductor expects int64.
244-
// Allocate a new int64 buffer and convert.
245249
size_t indices_i64_bytes = num_elements * sizeof(int64_t);
246250
void* indices_i64_ptr = nullptr;
247251
allocate_mtl_buffer(&indices_i64_ptr, indices_i64_bytes);
248252

249253
// Copy int32 → int64 on CPU (small tensor, fast)
254+
stream->synchronize(SyncType::COMMIT_AND_WAIT);
250255
{
251-
auto* stream_sync = getCurrentMetalStream();
252-
stream_sync->synchronize(SyncType::COMMIT_AND_WAIT);
253-
254256
int32_t* src = reinterpret_cast<int32_t*>(indices_ptr);
255257
int64_t* dst = reinterpret_cast<int64_t*>(indices_i64_ptr);
256258
for (size_t i = 0; i < num_elements; i++) {
257259
dst[i] = static_cast<int64_t>(src[i]);
258260
}
259261
}
260262
aoti_torch_mps_free(indices_ptr);
263+
indices_ptr = nullptr;
261264

262265
int32_t indices_dtype = static_cast<int32_t>(exec_aten::ScalarType::Long);
263266
std::vector<int64_t> indices_strides(ndim);
@@ -281,17 +284,19 @@ AOTITorchError aoti_torch_mps_topk(
281284
*ret0 = values_handle;
282285
*ret1 = indices_handle;
283286

284-
ET_LOG(Debug, "aoti_torch_mps_topk: Completed successfully");
285-
286287
} // @autoreleasepool
287288

288289
return Error::Ok;
289290

290291
} catch (const std::exception& e) {
291292
ET_LOG(Error, "aoti_torch_mps_topk exception: %s", e.what());
293+
if (values_ptr) aoti_torch_mps_free(values_ptr);
294+
if (indices_ptr) aoti_torch_mps_free(indices_ptr);
292295
return Error::Internal;
293296
} catch (...) {
294297
ET_LOG(Error, "aoti_torch_mps_topk: unknown exception");
298+
if (values_ptr) aoti_torch_mps_free(values_ptr);
299+
if (indices_ptr) aoti_torch_mps_free(indices_ptr);
295300
return Error::Internal;
296301
}
297302
}

0 commit comments

Comments
 (0)