-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathemlx_nif.cpp
More file actions
1466 lines (1250 loc) · 45.2 KB
/
emlx_nif.cpp
File metadata and controls
1466 lines (1250 loc) · 45.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "erl_nif.h"
#include "mlx/mlx.h"
#include "nx_nif_utils.hpp"
#include <iostream>
#include <map>
#include <numeric>
#include <string>
#include <cstring>
using namespace mlx::core;
std::map<const std::string, const mlx::core::Dtype> dtypes = {
{"bool", mlx::core::bool_}, {"uint8", mlx::core::uint8},
{"uint16", mlx::core::uint16}, {"uint32", mlx::core::uint32},
{"uint64", mlx::core::uint64}, {"int8", mlx::core::int8},
{"int16", mlx::core::int16}, {"int32", mlx::core::int32},
{"int64", mlx::core::int64}, {"float16", mlx::core::float16},
{"float32", mlx::core::float32}, {"bfloat16", mlx::core::bfloat16},
{"complex64", mlx::core::complex64}};
std::map<const std::string, const uint8_t> dtype_sizes = {
{"bool", mlx::core::bool_.size()},
{"uint8", mlx::core::uint8.size()},
{"uint16", mlx::core::uint16.size()},
{"uint32", mlx::core::uint32.size()},
{"uint64", mlx::core::uint64.size()},
{"int8", mlx::core::int8.size()},
{"int16", mlx::core::int16.size()},
{"int32", mlx::core::int32.size()},
{"int64", mlx::core::int64.size()},
{"float16", mlx::core::float16.size()},
{"float32", mlx::core::float32.size()},
{"bfloat16", mlx::core::bfloat16.size()},
{"complex64", mlx::core::complex64.size()}};
inline mlx::core::Dtype string2dtype(const std::string &atom) {
auto it = dtypes.find(atom);
if (it != dtypes.end()) {
return it->second;
}
throw std::runtime_error("Unknown dtype: " + atom);
}
inline const std::string *dtype2string(const mlx::core::Dtype dtype) {
for (const auto &pair : dtypes) {
if (pair.second == dtype) {
return &pair.first;
}
}
return nullptr;
}
inline const mlx::core::Device string2device(const std::string &atom) {
if (atom == "cpu") {
return mlx::core::Device(mlx::core::Device::DeviceType::cpu, 0);
} else if (atom == "gpu") {
return mlx::core::Device(mlx::core::Device::DeviceType::gpu, 0);
}
throw std::runtime_error("Unknown device: " + atom);
}
// MLX 0.31+ uses Shape = SmallVector<int> and Strides = SmallVector<long long>
// which no longer accept implicit construction from std::vector.
static inline mlx::core::Shape to_shape(const std::vector<int> &v) {
return mlx::core::Shape(v.begin(), v.end());
}
static inline mlx::core::Strides to_strides(const std::vector<int64_t> &v) {
return mlx::core::Strides(v.begin(), v.end());
}
// Class to manage the refcount of MLX tensors
class TensorP {
public:
TensorP(ErlNifEnv *env, const ERL_NIF_TERM arg) : ptr(nullptr) {
// setup
if (!enif_get_resource(env, arg, resource_object<mlx::core::array>::type,
(void **)&ptr)) {
err = nx::nif::error(env, "Unable to get tensor param in NIF");
return;
}
refcount = (std::atomic<int> *)(ptr + 1);
deleted = (std::atomic_flag *)(refcount + 1);
if (refcount->load() == 0) {
// already deallocated
ptr = nullptr;
err = nx::nif::error(env, "Tensor has been deallocated");
return;
}
if (is_valid()) {
// increase reference count
++(*refcount);
}
}
~TensorP() {
if (is_valid()) {
// decrease reference count
if (refcount->fetch_sub(1) == 0) {
ptr->~array(); // Call MLX tensor destructor
}
}
}
bool deallocate() {
if (is_valid() && atomic_flag_test_and_set(deleted) == false) {
--(*refcount);
return true;
} else {
return false;
}
}
mlx::core::array *data() const { return ptr; }
// Raw ERTS resource pointer for use with enif_make_resource_binary.
void *resource_ptr() const { return static_cast<void *>(ptr); }
bool is_valid() const { return ptr != nullptr; }
ERL_NIF_TERM error() { return err; }
private:
mlx::core::array *ptr;
std::atomic<int> *refcount;
std::atomic_flag *deleted;
ERL_NIF_TERM err;
};
#define CATCH() \
catch (const std::exception &e) { \
std::ostringstream msg; \
msg << e.what() << " in NIF." << __func__ << "/" << argc; \
return nx::nif::error(env, msg.str().c_str()); \
} \
catch (...) { \
return nx::nif::error(env, "Unknown error occurred"); \
}
#define TENSOR(A) \
try { \
return nx::nif::ok(env, create_tensor_resource(env, A)); \
} \
CATCH()
ERL_NIF_TERM
create_tensor_resource(ErlNifEnv *env, mlx::core::array tensor) {
ERL_NIF_TERM ret;
mlx::core::array *tensorPtr;
std::atomic<int> *refcount;
tensorPtr = (mlx::core::array *)enif_alloc_resource(
resource_object<mlx::core::array>::type, sizeof(mlx::core::array) +
sizeof(std::atomic<int>) +
sizeof(std::atomic_flag));
if (tensorPtr == NULL)
return enif_make_badarg(env);
new (tensorPtr) mlx::core::array(std::move(tensor));
refcount = new (tensorPtr + 1) std::atomic<int>(1);
new (refcount + 1) std::atomic_flag();
ret = enif_make_resource(env, tensorPtr);
enif_release_resource(tensorPtr);
return ret;
}
ERL_NIF_TERM create_function_resource(ErlNifEnv *env, emlx::function function) {
ERL_NIF_TERM ret;
std::atomic<int> *refcount;
auto function_ptr = (emlx::function *)enif_alloc_resource(
resource_object<emlx::function>::type,
sizeof(std::function<std::vector<array>(const std::vector<array> &)>) +
sizeof(std::atomic<int>) + sizeof(std::atomic_flag));
if (function_ptr == NULL) {
return enif_make_badarg(env);
}
new (function_ptr) emlx::function(function);
refcount = new (function_ptr + 1) std::atomic<int>(1);
new (refcount + 1) std::atomic_flag();
ret = enif_make_resource(env, function_ptr);
enif_release_resource(function_ptr);
return ret;
}
#define NIF(NAME) \
ERL_NIF_TERM NAME(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
#define PARAM(ARGN, TYPE, VAR) \
TYPE VAR; \
GET(ARGN, VAR)
#define TENSOR_PARAM(ARGN, VAR) \
TensorP VAR##_tp(env, argv[ARGN]); \
mlx::core::array *VAR; \
if (!VAR##_tp.is_valid()) { \
return VAR##_tp.error(); \
} else { \
VAR = VAR##_tp.data(); \
}
#define LIST_PARAM(ARGN, TYPE, VAR) \
TYPE VAR; \
if (!nx::nif::get_list(env, argv[ARGN], VAR)) \
return nx::nif::error(env, "Unable to get " #VAR " list param.");
NIF(deallocate) {
TensorP t(env, argv[0]);
if (t.deallocate()) {
return nx::nif::ok(env);
} else {
return nx::nif::atom(env, "already_deallocated");
}
}
NIF(scalar_type) {
TENSOR_PARAM(0, t);
const std::string *type_name = dtype2string(t->dtype());
if (type_name != nullptr)
return nx::nif::ok(env, enif_make_atom(env, type_name->c_str()));
else
return nx::nif::error(env, "Could not determine tensor type.");
}
NIF(shape) {
TENSOR_PARAM(0, t);
std::vector<ERL_NIF_TERM> sizes;
for (int64_t dim = 0; dim < t->ndim(); dim++)
sizes.push_back(nx::nif::make(env, static_cast<int64_t>(t->shape()[dim])));
return nx::nif::ok(
env, enif_make_tuple_from_array(env, sizes.data(), sizes.size()));
}
NIF(ones) {
SHAPE_PARAM(0, shape);
TYPE_PARAM(1, type);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::ones(to_shape(shape), type, device));
}
NIF(zeros) {
SHAPE_PARAM(0, shape);
TYPE_PARAM(1, type);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::zeros(to_shape(shape), type, device));
}
NIF(reshape) {
TENSOR_PARAM(0, t);
SHAPE_PARAM(1, shape);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::reshape(*t, to_shape(shape), device));
}
NIF(astype) {
TENSOR_PARAM(0, t);
TYPE_PARAM(1, type);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::astype(*t, type, device));
}
NIF(to_blob) {
TENSOR_PARAM(0, t);
size_t byte_size = t->nbytes();
if (argc == 2) {
PARAM(1, int, param_limit);
byte_size = static_cast<size_t>(param_limit) * t->itemsize();
}
ERL_NIF_TERM resource_bin;
if (t->flags().row_contiguous) {
// Zero-copy: alias the MLX buffer via the existing tensor resource.
// Invariant: lib/emlx.ex calls eval(tensor) before this NIF, so
// data<void>() is guaranteed non-null and stable (MLX arrays are immutable
// once materialised). enif_make_resource_binary keeps the resource alive
// until the binary is GC'd, decoupling the binary lifetime from Elixir GC
// of the tensor term.
resource_bin = enif_make_resource_binary(env, t_tp.resource_ptr(),
t->data<void>(), byte_size);
} else {
// Non-contiguous: materialise a fresh row-major copy, wrap it in a minimal
// ERTS resource, and alias that buffer zero-copy.
// The resource holds only sizeof(mlx::core::array) — no TensorP refcount/
// deleted-flag tail — because it is never exposed to TensorP or the Elixir
// side; only the binary holds a reference and default_dtor<array>
// (~array()) is the sole destructor path.
auto ct = mlx::core::contiguous(*t);
mlx::core::eval(ct);
auto *ct_ptr = static_cast<mlx::core::array *>(enif_alloc_resource(
resource_object<mlx::core::array>::type, sizeof(mlx::core::array)));
if (!ct_ptr)
return enif_make_badarg(env);
new (ct_ptr) mlx::core::array(std::move(ct));
resource_bin =
enif_make_resource_binary(env, ct_ptr, ct_ptr->data<void>(), byte_size);
enif_release_resource(ct_ptr);
}
return nx::nif::ok(env, resource_bin);
}
uint64_t elem_count(std::vector<int> shape) {
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>{});
}
NIF(from_blob) {
BINARY_PARAM(0, blob);
SHAPE_PARAM(1, shape);
TYPE_PARAM(2, type);
// DEVICE_PARAM(3, device);
if (blob.size / dtype_sizes[type_atom] < elem_count(shape))
return nx::nif::error(env,
"Binary size is too small for the requested shape");
try {
// Allocate MLX buffer and copy data from blob
size_t byte_size = blob.size;
allocator::Buffer mlx_buf = allocator::malloc(byte_size);
void *buf_ptr = mlx_buf.raw_ptr();
// Copy binary data to MLX buffer
std::memcpy(buf_ptr, blob.data, byte_size);
// Create deleter for the buffer
auto deleter = [](allocator::Buffer buf) { allocator::free(buf); };
// Create MLX array from the buffer
TENSOR(mlx::core::array(mlx_buf, to_shape(shape), type, deleter));
} catch (const std::exception &e) {
return nx::nif::error(env, e.what());
} catch (...) {
return nx::nif::error(env,
"Unknown error creating tensor from binary data");
}
}
NIF(scalar_tensor) {
SCALAR_PARAM(0, scalar, is_complex);
TYPE_PARAM(1, type);
// DEVICE_PARAM(2, device);
if (is_complex) {
TENSOR(mlx::core::array(complex_scalar, type))
} else {
TENSOR(mlx::core::array(scalar, type))
}
}
NIF(full) {
SCALAR_PARAM(0, scalar, is_complex);
SHAPE_PARAM(1, shape);
TYPE_PARAM(2, type);
DEVICE_PARAM(3, device);
if (is_complex) {
TENSOR(mlx::core::full(to_shape(shape), complex_scalar, type, device));
} else {
TENSOR(mlx::core::full(to_shape(shape), scalar, type, device));
}
}
NIF(arange) {
PARAM(0, int, start);
PARAM(1, int, stop);
PARAM(2, int, step);
PARAM(3, bool, integer);
DEVICE_PARAM(4, device);
if (integer) {
TENSOR(mlx::core::arange(start, stop, step, device));
} else {
TENSOR(mlx::core::arange(static_cast<double>(start),
static_cast<double>(stop),
static_cast<double>(step), device));
}
}
NIF(eye) {
PARAM(0, int, m);
PARAM(1, int, n);
TYPE_PARAM(2, type);
DEVICE_PARAM(3, device);
TENSOR(mlx::core::eye(m, n, 0, type, device));
}
NIF(broadcast_to) {
TENSOR_PARAM(0, t);
SHAPE_PARAM(1, shape);
DEVICE_PARAM(2, device);
auto result = mlx::core::broadcast_to(*t, to_shape(shape), device);
TENSOR(result);
}
NIF(tensordot) {
TENSOR_PARAM(0, a);
TENSOR_PARAM(1, b);
LIST_PARAM(2, std::vector<int>, axes1);
LIST_PARAM(3, std::vector<int>, axes2);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::tensordot(*a, *b, axes1, axes2, device));
}
NIF(einsum) {
TENSOR_PARAM(0, a);
TENSOR_PARAM(1, b);
std::string spec_string;
if (!nx::nif::get(env, argv[2], spec_string)) {
return nx::nif::error(env, "Unable to get spec_string param.");
}
DEVICE_PARAM(3, device);
TENSOR(mlx::core::einsum(spec_string, std::vector<mlx::core::array>({*a, *b}),
device));
}
NIF(tri_inv) {
TENSOR_PARAM(0, tensor);
PARAM(1, bool, upper);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::linalg::tri_inv(*tensor, upper, device));
}
NIF(linalg_lu) {
TENSOR_PARAM(0, tensor);
DEVICE_PARAM(1, device);
try {
auto result = mlx::core::linalg::lu(*tensor, device);
return nx::nif::ok(env, nx::nif::make_list(env, result));
}
CATCH()
}
NIF(linalg_qr) {
TENSOR_PARAM(0, tensor);
DEVICE_PARAM(1, device);
try {
auto [q, r] = mlx::core::linalg::qr(*tensor, device);
return nx::nif::ok(env, enif_make_tuple2(
env,
create_tensor_resource(env, q),
create_tensor_resource(env, r)));
}
CATCH()
}
NIF(linalg_svd) {
TENSOR_PARAM(0, tensor);
PARAM(1, bool, compute_uv);
DEVICE_PARAM(2, device);
try {
auto result = mlx::core::linalg::svd(*tensor, compute_uv, device);
return nx::nif::ok(env, nx::nif::make_list(env, result));
}
CATCH()
}
NIF(linalg_cholesky) {
TENSOR_PARAM(0, tensor);
PARAM(1, bool, upper);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::linalg::cholesky(*tensor, upper, device));
}
NIF(linalg_eigh) {
TENSOR_PARAM(0, tensor);
ATOM_PARAM(1, uplo);
DEVICE_PARAM(2, device);
try {
auto [eigenvalues, eigenvectors] = mlx::core::linalg::eigh(*tensor, uplo, device);
return nx::nif::ok(env, enif_make_tuple2(
env,
create_tensor_resource(env, eigenvalues),
create_tensor_resource(env, eigenvectors)));
}
CATCH()
}
NIF(linalg_inv) {
TENSOR_PARAM(0, tensor);
DEVICE_PARAM(1, device);
TENSOR(mlx::core::linalg::inv(*tensor, device));
}
NIF(linalg_pinv) {
TENSOR_PARAM(0, tensor);
DEVICE_PARAM(1, device);
TENSOR(mlx::core::linalg::pinv(*tensor, device));
}
NIF(linalg_solve) {
TENSOR_PARAM(0, tensorA);
TENSOR_PARAM(1, tensorB);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::linalg::solve(*tensorA, *tensorB, device));
}
NIF(linalg_solve_triangular) {
TENSOR_PARAM(0, tensorA);
TENSOR_PARAM(1, tensorB);
PARAM(2, bool, upper);
DEVICE_PARAM(3, device);
TENSOR(mlx::core::linalg::solve_triangular(*tensorA, *tensorB, upper, device));
}
NIF(conv_general) {
TENSOR_PARAM(0, tensor_input);
TENSOR_PARAM(1, tensor_kernel);
LIST_PARAM(2, std::vector<int>, strides);
LIST_PARAM(3, std::vector<int>, padding_low);
LIST_PARAM(4, std::vector<int>, padding_high);
LIST_PARAM(5, std::vector<int>, kernel_dilation);
LIST_PARAM(6, std::vector<int>, input_dilation);
PARAM(7, int, feature_group_count);
DEVICE_PARAM(8, device);
TENSOR(mlx::core::conv_general(
*tensor_input, *tensor_kernel, strides, padding_low, padding_high,
kernel_dilation, input_dilation, feature_group_count, false, device));
}
NIF(transpose) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<int>, axes);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::transpose(*t, axes, device));
}
NIF(pad) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<int>, axes);
LIST_PARAM(2, std::vector<int>, low_pad_size);
LIST_PARAM(3, std::vector<int>, high_pad_size);
TENSOR_PARAM(4, pad_value);
DEVICE_PARAM(5, device);
TENSOR(mlx::core::pad(*t, axes, to_shape(low_pad_size), to_shape(high_pad_size),
*pad_value, "constant", device))
};
NIF(sort) {
TENSOR_PARAM(0, t);
PARAM(1, int, axis);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::sort(*t, axis, device));
}
NIF(argsort) {
TENSOR_PARAM(0, t);
PARAM(1, int, axis);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::argsort(*t, axis, device));
}
NIF(eval) {
TENSOR_PARAM(0, t);
mlx::core::eval(*t);
return nx::nif::ok(env);
}
NIF(stack) {
LIST_PARAM(0, std::vector<mlx::core::array>, arrays);
PARAM(1, int, axis);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::stack(arrays, axis, device));
}
NIF(where) {
TENSOR_PARAM(0, condition);
TENSOR_PARAM(1, x);
TENSOR_PARAM(2, y);
DEVICE_PARAM(3, device);
TENSOR(mlx::core::where(*condition, *x, *y, device));
}
NIF(concatenate) {
LIST_PARAM(0, std::vector<mlx::core::array>, arrays);
PARAM(1, int, axis);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::concatenate(arrays, axis, device));
}
NIF(take_along_axis) {
TENSOR_PARAM(0, t);
TENSOR_PARAM(1, indices);
PARAM(2, int, axis);
DEVICE_PARAM(3, device);
TENSOR(mlx::core::take_along_axis(*t, *indices, axis, device));
}
NIF(take) {
TENSOR_PARAM(0, t);
TENSOR_PARAM(1, indices);
PARAM(2, int, axis);
DEVICE_PARAM(3, device);
TENSOR(mlx::core::take(*t, *indices, axis, device));
}
NIF(gather) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<mlx::core::array>, indices);
LIST_PARAM(2, std::vector<int>, axes);
LIST_PARAM(3, std::vector<int>, slice_sizes);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::gather(*t, indices, axes, to_shape(slice_sizes), device));
}
NIF(scatter_add) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<mlx::core::array>, indices);
TENSOR_PARAM(2, tensor_updates);
LIST_PARAM(3, std::vector<int>, axes);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::scatter_add(*t, indices, *tensor_updates, axes, device));
}
NIF(scatter) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<mlx::core::array>, indices);
TENSOR_PARAM(2, tensor_updates);
LIST_PARAM(3, std::vector<int>, axes);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::scatter(*t, indices, *tensor_updates, axes, device));
}
/* Reduction Ops */
#define REDUCTION_AXES_OP(OP) REDUCTION_AXES_OP2(OP, OP)
#define REDUCTION_AXES_OP2(OP, NATIVE_OP) \
NIF(OP) { \
TENSOR_PARAM(0, tensor); \
LIST_PARAM(1, std::vector<int>, axes); \
PARAM(2, bool, keep_dims); \
DEVICE_PARAM(3, device); \
\
if (axes.empty()) { \
for (int i = 0; i < tensor->ndim(); ++i) { \
axes.push_back(i); \
} \
} \
TENSOR(mlx::core::NATIVE_OP(*tensor, axes, keep_dims, device)); \
}
#define REDUCTION_AXIS_OP(OP) REDUCTION_AXIS_OP2(OP, OP)
#define REDUCTION_AXIS_OP2(OP, NATIVE_OP) \
NIF(OP) { \
TENSOR_PARAM(0, tensor); \
if (argc == 3) { \
PARAM(1, bool, keep_dims); \
DEVICE_PARAM(2, device); \
TENSOR(mlx::core::NATIVE_OP(*tensor, keep_dims, device)); \
} else { \
PARAM(1, int, axis); \
PARAM(2, bool, keep_dims); \
DEVICE_PARAM(3, device); \
TENSOR(mlx::core::NATIVE_OP(*tensor, axis, keep_dims, device)); \
} \
}
#define REDUCTION_AXIS_REVERSIBLE_OP(OP) REDUCTION_AXIS_REVERSIBLE_OP2(OP, OP)
#define REDUCTION_AXIS_REVERSIBLE_OP2(OP, NATIVE_OP) \
NIF(OP) { \
TENSOR_PARAM(0, tensor); \
PARAM(1, int, axis); \
PARAM(2, bool, keep_dims); \
DEVICE_PARAM(3, device); \
\
TENSOR(mlx::core::NATIVE_OP(*tensor, axis, keep_dims, device)); \
}
REDUCTION_AXES_OP(all)
REDUCTION_AXES_OP(any)
REDUCTION_AXES_OP(sum)
REDUCTION_AXES_OP2(product, prod)
REDUCTION_AXIS_OP(argmax)
REDUCTION_AXIS_OP(argmin)
NIF(cumulative_sum) {
TENSOR_PARAM(0, tensor);
PARAM(1, int, axis);
PARAM(2, bool, reverse);
PARAM(3, bool, inclusive);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::cumsum(*tensor, axis, reverse, inclusive, device));
}
NIF(cumulative_product) {
TENSOR_PARAM(0, tensor);
PARAM(1, int, axis);
PARAM(2, bool, reverse);
PARAM(3, bool, inclusive);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::cumprod(*tensor, axis, reverse, inclusive, device));
}
NIF(cumulative_max) {
TENSOR_PARAM(0, tensor);
PARAM(1, int, axis);
PARAM(2, bool, reverse);
PARAM(3, bool, inclusive);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::cummax(*tensor, axis, reverse, inclusive, device));
}
NIF(cumulative_min) {
TENSOR_PARAM(0, tensor);
PARAM(1, int, axis);
PARAM(2, bool, reverse);
PARAM(3, bool, inclusive);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::cummin(*tensor, axis, reverse, inclusive, device));
}
/* Unary Ops */
#define UNARY_OP(OP) UNARY_OP2(OP, OP)
#define UNARY_OP2(OP, NATIVE_OP) \
NIF(OP) { \
TENSOR_PARAM(0, tensor); \
DEVICE_PARAM(1, device); \
\
TENSOR(mlx::core::NATIVE_OP(*tensor, device)); \
}
/* Binary Ops */
#define BINARY_OP(OP) BINARY_OP2(OP, OP)
#define BINARY_OP2(OP, NATIVE_OP) \
NIF(OP) { \
TENSOR_PARAM(0, a); \
TENSOR_PARAM(1, b); \
DEVICE_PARAM(2, device); \
\
TENSOR(mlx::core::NATIVE_OP(*a, *b, device)); \
}
static int open_resources(ErlNifEnv *env) {
const char *mod = "EMLX";
if (!open_resource<mlx::core::array>(env, mod, "MLXArray")) {
return -1;
}
if (!open_resource<emlx::function>(env, mod, "CompiledFunction")) {
return -1;
}
return 0;
}
static int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) {
if (open_resources(env) != 0) {
return -1;
}
return 0;
}
int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) {
// Silence "unused var" warnings.
(void)(env);
(void)(priv_data);
(void)(old_priv_data);
(void)(load_info);
return 0;
}
UNARY_OP(abs)
UNARY_OP(ceil)
UNARY_OP(conjugate)
UNARY_OP(floor)
UNARY_OP2(negate, negative)
UNARY_OP(round)
UNARY_OP(sign)
UNARY_OP(real)
UNARY_OP(imag)
UNARY_OP2(is_nan, isnan)
UNARY_OP2(is_infinity, isinf)
UNARY_OP(logical_not)
UNARY_OP(sigmoid)
UNARY_OP2(asin, arcsin)
UNARY_OP2(asinh, arcsinh)
UNARY_OP2(acos, arccos)
UNARY_OP2(acosh, arccosh)
UNARY_OP2(atan, arctan)
UNARY_OP2(atanh, arctanh)
UNARY_OP(cos)
UNARY_OP(cosh)
UNARY_OP(erf)
UNARY_OP2(erf_inv, erfinv)
UNARY_OP(exp)
UNARY_OP(expm1)
UNARY_OP(log)
UNARY_OP(log1p)
UNARY_OP(rsqrt)
UNARY_OP(sin)
UNARY_OP(sinh)
UNARY_OP(sqrt)
UNARY_OP(tan)
UNARY_OP(tanh)
BINARY_OP(add)
BINARY_OP(subtract)
BINARY_OP(multiply)
BINARY_OP2(pow, power)
BINARY_OP2(remainder, remainder)
BINARY_OP2(divide, divide)
BINARY_OP2(atan2, arctan2)
BINARY_OP2(minimum, minimum)
BINARY_OP2(maximum, maximum)
BINARY_OP2(quotient, floor_divide)
BINARY_OP(bitwise_and)
BINARY_OP(bitwise_or)
BINARY_OP(bitwise_xor)
NIF(bitwise_not) {
TENSOR_PARAM(0, a);
DEVICE_PARAM(1, device);
auto dtype = (*a).dtype();
auto mask = mlx::core::full({}, 0xFFFFFFFFFFFFFFFF, dtype, device);
TENSOR(mlx::core::subtract(mask, *a, device));
}
BINARY_OP(left_shift)
BINARY_OP(right_shift)
BINARY_OP(equal)
BINARY_OP(not_equal)
BINARY_OP(greater)
BINARY_OP(less)
BINARY_OP(greater_equal)
BINARY_OP(less_equal)
BINARY_OP(logical_and)
BINARY_OP(logical_or)
NIF(logical_xor) {
TENSOR_PARAM(0, a);
TENSOR_PARAM(1, b);
DEVICE_PARAM(2, device);
auto t1 = mlx::core::logical_or(*a, *b, device);
auto t2 =
mlx::core::logical_not(mlx::core::logical_and(*a, *b, device), device);
TENSOR(mlx::core::logical_and(t1, t2, device));
}
NIF(allclose) {
TENSOR_PARAM(0, a);
TENSOR_PARAM(1, b);
PARAM(2, double, rtol);
PARAM(3, double, atol);
PARAM(4, bool, equal_nan);
DEVICE_PARAM(5, device);
TENSOR(mlx::core::allclose(*a, *b, rtol, atol, equal_nan, device));
}
NIF(isclose) {
TENSOR_PARAM(0, a);
TENSOR_PARAM(1, b);
PARAM(2, double, rtol);
PARAM(3, double, atol);
PARAM(4, bool, equal_nan);
DEVICE_PARAM(5, device);
TENSOR(mlx::core::isclose(*a, *b, rtol, atol, equal_nan, device));
}
NIF(item) {
TENSOR_PARAM(0, t);
mlx::core::eval(*t);
// Fix for MLX scalar layout bug: Use the correct type when calling item<T>()
// to avoid reading wrong number of bytes from potentially invalid memory
// layouts.
auto dtype = t->dtype();
// Handle integer and boolean types with proper dtype matching
if (dtype == mlx::core::bool_) {
bool value = t->item<bool>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint8) {
uint8_t value = t->item<uint8_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint16) {
uint16_t value = t->item<uint16_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint32) {
uint32_t value = t->item<uint32_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::uint64) {
uint64_t value = t->item<uint64_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int8) {
int8_t value = t->item<int8_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int16) {
int16_t value = t->item<int16_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int32) {
int32_t value = t->item<int32_t>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<int64_t>(value)));
} else if (dtype == mlx::core::int64) {
int64_t value = t->item<int64_t>();
return nx::nif::ok(env, nx::nif::make(env, value));
} else if (dtype == mlx::core::float16 || dtype == mlx::core::bfloat16) {
// MLX handles float16/bfloat16 conversion internally
float value = t->item<float>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<double>(value)));
} else if (dtype == mlx::core::float32) {
float value = t->item<float>();
return nx::nif::ok(env, nx::nif::make(env, static_cast<double>(value)));
} else if (dtype == mlx::core::complex64) {
// Complex types need special handling - not supported via item()
return nx::nif::error(env,
"Complex scalar extraction not supported via item()");
} else {
// Fallback for any other types
double value = t->item<double>();
return nx::nif::ok(env, nx::nif::make(env, value));
}
}
NIF(slice) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<int>, starts);
LIST_PARAM(2, std::vector<int>, stops);
LIST_PARAM(3, std::vector<int>, strides);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::slice(*t, to_shape(starts), to_shape(stops), to_shape(strides), device));
}
NIF(slice_update) {
TENSOR_PARAM(0, t);
TENSOR_PARAM(1, tensor_updates);
LIST_PARAM(2, std::vector<int>, starts);
LIST_PARAM(3, std::vector<int>, stops);
DEVICE_PARAM(4, device);
TENSOR(mlx::core::slice_update(*t, *tensor_updates, to_shape(starts), to_shape(stops), device));
}
NIF(squeeze) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<int>, axes);
DEVICE_PARAM(2, device);
TENSOR(mlx::core::squeeze(*t, axes, device));
}
NIF(emlx_fft) {
TENSOR_PARAM(0, t);
PARAM(1, int, n);
PARAM(2, int, axis);
DEVICE_PARAM(3, device);