Skip to content

Commit 3ec5131

Browse files
committed
Break
1 parent f6a216d commit 3ec5131

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,8 @@ def one_hot(
397397
# out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398398
# out = at(out)[xp.arange(x_size), x_flattened].set(1)
399399
# out = xp.reshape(out, (*x.shape, num_classes))
400-
out = x[..., None] == xp.arange(
401-
num_classes, dtype=x.dtype, device=_compat.device(x)
402-
)
400+
range_num_classes = xp.arange(num_classes, dtype=x.dtype, device=_compat.device(x))
401+
out = x[..., xp.newaxis] == range_num_classes
403402
return xp.astype(out, dtype)
404403

405404

0 commit comments

Comments
 (0)