Skip to content

Commit e37e926

Browse files
authored
Add vecdot to array API namespace (#3748)
1 parent da38f3e commit e37e926

6 files changed

Lines changed: 70 additions & 1 deletion

File tree

docs/src/python/ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ Operations
197197
triu
198198
unflatten
199199
unstack
200+
vecdot
200201
var
201202
view
202203
where

mlx/ops.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5534,6 +5534,28 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
55345534
return tensordot(a, b, {-1}, {-1}, s);
55355535
}
55365536

5537+
array vecdot(
5538+
const array& a,
5539+
const array& b,
5540+
int axis /* = -1 */,
5541+
StreamOrDevice s /* = {} */) {
5542+
if (a.ndim() == 0 || b.ndim() == 0) {
5543+
throw std::invalid_argument("[vecdot] inputs must be at least 1D.");
5544+
}
5545+
int ax = axis < 0 ? axis + a.ndim() : axis;
5546+
if (ax < 0 || ax >= a.ndim()) {
5547+
throw std::invalid_argument("[vecdot] axis is out of bounds.");
5548+
}
5549+
if (axis < 0 ? axis + b.ndim() != ax : axis >= b.ndim()) {
5550+
throw std::invalid_argument("[vecdot] axis is out of bounds.");
5551+
}
5552+
if (a.shape(ax) != b.shape(ax)) {
5553+
throw std::invalid_argument(
5554+
"[vecdot] a and b must have the same size along axis.");
5555+
}
5556+
return sum(multiply(conjugate(a, s), b, s), ax, false, s);
5557+
}
5558+
55375559
/** Compute D = beta * C + alpha * (A @ B) */
55385560
array addmm(
55395561
array c,

mlx/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,10 @@ MLX_API array outer(const array& a, const array& b, StreamOrDevice s = {});
15741574
/** Compute the inner product of two vectors. */
15751575
MLX_API array inner(const array& a, const array& b, StreamOrDevice s = {});
15761576

1577+
/** Compute a vector dot product along an axis. */
1578+
MLX_API array
1579+
vecdot(const array& a, const array& b, int axis = -1, StreamOrDevice s = {});
1580+
15771581
/** Compute D = beta * C + alpha * (A @ B) */
15781582
MLX_API array addmm(
15791583
array c,

python/src/ops.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4737,6 +4737,27 @@ void init_ops(nb::module_& m) {
47374737
Returns:
47384738
array: The inner product.
47394739
)pbdoc");
4740+
m.def(
4741+
"vecdot",
4742+
&mx::vecdot,
4743+
nb::arg(),
4744+
nb::arg(),
4745+
"axis"_a = -1,
4746+
nb::kw_only(),
4747+
"stream"_a = nb::none(),
4748+
nb::sig(
4749+
"def vecdot(a: array, b: array, /, *, axis: int = -1, stream: Union[None, Stream, Device] = None) -> array"),
4750+
R"pbdoc(
4751+
Compute the vector dot product of two arrays along an axis.
4752+
4753+
Args:
4754+
a (array): Input array
4755+
b (array): Input array
4756+
axis (int, optional): Axis over which to compute the dot product. Default: ``-1``.
4757+
4758+
Returns:
4759+
array: The vector dot product.
4760+
)pbdoc");
47404761
m.def(
47414762
"outer",
47424763
&mx::outer,

python/tests/test_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def test_isdtype(self):
192192

193193
# Reachable through the array API namespace.
194194
xp = mx.array(1.0).__array_namespace__()
195-
for name in ("result_type", "can_cast", "isdtype"):
195+
for name in ("result_type", "can_cast", "isdtype", "vecdot"):
196196
self.assertTrue(hasattr(xp, name), msg=name)
197197

198198

python/tests/test_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2655,6 +2655,27 @@ def test_inner(self):
26552655
self.assertCmpNumpy([(1, 1, 2), (3, 2)], mx.inner, np.inner)
26562656
self.assertCmpNumpy([(2, 3, 4), (4,)], mx.inner, np.inner)
26572657

2658+
def test_vecdot(self):
2659+
a = mx.array([[1, 2, 3], [4, 5, 6]])
2660+
b = mx.array([[7, 8, 9], [10, 11, 12]])
2661+
self.assertEqual(mx.vecdot(a, b).tolist(), [50, 167])
2662+
self.assertEqual(mx.vecdot(a, b, axis=0).tolist(), [47, 71, 99])
2663+
2664+
a = mx.array([1 + 2j, 3 + 4j])
2665+
b = mx.array([5 + 6j, 7 + 8j])
2666+
expected = np.vdot(np.array([1 + 2j, 3 + 4j]), np.array([5 + 6j, 7 + 8j]))
2667+
self.assertTrue(np.allclose(mx.vecdot(a, b), expected))
2668+
2669+
xp = mx.array(1.0).__array_namespace__()
2670+
self.assertEqual(xp.vecdot(mx.array([1, 2]), mx.array([3, 4])).item(), 11)
2671+
2672+
with self.assertRaises(ValueError):
2673+
mx.vecdot(mx.array(1), mx.array([1]))
2674+
with self.assertRaises(ValueError):
2675+
mx.vecdot(mx.array([1, 2]), mx.array([1, 2]), axis=1)
2676+
with self.assertRaises(ValueError):
2677+
mx.vecdot(mx.array([1, 2]), mx.array([1]))
2678+
26582679
def test_outer(self):
26592680
self.assertCmpNumpy([(3,), (3,)], mx.outer, np.outer)
26602681
self.assertCmpNumpy(

0 commit comments

Comments
 (0)