Skip to content

Commit 9e393dc

Browse files
committed
add topk
1 parent 7873b39 commit 9e393dc

3 files changed

Lines changed: 29 additions & 1 deletion

File tree

mlx/onnx/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .op_pool import MaxPool, AveragePool
1414
from .op_conv import Conv
1515
from .op_slice import Slice
16+
from .op_topk import TopK
1617

1718
# Reference Docs: https://onnx.ai/onnx/operators/
1819

mlx/onnx/ops/op_topk.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import mlx.core as mx
2+
3+
def TopK(x: mx.array, k: mx.array, axis=-1, largest=1, sorted=1):
4+
if isinstance(k, mx.array):
5+
k = k.item()
6+
if x.ndim == 2 and axis == 1:
7+
sample = mx.arange(x.shape[0])[:, None]
8+
if largest == 0:
9+
sorted_indices = mx.argpartition(x, kth=k - 1, axis=axis)
10+
sorted_indices = sorted_indices[:, :k]
11+
sorted_indices = sorted_indices[sample, mx.argsort(x[sample, sorted_indices])]
12+
else:
13+
sorted_indices = mx.argpartition(-x, kth=k-1, axis=axis)
14+
sorted_indices = sorted_indices[:, :k]
15+
sorted_indices = sorted_indices[sample, mx.argsort(-x[sample, sorted_indices])]
16+
sorted_distances = x[sample, sorted_indices]
17+
return (sorted_distances, sorted_indices.astype(mx.int64))
18+
19+
if largest == 0:
20+
sorted_indices = mx.argsort(x, axis=axis)
21+
sorted_values = mx.sort(x, axis=axis)
22+
else:
23+
sorted_indices = mx.argsort(-x, axis=axis)
24+
sorted_values = -mx.sort(-x, axis=axis)
25+
ark = mx.arange(k)
26+
topk_sorted_indices = mx.take(sorted_indices, ark, axis=axis)
27+
topk_sorted_values = mx.take(sorted_values, ark, axis=axis)
28+
return topk_sorted_values, topk_sorted_indices.astype(mx.int64)

tests/test_onnx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def supports_device(cls, device: str) -> bool:
6565
btest.exclude("test_convtranspose_*")
6666

6767
btest.exclude("test_PReLU_*")
68-
btest.exclude("test_topk*")
6968

7069
# TODO: Implement dilations / col format
7170
btest.exclude("test_averagepool_2d_dilations_cpu")

0 commit comments

Comments
 (0)