|
| 1 | +import mlx.core as mx |
| 2 | + |
| 3 | +def TopK(x: mx.array, k: mx.array, axis=-1, largest=1, sorted=1): |
| 4 | + if isinstance(k, mx.array): |
| 5 | + k = k.item() |
| 6 | + if x.ndim == 2 and axis == 1: |
| 7 | + sample = mx.arange(x.shape[0])[:, None] |
| 8 | + if largest == 0: |
| 9 | + sorted_indices = mx.argpartition(x, kth=k - 1, axis=axis) |
| 10 | + sorted_indices = sorted_indices[:, :k] |
| 11 | + sorted_indices = sorted_indices[sample, mx.argsort(x[sample, sorted_indices])] |
| 12 | + else: |
| 13 | + sorted_indices = mx.argpartition(-x, kth=k-1, axis=axis) |
| 14 | + sorted_indices = sorted_indices[:, :k] |
| 15 | + sorted_indices = sorted_indices[sample, mx.argsort(-x[sample, sorted_indices])] |
| 16 | + sorted_distances = x[sample, sorted_indices] |
| 17 | + return (sorted_distances, sorted_indices.astype(mx.int64)) |
| 18 | + |
| 19 | + if largest == 0: |
| 20 | + sorted_indices = mx.argsort(x, axis=axis) |
| 21 | + sorted_values = mx.sort(x, axis=axis) |
| 22 | + else: |
| 23 | + sorted_indices = mx.argsort(-x, axis=axis) |
| 24 | + sorted_values = -mx.sort(-x, axis=axis) |
| 25 | + ark = mx.arange(k) |
| 26 | + topk_sorted_indices = mx.take(sorted_indices, ark, axis=axis) |
| 27 | + topk_sorted_values = mx.take(sorted_values, ark, axis=axis) |
| 28 | + return topk_sorted_values, topk_sorted_indices.astype(mx.int64) |
0 commit comments