Skip to content

Commit e34d415

Browse files
MNT: move weights validation to generic cov
1 parent cb717b0 commit e34d415

3 files changed

Lines changed: 17 additions & 26 deletions

File tree

src/array_api_extra/_delegation.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -222,20 +222,6 @@ def cov(
222222
if m.ndim >= 2 and axis not in (-1, m.ndim - 1):
223223
m = xp.moveaxis(m, axis, -1)
224224

225-
# Validate weight shapes (eager metadata, lazy-safe). Value-based
226-
# checks (non-negative, integer dtype) are intentionally skipped so
227-
# lazy backends don't trigger compute -- same tradeoff as dask.cov.
228-
n_obs = m.shape[-1]
229-
for name, w in (("fweights", fweights), ("aweights", aweights)):
230-
if w is None:
231-
continue
232-
if w.ndim != 1:
233-
msg = f"`{name}` must be 1-D, got ndim={w.ndim}"
234-
raise ValueError(msg)
235-
if w.shape[0] != n_obs:
236-
msg = f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations"
237-
raise ValueError(msg)
238-
239225
# `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov`
240226
# requires integer `correction`. For non-integer-valued `correction`,
241227
# fall through to the generic implementation.

src/array_api_extra/_lib/_funcs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,23 @@ def cov(
299299
m = atleast_nd(m, ndim=2, xp=xp)
300300
m = xp.astype(m, dtype)
301301

302+
# Validate weight shapes (eager metadata, lazy-safe). Native backends
303+
# validate themselves; this covers the generic path (array-api-strict,
304+
# sparse, and the dask+weights fallback where the native check is
305+
# bypassed to preserve laziness).
306+
n_obs = m.shape[-1]
307+
for name, w_in in (("fweights", fweights), ("aweights", aweights)):
308+
if w_in is None:
309+
continue
310+
if w_in.ndim != 1:
311+
msg = f"`{name}` must be 1-D, got ndim={w_in.ndim}"
312+
raise ValueError(msg)
313+
if w_in.shape[0] != n_obs:
314+
msg = (
315+
f"`{name}` has length {w_in.shape[0]} but `m` has {n_obs} observations"
316+
)
317+
raise ValueError(msg)
318+
302319
fw = None
303320
if fweights is not None:
304321
fw = xp.astype(xp.asarray(fweights), dtype)

tests/test_funcs.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -723,18 +723,6 @@ def test_axis_out_of_bounds(self, xp: ModuleType):
723723
with pytest.raises(IndexError):
724724
_ = cov(m, axis=5)
725725

726-
def test_weights_shape_validation(self, xp: ModuleType):
727-
m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
728-
# Wrong length.
729-
with pytest.raises(ValueError, match="`fweights` has length"):
730-
_ = cov(m, fweights=xp.asarray([1, 2]))
731-
with pytest.raises(ValueError, match="`aweights` has length"):
732-
_ = cov(m, aweights=xp.asarray([0.1, 0.2]))
733-
# Wrong ndim.
734-
with pytest.raises(ValueError, match="`fweights` must be 1-D"):
735-
_ = cov(m, fweights=xp.asarray([[1, 2, 3]]))
736-
with pytest.raises(ValueError, match="`aweights` must be 1-D"):
737-
_ = cov(m, aweights=xp.asarray([[0.1, 0.2, 0.3]]))
738726

739727

740728
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)

0 commit comments

Comments
 (0)