@@ -197,6 +197,55 @@ def expand_dims(
197197 return _funcs .expand_dims (a , axis = axis , xp = xp )
198198
199199
200+ def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType | None = None ) -> Array :
201+ """
202+ Recursively expand the dimension of an array to at least `ndim`.
203+
204+ Parameters
205+ ----------
206+ x : array
207+ Input array.
208+ ndim : int
209+ The minimum number of dimensions for the result.
210+ xp : array_namespace, optional
211+ The standard-compatible namespace for `x`. Default: infer.
212+
213+ Returns
214+ -------
215+ array
216+ An array with ``res.ndim`` >= `ndim`.
217+ If ``x.ndim`` >= `ndim`, `x` is returned.
218+ If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
219+ until ``res.ndim`` equals `ndim`.
220+
221+ Examples
222+ --------
223+ >>> import array_api_strict as xp
224+ >>> import array_api_extra as xpx
225+ >>> x = xp.asarray([1])
226+ >>> xpx.atleast_nd(x, ndim=3, xp=xp)
227+ Array([[[1]]], dtype=array_api_strict.int64)
228+
229+ >>> x = xp.asarray([[[1, 2],
230+ ... [3, 4]]])
231+ >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
232+ True
233+ """
234+ if xp is None :
235+ xp = array_namespace (x )
236+
237+ if 1 <= ndim <= 3 and (
238+ is_numpy_namespace (xp )
239+ or is_jax_namespace (xp )
240+ or is_dask_namespace (xp )
241+ or is_cupy_namespace (xp )
242+ or is_torch_namespace (xp )
243+ ):
244+ return getattr (xp , f"atleast_{ ndim } d" )(x )
245+
246+ return _funcs .atleast_nd (x , ndim = ndim , xp = xp )
247+
248+
200249def isclose (
201250 a : Array | complex ,
202251 b : Array | complex ,
0 commit comments