Skip to content

Commit 06b4007

Browse files
MNT: rename weights params to fweights/aweights
1 parent 98f216a commit 06b4007

3 files changed

Lines changed: 28 additions & 26 deletions

File tree

src/array_api_extra/_delegation.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def cov(
8787
*,
8888
axis: int = -1,
8989
correction: int | float = 1,
90-
frequency_weights: Array | None = None,
91-
weights: Array | None = None,
90+
fweights: Array | None = None,
91+
aweights: Array | None = None,
9292
xp: ModuleType | None = None,
9393
) -> Array:
9494
"""
@@ -126,12 +126,12 @@ def cov(
126126
``correction`` in ``numpy.var``/``std`` and ``torch.cov``.
127127
fweights : array, optional
128128
1-D array of integer frequency weights: the number of times each
129-
observation is repeated. Corresponds to ``fweights`` in
129+
observation is repeated. Same as ``fweights`` in
130130
``numpy.cov``/``torch.cov``.
131131
aweights : array, optional
132132
1-D array of observation-vector weights (analytic weights). Larger
133-
values mark more important observations. Corresponds to
134-
``aweights`` in ``numpy.cov``/``torch.cov``.
133+
values mark more important observations. Same as ``aweights`` in
134+
``numpy.cov``/``torch.cov``.
135135
xp : array_namespace, optional
136136
The standard-compatible namespace for `m`. Default: infer.
137137
@@ -149,8 +149,8 @@ def cov(
149149
numpy.cov(m, rowvar=False) -> cov(m, axis=-2)
150150
numpy.cov(m, bias=True) -> cov(m, correction=0)
151151
numpy.cov(m, ddof=k) -> cov(m, correction=k)
152-
numpy.cov(m, fweights=f) -> cov(m, frequency_weights=f)
153-
numpy.cov(m, aweights=a) -> cov(m, weights=a)
152+
numpy.cov(m, fweights=f) -> cov(m, fweights=f)
153+
numpy.cov(m, aweights=a) -> cov(m, aweights=a)
154154
155155
Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective
156156
degrees of freedom is only emitted on the unweighted path. The
@@ -226,12 +226,12 @@ def cov(
226226
# requires integer `correction`. For non-integer-valued `correction`,
227227
# fall through to the generic implementation.
228228
integer_correction = isinstance(correction, int) or correction.is_integer()
229-
has_weights = frequency_weights is not None or weights is not None
229+
has_weights = fweights is not None or aweights is not None
230230

231231
if m.ndim <= 2 and integer_correction:
232232
if is_torch_namespace(xp):
233-
fw = None if frequency_weights is None else xp.asarray(frequency_weights)
234-
aw = None if weights is None else xp.asarray(weights)
233+
fw = None if fweights is None else xp.asarray(fweights)
234+
aw = None if aweights is None else xp.asarray(aweights)
235235
return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw)
236236
# `dask.array.cov` forces `.compute()` whenever weights are given:
237237
# its internal `if fact <= 0` check on a lazy 0-D scalar triggers
@@ -246,15 +246,15 @@ def cov(
246246
return xp.cov(
247247
m,
248248
ddof=int(correction),
249-
fweights=frequency_weights,
250-
aweights=weights,
249+
fweights=fweights,
250+
aweights=aweights,
251251
)
252252

253253
return _funcs.cov(
254254
m,
255255
correction=correction,
256-
frequency_weights=frequency_weights,
257-
weights=weights,
256+
fweights=fweights,
257+
aweights=aweights,
258258
xp=xp,
259259
)
260260

src/array_api_extra/_lib/_funcs.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ def cov(
286286
/,
287287
*,
288288
correction: int | float = 1,
289-
frequency_weights: Array | None = None,
290-
weights: Array | None = None,
289+
fweights: Array | None = None,
290+
aweights: Array | None = None,
291291
xp: ModuleType,
292292
) -> Array: # numpydoc ignore=PR01,RT01
293293
"""See docstring in array_api_extra._delegation."""
@@ -300,9 +300,11 @@ def cov(
300300
m = xp.astype(m, dtype)
301301

302302
fw = None
303-
if frequency_weights is not None:
304-
fw = xp.astype(xp.asarray(frequency_weights), dtype)
305-
aw = None if weights is None else xp.astype(xp.asarray(weights), dtype)
303+
if fweights is not None:
304+
fw = xp.astype(xp.asarray(fweights), dtype)
305+
aw = None
306+
if aweights is not None:
307+
aw = xp.astype(xp.asarray(aweights), dtype)
306308
if fw is None and aw is None:
307309
w = None
308310
elif fw is None:

tests/test_funcs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -665,15 +665,15 @@ def test_frequency_weights(self, xp: ModuleType):
665665
m = rng.random((3, 10))
666666
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64)
667667
ref = np.cov(m, fweights=fw)
668-
res = cov(xp.asarray(m), frequency_weights=xp.asarray(fw))
668+
res = cov(xp.asarray(m), fweights=xp.asarray(fw))
669669
xp_assert_close(res, xp.asarray(ref))
670670

671671
def test_weights(self, xp: ModuleType):
672672
rng = np.random.default_rng(20260417)
673673
m = rng.random((3, 10))
674674
aw = rng.random(10)
675675
ref = np.cov(m, aweights=aw)
676-
res = cov(xp.asarray(m), weights=xp.asarray(aw))
676+
res = cov(xp.asarray(m), aweights=xp.asarray(aw))
677677
xp_assert_close(res, xp.asarray(ref))
678678

679679
def test_both_weights(self, xp: ModuleType):
@@ -686,8 +686,8 @@ def test_both_weights(self, xp: ModuleType):
686686
res = cov(
687687
xp.asarray(m),
688688
correction=correction,
689-
frequency_weights=xp.asarray(fw),
690-
weights=xp.asarray(aw),
689+
fweights=xp.asarray(fw),
690+
aweights=xp.asarray(aw),
691691
)
692692
xp_assert_close(res, xp.asarray(ref))
693693

@@ -697,7 +697,7 @@ def test_batch_with_weights(self, xp: ModuleType):
697697
n_var, n_obs = 3, 15
698698
m = rng.random((*batch_shape, n_var, n_obs))
699699
aw = rng.random(n_obs)
700-
res = cov(xp.asarray(m), weights=xp.asarray(aw))
700+
res = cov(xp.asarray(m), aweights=xp.asarray(aw))
701701
ref_list = [np.cov(m_, aweights=aw) for m_ in np.reshape(m, (-1, n_var, n_obs))]
702702
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
703703
xp_assert_close(res, xp.asarray(ref))
@@ -713,8 +713,8 @@ def test_axis_with_weights(self, xp: ModuleType):
713713
res = cov(
714714
xp.asarray(m),
715715
axis=-2,
716-
frequency_weights=xp.asarray(fw),
717-
weights=xp.asarray(aw),
716+
fweights=xp.asarray(fw),
717+
aweights=xp.asarray(aw),
718718
)
719719
xp_assert_close(res, xp.asarray(ref))
720720

0 commit comments

Comments
 (0)