@@ -75,13 +75,12 @@ def quantile(
7575
7676 n = xp .asarray (y .shape [- 1 ], dtype = dtype , device = _compat .device (y ))
7777
78- res = _quantile_hf (y , q_arr , n , method , xp )
78+ # Validate that q values are in the range [0, 1]
79+ if xp .any ((q_arr < 0 ) | (q_arr > 1 )):
80+ msg = "`q` must contain values between 0 and 1 inclusive."
81+ raise ValueError (msg )
7982
80- # Handle NaN output for invalid q values
81- p_mask = (q_arr > 1 ) | (q_arr < 0 ) | xp .isnan (q_arr )
82- if xp .any (p_mask ):
83- res = xp .asarray (res , copy = True )
84- res = at (res , p_mask ).set (xp .nan )
83+ res = _quantile_hf (y , q_arr , n , method , xp )
8584
8685 # Reshape per axis/keepdims
8786 if axis_none and keepdims :
@@ -97,9 +96,10 @@ def quantile(
9796 res = xp .squeeze (res , axis = axis )
9897
9998 # For scalar q, ensure we return a scalar result
100- if q_is_scalar and hasattr (res , "shape" ) and res .shape != ():
101- res = res [()]
102-
99+ # if q_is_scalar and hasattr(res, "shape") and res.shape != ():
100+ # res = res[()]
101+ if res .ndim == 0 :
102+ return res [()]
103103 return res
104104
105105
@@ -121,7 +121,10 @@ def _quantile_hf(
121121 m = ms [method ]
122122
123123 jg = p * n + m - 1
124- j = xp .astype (jg // 1 , xp .int64 ) # Convert to integer
124+ # Convert both to integers, the type of j and n must be the same
125+ # for us to be able to `xp.clip` them.
126+ j = xp .astype (jg // 1 , xp .int64 )
127+ n = xp .astype (n , xp .int64 )
125128 g = jg % 1
126129
127130 if method == "inverted_cdf" :
0 commit comments