88
99// Top-k operator using MPSGraph.
1010// Used by MoE routing (torch.topk in SparseMoE.forward).
11+ // Note: sorted parameter is accepted but MPSGraph always returns sorted results.
1112
1213#include < executorch/backends/apple/metal/runtime/ops/common.h>
1314
@@ -40,6 +41,9 @@ AOTITorchError aoti_torch_mps_topk(
4041 return Error::Internal;
4142 }
4243
44+ void * values_ptr = nullptr ;
45+ void * indices_ptr = nullptr ;
46+
4347 try {
4448 @autoreleasepool {
4549 auto * self_tensor = reinterpret_cast <Tensor*>(self);
@@ -55,7 +59,7 @@ AOTITorchError aoti_torch_mps_topk(
5559
5660 int64_t dim_size = self_tensor->sizes ()[dim];
5761 if (k > dim_size) {
58- ET_LOG (Error, " aoti_torch_mps_topk: k=%lld > dim_size=%lld\n " , k, dim_size);
62+ ET_LOG (Error, " aoti_torch_mps_topk: k=%lld > dim_size=%lld" , k, dim_size);
5963 return Error::InvalidArgument;
6064 }
6165
@@ -96,18 +100,20 @@ AOTITorchError aoti_torch_mps_topk(
96100 size_t values_bytes = num_elements * element_size;
97101 size_t indices_bytes = num_elements * sizeof (int32_t );
98102
99- void * values_ptr = nullptr ;
100- void * indices_ptr = nullptr ;
101103 allocate_mtl_buffer (&values_ptr, values_bytes);
102104 allocate_mtl_buffer (&indices_ptr, indices_bytes);
103105
104- // Build MPSGraph
105106 // Convert input shape to NSArray<NSNumber*>
106107 NSMutableArray <NSNumber *>* input_shape = [NSMutableArray arrayWithCapacity: ndim];
107108 for (int64_t i = 0 ; i < ndim; i++) {
108109 [input_shape addObject: @(self_tensor->sizes ()[i])];
109110 }
110111
112+ NSMutableArray <NSNumber *>* out_ns_shape = [NSMutableArray arrayWithCapacity: ndim];
113+ for (int64_t i = 0 ; i < ndim; i++) {
114+ [out_ns_shape addObject: @(out_sizes[i])];
115+ }
116+
111117 // Check graph cache
112118 GraphCacheKey cache_key;
113119 cache_key.op_name = " topk" ;
@@ -120,101 +126,103 @@ AOTITorchError aoti_torch_mps_topk(
120126 cache_key.dtype = dtype;
121127 cache_key.transpose_flag = false ;
122128
129+ stream->endKernelCoalescing ();
130+
131+ id <MTLBuffer > self_buffer = get_mtl_buffer (self_tensor, " topk" , " self" );
132+ id <MTLBuffer > values_buffer = ptr_to_mtl_buffer[values_ptr];
133+ id <MTLBuffer > indices_buffer = ptr_to_mtl_buffer[indices_ptr];
134+
123135 auto cache_it = graph_cache.find (cache_key);
124136 if (cache_it != graph_cache.end ()) {
125137 cache_stats.hits ++;
138+ cache_stats.logStats ();
126139 auto & cached = cache_it->second ;
127140
128- id < MTLBuffer > self_buffer = get_mtl_buffer (self_tensor, " topk " , " self " ) ;
129- id < MTLBuffer > values_buffer = ptr_to_mtl_buffer[values_ptr ];
130- id < MTLBuffer > indices_buffer = ptr_to_mtl_buffer[indices_ptr ];
141+ MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: self_buffer shape: input_shape dataType: mps_dtype] ;
142+ MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: values_buffer shape: out_ns_shape dataType: mps_dtype ];
143+ MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: indices_buffer shape: out_ns_shape dataType: MPSDataTypeInt32 ];
131144
132145 NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
133- cached.input1 : [[MPSGraphTensorData alloc ] initWithMTLBuffer: self_buffer shape: input_shape dataType: mps_dtype] ,
146+ cached.input1 : selfData ,
134147 };
135-
136- NSMutableArray <NSNumber *>* out_ns_shape = [NSMutableArray arrayWithCapacity: ndim];
137- for (int64_t i = 0 ; i < ndim; i++) {
138- [out_ns_shape addObject: @(out_sizes[i])];
139- }
140-
141148 NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* results = @{
142- cached.output : [[MPSGraphTensorData alloc ] initWithMTLBuffer: values_buffer shape: out_ns_shape dataType: mps_dtype] ,
143- cached.input2 : [[MPSGraphTensorData alloc ] initWithMTLBuffer: indices_buffer shape: out_ns_shape dataType: MPSDataTypeInt32] ,
149+ cached.output : valuesData ,
150+ cached.input2 : indicesData ,
144151 };
145152
146- stream->executeMPSGraph (cached.graph , feeds, results, SyncType::COMMIT );
153+ @try {
154+ stream->executeMPSGraph (cached.graph , feeds, results, SyncType::COMMIT );
155+ } @catch (NSException * e) {
156+ ET_LOG (Error, " aoti_torch_mps_topk: ObjC exception: %s - %s" ,
157+ e.name .UTF8String , e.reason .UTF8String );
158+ throw std::runtime_error (std::string (" MPSGraph topk failed: " ) + e.reason .UTF8String );
159+ }
160+
161+ [selfData release ];
162+ [valuesData release ];
163+ [indicesData release ];
147164 } else {
148165 cache_stats.misses ++;
166+ cache_stats.logStats ();
149167 ET_LOG (Debug, " aoti_torch_mps_topk: cache miss, building graph" );
150168
151169 @try {
152- MPSGraph* graph = [[MPSGraph alloc ] init ];
153- MPSGraphTensor* input = [graph placeholderWithShape: input_shape
154- dataType: mps_dtype
155- name: @" self" ];
156-
157- // MPSGraph topK: returns (values, indices) along the last dimension.
158- // If dim != -1, we need to transpose dim to last, topk, then transpose back.
159- MPSGraphTensor* work = input;
160- bool need_transpose = (dim != ndim - 1 );
161-
162- if (need_transpose) {
163- work = [graph transposeTensor: work dimension: dim withDimension: ndim - 1 name: nil ];
164- }
165-
166- // MPSGraph topKWithTensor returns along the last axis
167- NSArray <MPSGraphTensor*>* topk_results;
168- if (largest) {
169- topk_results = [graph topKWithSourceTensor: work k: (NSUInteger )k name: nil ];
170- } else {
171- // For smallest: negate, topk, negate back
172- MPSGraphTensor* neg = [graph negativeWithTensor: work name: nil ];
173- topk_results = [graph topKWithSourceTensor: neg k: (NSUInteger )k name: nil ];
174- topk_results = @[
175- [graph negativeWithTensor: topk_results[0 ] name: nil ],
176- topk_results[1 ]
177- ];
178- }
179-
180- MPSGraphTensor* values_out = topk_results[0 ];
181- MPSGraphTensor* indices_out = topk_results[1 ];
182-
183- if (need_transpose) {
184- values_out = [graph transposeTensor: values_out dimension: dim withDimension: ndim - 1 name: nil ];
185- indices_out = [graph transposeTensor: indices_out dimension: dim withDimension: ndim - 1 name: nil ];
186- }
187-
188- // Cache the graph
189- CachedGraph cached_graph;
190- cached_graph.graph = graph;
191- cached_graph.input1 = input;
192- cached_graph.input2 = indices_out; // reuse input2 slot for indices output
193- cached_graph.output = values_out;
194- graph_cache[cache_key] = cached_graph;
195-
196- // Execute
197- id <MTLBuffer > self_buffer = get_mtl_buffer (self_tensor, " topk" , " self" );
198- id <MTLBuffer > values_buffer = ptr_to_mtl_buffer[values_ptr];
199- id <MTLBuffer > indices_buffer = ptr_to_mtl_buffer[indices_ptr];
200-
201- NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
202- input: [[MPSGraphTensorData alloc ] initWithMTLBuffer: self_buffer shape: input_shape dataType: mps_dtype],
203- };
204-
205- NSMutableArray <NSNumber *>* out_ns_shape = [NSMutableArray arrayWithCapacity: ndim];
206- for (int64_t i = 0 ; i < ndim; i++) {
207- [out_ns_shape addObject: @(out_sizes[i])];
208- }
209-
210- NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* results = @{
211- values_out: [[MPSGraphTensorData alloc ] initWithMTLBuffer: values_buffer shape: out_ns_shape dataType: mps_dtype],
212- indices_out: [[MPSGraphTensorData alloc ] initWithMTLBuffer: indices_buffer shape: out_ns_shape dataType: MPSDataTypeInt32],
213- };
214-
215- ET_LOG (Debug, " aoti_torch_mps_topk: executing MPSGraph" );
216- stream->executeMPSGraph (graph, feeds, results, SyncType::COMMIT );
217- ET_LOG (Debug, " aoti_torch_mps_topk: MPSGraph done" );
170+ MPSGraph* graph = [[MPSGraph alloc ] init ];
171+ MPSGraphTensor* input = [graph placeholderWithShape: input_shape
172+ dataType: mps_dtype
173+ name: @" self" ];
174+
175+ MPSGraphTensor* work = input;
176+ bool need_transpose = (dim != ndim - 1 );
177+
178+ if (need_transpose) {
179+ work = [graph transposeTensor: work dimension: dim withDimension: ndim - 1 name: nil ];
180+ }
181+
182+ NSArray <MPSGraphTensor*>* topk_results;
183+ if (largest) {
184+ topk_results = [graph topKWithSourceTensor: work k: (NSUInteger )k name: nil ];
185+ } else {
186+ MPSGraphTensor* neg = [graph negativeWithTensor: work name: nil ];
187+ topk_results = [graph topKWithSourceTensor: neg k: (NSUInteger )k name: nil ];
188+ topk_results = @[
189+ [graph negativeWithTensor: topk_results[0 ] name: nil ],
190+ topk_results[1 ]
191+ ];
192+ }
193+
194+ MPSGraphTensor* values_out = topk_results[0 ];
195+ MPSGraphTensor* indices_out = topk_results[1 ];
196+
197+ if (need_transpose) {
198+ values_out = [graph transposeTensor: values_out dimension: dim withDimension: ndim - 1 name: nil ];
199+ indices_out = [graph transposeTensor: indices_out dimension: dim withDimension: ndim - 1 name: nil ];
200+ }
201+
202+ CachedGraph cached_graph;
203+ cached_graph.graph = graph;
204+ cached_graph.input1 = input;
205+ cached_graph.input2 = indices_out;
206+ cached_graph.output = values_out;
207+ graph_cache[cache_key] = cached_graph;
208+
209+ MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: self_buffer shape: input_shape dataType: mps_dtype];
210+ MPSGraphTensorData* valuesData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: values_buffer shape: out_ns_shape dataType: mps_dtype];
211+ MPSGraphTensorData* indicesData = [[MPSGraphTensorData alloc ] initWithMTLBuffer: indices_buffer shape: out_ns_shape dataType: MPSDataTypeInt32];
212+
213+ NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
214+ input: selfData,
215+ };
216+ NSDictionary <MPSGraphTensor*, MPSGraphTensorData*>* results = @{
217+ values_out: valuesData,
218+ indices_out: indicesData,
219+ };
220+
221+ stream->executeMPSGraph (graph, feeds, results, SyncType::COMMIT );
222+
223+ [selfData release ];
224+ [valuesData release ];
225+ [indicesData release ];
218226 } @catch (NSException * e) {
219227 ET_LOG (Error, " aoti_torch_mps_topk: ObjC exception: %s - %s" ,
220228 e.name .UTF8String , e.reason .UTF8String );
@@ -223,7 +231,6 @@ AOTITorchError aoti_torch_mps_topk(
223231 }
224232
225233 // Create output tensor handles
226- // Values tensor
227234 AOTITensorHandle values_handle = nullptr ;
228235 aoti_torch_create_tensor_from_blob_v2 (
229236 values_ptr, ndim, out_sizes.data (), out_strides.data (),
@@ -235,29 +242,25 @@ AOTITorchError aoti_torch_mps_topk(
235242 aoti_torch_mps_free (indices_ptr);
236243 return Error::Internal;
237244 }
238- ET_LOG (Debug, " aoti_torch_mps_topk: values tensor created" );
239245
240- extern std::unordered_map<void *, int32_t > memory_to_n_tensor;
241246 memory_to_n_tensor[values_ptr] = 1 ;
242247
243248 // Indices tensor — MPSGraph outputs int32, AOTInductor expects int64.
244- // Allocate a new int64 buffer and convert.
245249 size_t indices_i64_bytes = num_elements * sizeof (int64_t );
246250 void * indices_i64_ptr = nullptr ;
247251 allocate_mtl_buffer (&indices_i64_ptr, indices_i64_bytes);
248252
249253 // Copy int32 → int64 on CPU (small tensor, fast)
254+ stream->synchronize (SyncType::COMMIT_AND_WAIT );
250255 {
251- auto * stream_sync = getCurrentMetalStream ();
252- stream_sync->synchronize (SyncType::COMMIT_AND_WAIT );
253-
254256 int32_t * src = reinterpret_cast <int32_t *>(indices_ptr);
255257 int64_t * dst = reinterpret_cast <int64_t *>(indices_i64_ptr);
256258 for (size_t i = 0 ; i < num_elements; i++) {
257259 dst[i] = static_cast <int64_t >(src[i]);
258260 }
259261 }
260262 aoti_torch_mps_free (indices_ptr);
263+ indices_ptr = nullptr ;
261264
262265 int32_t indices_dtype = static_cast <int32_t >(exec_aten::ScalarType::Long);
263266 std::vector<int64_t > indices_strides (ndim);
@@ -281,17 +284,19 @@ AOTITorchError aoti_torch_mps_topk(
281284 *ret0 = values_handle;
282285 *ret1 = indices_handle;
283286
284- ET_LOG (Debug, " aoti_torch_mps_topk: Completed successfully" );
285-
286287 } // @autoreleasepool
287288
288289 return Error::Ok;
289290
290291 } catch (const std::exception& e) {
291292 ET_LOG (Error, " aoti_torch_mps_topk exception: %s" , e.what ());
293+ if (values_ptr) aoti_torch_mps_free (values_ptr);
294+ if (indices_ptr) aoti_torch_mps_free (indices_ptr);
292295 return Error::Internal;
293296 } catch (...) {
294297 ET_LOG (Error, " aoti_torch_mps_topk: unknown exception" );
298+ if (values_ptr) aoti_torch_mps_free (values_ptr);
299+ if (indices_ptr) aoti_torch_mps_free (indices_ptr);
295300 return Error::Internal;
296301 }
297302}
0 commit comments