Skip to content

Commit 5e72e78

Browse files
committed
ENH: arange default dtype
1 parent 1470505 commit 5e72e78

2 files changed

Lines changed: 18 additions & 2 deletions

File tree

array_api_strict/_creation_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def arange(
127127

128128
_check_device(device)
129129
_check_valid_dtype(dtype, device)
130+
if dtype is None:
131+
if any(isinstance(x, float) for x in (start, stop, step)):
132+
dtype = get_default_dtypes(device)["real floating"]
133+
else:
134+
dtype = get_default_dtypes(device)["integral"]
130135

131136
return Array._new(
132137
np.arange(start, stop, step, dtype=_np_dtype(dtype)),

array_api_strict/tests/test_creation_functions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def test_eye(self):
281281
with pytest.raises((TypeError, ValueError)):
282282
eye(3, device=device, dtype=float64)
283283

284-
285284
def test_linspace(self):
286285
device = Device('F32_device')
287286

@@ -294,11 +293,23 @@ def test_linspace(self):
294293
with pytest.raises((TypeError, ValueError)):
295294
linspace(1, 10, 11, device=device, dtype=float64)
296295

296+
def test_arange(self):
297+
device = Device('F32_device')
298+
299+
a = arange(0, 10, 1, device=device)
300+
assert a.dtype == self.info.default_dtypes(device=device)["integral"]
301+
302+
a = arange(0.0, 10, 1, device=device)
303+
assert a.dtype == self.info.default_dtypes(device=device)["real floating"]
304+
305+
with pytest.raises((TypeError, ValueError)):
306+
arange(0, 10, 1, device=device, dtype=float64)
297307

308+
with pytest.raises((TypeError, ValueError)):
309+
arange(0.0, 10, 1, device=device, dtype=float64)
298310

299311
# TODO:
300312
# def asarray(
301-
# def arange(
302313

303314

304315
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])

0 commit comments

Comments
 (0)