We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 13a5507 commit 007a61fCopy full SHA for 007a61f
1 file changed
src/array_api_extra/_lib/_quantile.py
@@ -91,13 +91,9 @@ def quantile(
91
# Move axis back to original position
92
res = xp.moveaxis(res, -1, axis)
93
94
- # Handle keepdims
95
if not keepdims and res.shape[axis] == 1:
96
res = xp.squeeze(res, axis=axis)
97
98
- # For scalar q, ensure we return a scalar result
99
- # if q_is_scalar and hasattr(res, "shape") and res.shape != ():
100
- # res = res[()]
101
if res.ndim == 0:
102
return res[()]
103
return res
@@ -148,6 +144,7 @@ def _quantile_hf(
148
144
jp1 = xp.broadcast_to(jp1, broadcast_shape)
149
145
g = xp.broadcast_to(g, broadcast_shape)
150
146
151
- return (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
147
+ res = (1 - g) * xp.take_along_axis(y, j, axis=-1) + g * xp.take_along_axis(
152
y, jp1, axis=-1
153
)
+ return res
0 commit comments