Skip to content

Commit 8b5c471

Browse files
ENH: validate weights shape in cov
1 parent 06b4007 commit 8b5c471

2 files changed

Lines changed: 29 additions & 0 deletions

File tree

src/array_api_extra/_delegation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,22 @@ def cov(
222222
if m.ndim >= 2 and axis not in (-1, m.ndim - 1):
223223
m = xp.moveaxis(m, axis, -1)
224224

225+
# Validate weight shapes (eager metadata, lazy-safe). Value-based
226+
# checks (non-negative, integer dtype) are intentionally skipped so
227+
# lazy backends don't trigger compute -- same tradeoff as dask.cov.
228+
n_obs = m.shape[-1]
229+
for name, w in (("fweights", fweights), ("aweights", aweights)):
230+
if w is None:
231+
continue
232+
if w.ndim != 1:
233+
msg = f"`{name}` must be 1-D, got ndim={w.ndim}"
234+
raise ValueError(msg)
235+
if w.shape[0] != n_obs:
236+
msg = (
237+
f"`{name}` has length {w.shape[0]} but `m` has {n_obs} observations"
238+
)
239+
raise ValueError(msg)
240+
225241
# `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov`
226242
# requires integer `correction`. For non-integer-valued `correction`,
227243
# fall through to the generic implementation.

tests/test_funcs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,19 @@ def test_axis_out_of_bounds(self, xp: ModuleType):
723723
with pytest.raises(IndexError):
724724
_ = cov(m, axis=5)
725725

726+
def test_weights_shape_validation(self, xp: ModuleType):
727+
m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
728+
# Wrong length.
729+
with pytest.raises(ValueError, match="`fweights` has length"):
730+
_ = cov(m, fweights=xp.asarray([1, 2]))
731+
with pytest.raises(ValueError, match="`aweights` has length"):
732+
_ = cov(m, aweights=xp.asarray([0.1, 0.2]))
733+
# Wrong ndim.
734+
with pytest.raises(ValueError, match="`fweights` must be 1-D"):
735+
_ = cov(m, fweights=xp.asarray([[1, 2, 3]]))
736+
with pytest.raises(ValueError, match="`aweights` must be 1-D"):
737+
_ = cov(m, aweights=xp.asarray([[0.1, 0.2, 0.3]]))
738+
726739

727740
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
728741
class TestOneHot:

0 commit comments

Comments
 (0)