@@ -2655,6 +2655,27 @@ def test_inner(self):
26552655 self .assertCmpNumpy ([(1 , 1 , 2 ), (3 , 2 )], mx .inner , np .inner )
26562656 self .assertCmpNumpy ([(2 , 3 , 4 ), (4 ,)], mx .inner , np .inner )
26572657
2658+ def test_vecdot (self ):
2659+ a = mx .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2660+ b = mx .array ([[7 , 8 , 9 ], [10 , 11 , 12 ]])
2661+ self .assertEqual (mx .vecdot (a , b ).tolist (), [50 , 167 ])
2662+ self .assertEqual (mx .vecdot (a , b , axis = 0 ).tolist (), [47 , 71 , 99 ])
2663+
2664+ a = mx .array ([1 + 2j , 3 + 4j ])
2665+ b = mx .array ([5 + 6j , 7 + 8j ])
2666+ expected = np .vdot (np .array ([1 + 2j , 3 + 4j ]), np .array ([5 + 6j , 7 + 8j ]))
2667+ self .assertTrue (np .allclose (mx .vecdot (a , b ), expected ))
2668+
2669+ xp = mx .array (1.0 ).__array_namespace__ ()
2670+ self .assertEqual (xp .vecdot (mx .array ([1 , 2 ]), mx .array ([3 , 4 ])).item (), 11 )
2671+
2672+ with self .assertRaises (ValueError ):
2673+ mx .vecdot (mx .array (1 ), mx .array ([1 ]))
2674+ with self .assertRaises (ValueError ):
2675+ mx .vecdot (mx .array ([1 , 2 ]), mx .array ([1 , 2 ]), axis = 1 )
2676+ with self .assertRaises (ValueError ):
2677+ mx .vecdot (mx .array ([1 , 2 ]), mx .array ([1 ]))
2678+
26582679 def test_outer (self ):
26592680 self .assertCmpNumpy ([(3 ,), (3 ,)], mx .outer , np .outer )
26602681 self .assertCmpNumpy (
0 commit comments