Skip to content

Commit 8cd377b

Browse files
authored
Add the bartlett function (#3155)
1 parent f145ece commit 8cd377b

4 files changed

Lines changed: 52 additions & 0 deletions

File tree

mlx/ops.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,21 @@ array argmax(
23122312
return out;
23132313
}
23142314

2315+
array bartlett(int M, StreamOrDevice s /* = {} */) {
2316+
if (M < 1) {
2317+
return array({});
2318+
}
2319+
if (M == 1) {
2320+
return ones({1}, float32, s);
2321+
}
2322+
2323+
auto n = arange(0, M, float32, s);
2324+
float factor_val = 2.0f / (M - 1);
2325+
auto factor = array(factor_val, float32);
2326+
auto term = subtract(multiply(factor, n, s), array(1.0f, float32), s);
2327+
return subtract(array(1.0f, float32), abs(term, s), s);
2328+
}
2329+
23152330
array hanning(int M, StreamOrDevice s /* = {} */) {
23162331
if (M < 1) {
23172332
return array({});

mlx/ops.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,9 @@ MLX_API array hanning(int M, StreamOrDevice s = {});
672672
/** Returns the Hamming window of size M. */
673673
MLX_API array hamming(int M, StreamOrDevice s = {});
674674

675+
/** Returns the bartlett window of size M. */
676+
MLX_API array bartlett(int M, StreamOrDevice s = {});
677+
675678
/** Returns the Blackmann window of size M. */
676679
MLX_API array blackman(int M, StreamOrDevice s = {});
677680

python/src/ops.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,28 @@ void init_ops(nb::module_& m) {
14281428
"stream"_a = nb::none(),
14291429
nb::sig(
14301430
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
1431+
m.def(
1432+
"bartlett",
1433+
&mlx::core::bartlett,
1434+
"M"_a,
1435+
nb::kw_only(),
1436+
"stream"_a = nb::none(),
1437+
R"pbdoc(
1438+
Return the Bartlett window.
1439+
1440+
The Bartlett window is a taper formed by using a weighted cosine.
1441+
1442+
.. math::
1443+
w(n) = 1 - \frac{2|n - (M-1)/2|}{M-1}
1444+
\qquad 0 \le n \le M-1
1445+
1446+
Args:
1447+
M (int): Number of points in the output window.
1448+
1449+
Returns:
1450+
array: The window, with the maximum value normalized to one (the value one
1451+
appears only if the number of samples is odd).
1452+
)pbdoc");
14311453
m.def(
14321454
"hanning",
14331455
&mlx::core::hanning,

python/tests/test_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,18 @@ def test_hamming_general(self):
14741474
self.assertEqual(a.size, 0)
14751475
self.assertEqual(a.dtype, mx.float32)
14761476

1477+
def test_bartlett_general(self):
1478+
a = mx.bartlett(10)
1479+
expected = np.bartlett(10)
1480+
self.assertTrue(np.allclose(a, expected, atol=1e-5))
1481+
1482+
a = mx.bartlett(1)
1483+
self.assertEqual(a.item(), 1.0)
1484+
1485+
a = mx.bartlett(0)
1486+
self.assertEqual(a.size, 0)
1487+
self.assertEqual(a.dtype, mx.float32)
1488+
14771489
def test_blackman_general(self):
14781490
a = mx.blackman(10)
14791491
expected = np.blackman(10)

0 commit comments

Comments
 (0)