Skip to content

Commit a3dcce9

Browse files
committed
Add Generator
1 parent 7b4887e commit a3dcce9

2 files changed

Lines changed: 159 additions & 0 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
sinc,
1717
)
1818
from ._lib._lazy import lazy_apply
19+
from ._random import Generator, JaxGenerator, TorchGenerator
1920

2021
__version__ = "0.8.1.dev0"
2122

2223
# pylint: disable=duplicate-code
2324
__all__ = [
25+
"Generator",
26+
"JaxGenerator",
27+
"TorchGenerator",
2428
"__version__",
2529
"apply_where",
2630
"at",

src/array_api_extra/_random.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from types import Any, ModuleType
2+
from typing import TYPE_CHECKING
3+
4+
from ._lib._utils._compat import (
5+
is_jax_namespace,
6+
is_torch_namespace,
7+
)
8+
from ._lib._utils._typing import Array, Device, DType
9+
10+
if TYPE_CHECKING:
11+
import jax
12+
import torch
13+
14+
15+
class Generator:
16+
@classmethod
17+
def create(cls, seed: int, device: Device | None = None) -> "Generator":
18+
raise NotImplementedError
19+
20+
def get_state(self) -> Any:
21+
raise NotImplementedError
22+
23+
def set_state(self, state: object):
24+
raise NotImplementedError
25+
26+
def uniform(
27+
self,
28+
shape: tuple[int, ...] = (),
29+
dtype: DType | None = None,
30+
minval: float | Array = 0.0,
31+
maxval: float | Array = 1.0,
32+
) -> Array:
33+
raise NotImplementedError
34+
35+
36+
class JaxGenerator(Generator):
37+
def __init__(self, key: Array, count: Array | None = None) -> None:
38+
super().__init__()
39+
import jax
40+
import jax.numpy as jnp
41+
42+
if count is None:
43+
count = jnp.zeros((), dtype=jnp.uint32)
44+
else:
45+
assert isinstance(count, jax.Array)
46+
assert count.ndim == 0
47+
assert isinstance(key, jax.Array)
48+
self._key = key
49+
self._count = count
50+
51+
@classmethod
52+
def create(cls, seed: int, device: Device | None = None) -> "JaxGenerator":
53+
import jax.random as jr
54+
55+
key = jr.key(seed).to_device(device)
56+
return JaxGenerator(key)
57+
58+
def get_state(self) -> Any:
59+
import jax.random as jr
60+
61+
return (jr.key_data(self._key), self._count)
62+
63+
def set_state(self, state: object):
64+
import jax
65+
import jax.random as jr
66+
67+
assert isinstance(state, tuple)
68+
key_data, count = state
69+
assert isinstance(key_data, jax.Array)
70+
assert isinstance(count, int)
71+
self._key = jr.wrap_key_data(key_data)
72+
self._count = count
73+
74+
def key(self) -> jax.Array:
75+
"""This should be passed to traced functions instead of the generator."""
76+
import jax.random as jr
77+
78+
key = jr.fold_in(self._key, self._count)
79+
self._count += 1
80+
return key
81+
82+
def fork(self, samples: int) -> Array:
83+
"""This should be passed to vmapped functions instead of the generator."""
84+
import jax.random as jr
85+
86+
return jr.split(self.key(), samples)
87+
88+
def uniform(
89+
self,
90+
shape: tuple[int, ...] = (),
91+
dtype: DType | None = None,
92+
minval: float | Array = 0.0,
93+
maxval: float | Array = 1.0,
94+
) -> Array:
95+
import jax
96+
import jax.random as jr
97+
98+
if dtype is None:
99+
dtype = float
100+
assert isinstance(minval, float | jax.Array)
101+
assert isinstance(maxval, float | jax.Array)
102+
return jr.uniform(self.key(), shape, dtype, minval, maxval)
103+
104+
105+
class TorchGenerator(Generator):
106+
def __init__(self, generator: "torch.Generator") -> None:
107+
super().__init__()
108+
self._generator = generator
109+
110+
@classmethod
111+
def create(cls, seed: int, device: Device | None = None) -> "TorchGenerator":
112+
import torch
113+
114+
device = "cpu" if device is None else device
115+
generator = torch.Generator(device)
116+
generator = generator.manual_seed(seed)
117+
return TorchGenerator(generator)
118+
119+
def get_state(self) -> Any:
120+
return self._generator.get_state()
121+
122+
def set_state(self, state: object):
123+
import torch
124+
assert isinstance(state, torch.Tensor)
125+
self._generator.set_state(state)
126+
127+
def uniform(
128+
self,
129+
shape: tuple[int, ...] = (),
130+
dtype: DType | None = None,
131+
minval: float | Array = 0.0,
132+
maxval: float | Array = 1.0,
133+
) -> Array:
134+
import torch
135+
136+
u = torch.rand(*shape, generator=self._generator, dtype=dtype)
137+
return u * (maxval - minval) + minval
138+
139+
140+
def create_generator(
141+
xp: ModuleType,
142+
seed: int,
143+
*,
144+
device: Device | None = None,
145+
) -> Generator:
146+
cls = (
147+
JaxGenerator
148+
if is_jax_namespace(xp)
149+
else TorchGenerator
150+
if is_torch_namespace(xp)
151+
else None
152+
)
153+
if cls is None:
154+
raise TypeError
155+
return cls.create(seed, device)

0 commit comments

Comments
 (0)