Skip to content

Commit f6a216d

Browse files
committed
Comment
1 parent ab421e9 commit f6a216d

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,12 +392,11 @@ def one_hot(
392392
# specification.
393393
msg = "x must have a concrete size."
394394
raise TypeError(msg)
395-
# TODO: Benchmark whether this is faster on the numpy backend:
395+
# TODO: Benchmark whether this is faster on the Numpy backend:
396396
# x_flattened = xp.reshape(x, (-1,))
397397
# out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398-
# at(out)[xp.arange(x_size), x_flattened].set(1)
399-
# if x.ndim != 1:
400-
# out = xp.reshape(out, (*x.shape, num_classes))
398+
# out = at(out)[xp.arange(x_size), x_flattened].set(1)
399+
# out = xp.reshape(out, (*x.shape, num_classes))
401400
out = x[..., None] == xp.arange(
402401
num_classes, dtype=x.dtype, device=_compat.device(x)
403402
)

0 commit comments

Comments
 (0)