We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ab421e9 commit f6a216dCopy full SHA for f6a216d
1 file changed
src/array_api_extra/_lib/_funcs.py
@@ -392,12 +392,11 @@ def one_hot(
392
# specification.
393
msg = "x must have a concrete size."
394
raise TypeError(msg)
395
- # TODO: Benchmark whether this is faster on the numpy backend:
+ # TODO: Benchmark whether this is faster on the Numpy backend:
396
# x_flattened = xp.reshape(x, (-1,))
397
# out = xp.zeros((x.size, num_classes), dtype=dtype, device=_compat.device(x))
398
- # at(out)[xp.arange(x_size), x_flattened].set(1)
399
- # if x.ndim != 1:
400
- # out = xp.reshape(out, (*x.shape, num_classes))
+ # out = at(out)[xp.arange(x_size), x_flattened].set(1)
+ # out = xp.reshape(out, (*x.shape, num_classes))
401
out = x[..., None] == xp.arange(
402
num_classes, dtype=x.dtype, device=_compat.device(x)
403
)
0 commit comments