|
8 | 8 |
|
9 | 9 | from ._at import at |
10 | 10 | from ._utils import _compat, _helpers |
11 | | -from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array |
| 11 | +from ._utils._compat import ( |
| 12 | + array_namespace, |
| 13 | + is_dask_namespace, |
| 14 | + is_jax_array, |
| 15 | + is_torch_array, |
| 16 | +) |
12 | 17 | from ._utils._helpers import ( |
13 | 18 | asarrays, |
14 | 19 | capabilities, |
15 | 20 | eager_shape, |
16 | 21 | meta_namespace, |
17 | 22 | ndindex, |
18 | 23 | ) |
19 | | -from ._utils._typing import Array |
| 24 | +from ._utils._typing import Array, DType |
20 | 25 |
|
21 | 26 | __all__ = [ |
22 | 27 | "apply_where", |
@@ -375,6 +380,39 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
375 | 380 | return xp.squeeze(c, axis=axes) |
376 | 381 |
|
377 | 382 |
|
| 383 | +def one_hot( |
| 384 | + x: Array, |
| 385 | + /, |
| 386 | + num_classes: int, |
| 387 | + *, |
| 388 | + dtype: DType | None = None, |
| 389 | + axis: int = -1, |
| 390 | + xp: ModuleType | None = None, |
| 391 | +) -> Array: |
| 392 | + if xp is None: |
| 393 | + xp = array_namespace(x) |
| 394 | + if is_jax_array(x): |
| 395 | + from jax.nn import one_hot |
| 396 | + if dtype is None: |
| 397 | + dtype = xp.float_ |
| 398 | + return one_hot(x, num_classes, dtype=dtype, axis=axis) |
| 399 | + if is_torch_array(x): |
| 400 | + from torch.nn.functional import one_hot |
| 401 | + out = one_hot(x, num_classes) |
| 402 | + if dtype is not None: |
| 403 | + out = xp.astype(out, dtype) |
| 404 | + else: |
| 405 | + if dtype is None: |
| 406 | + dtype = xp.float64 |
| 407 | + out = xp.zeros((x.size, num_classes), dtype=dtype) |
| 408 | + at(out)[xp.arange(x.size), xp.reshape(x, -1)].set(1) |
| 409 | + if x.ndim != 1: |
| 410 | + out = xp.reshape(out, (*x.shape, num_classes)) |
| 411 | + if axis != -1: |
| 412 | + out = xp.moveaxis(out, -1, axis) |
| 413 | + return out |
| 414 | + |
| 415 | + |
378 | 416 | def create_diagonal( |
379 | 417 | x: Array, /, *, offset: int = 0, xp: ModuleType | None = None |
380 | 418 | ) -> Array: |
|
0 commit comments