Skip to content

Commit d5d52a8

Browse files
committed
feat: UNURAN Sampler realization
1 parent db20e75 commit d5d52a8

31 files changed

Lines changed: 4022 additions & 405 deletions

src/pysatl_core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .distributions import __all__ as _distr_all
1515
from .families import *
1616
from .families import __all__ as _family_all
17+
from .sampling import *
18+
from .sampling import __all__ as _sampling_all
1719
from .types import *
1820
from .types import __all__ as _types_all
1921

@@ -23,8 +25,10 @@
2325
*_distr_all,
2426
*_family_all,
2527
*_types_all,
28+
*_sampling_all,
2629
]
2730

2831
del _distr_all
2932
del _family_all
3033
del _types_all
34+
del _sampling_all

src/pysatl_core/distributions/distribution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from collections.abc import Mapping
2424
from typing import Any
2525

26-
from pysatl_core.distributions.computation import AnalyticalComputation, Method
26+
from pysatl_core.distributions.computation import AnalyticalComputation
2727
from pysatl_core.distributions.strategies import (
2828
ComputationStrategy,
2929
SamplingStrategy,
@@ -32,6 +32,7 @@
3232
from pysatl_core.types import (
3333
DistributionType,
3434
GenericCharacteristicName,
35+
Method,
3536
)
3637

3738

src/pysatl_core/distributions/strategies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
import numpy as np
1717

1818
from pysatl_core.distributions.registry import characteristic_registry
19-
from pysatl_core.types import CharacteristicName, NumericArray
19+
from pysatl_core.types import CharacteristicName, Method, NumericArray
2020

2121
if TYPE_CHECKING:
2222
from typing import Any
2323

24-
from pysatl_core.distributions.computation import FittedComputationMethod, Method
24+
from pysatl_core.distributions.computation import FittedComputationMethod
2525
from pysatl_core.distributions.distribution import Distribution
2626
from pysatl_core.types import GenericCharacteristicName
2727

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Public sampling interface re-exporting UNURAN-based defaults.
3+
"""
4+
5+
__author__ = "Artem Romanyuk"
6+
__copyright__ = "Copyright (c) 2025 PySATL project"
7+
__license__ = "SPDX-License-Identifier: MIT"
8+
9+
from .unuran import (
10+
DefaultUnuranSampler,
11+
DefaultUnuranSamplingStrategy,
12+
UnuranMethod,
13+
UnuranMethodConfig,
14+
)
15+
16+
SamplingMethod = UnuranMethod
17+
SamplingMethodConfig = UnuranMethodConfig
18+
DefaultSampler = DefaultUnuranSampler
19+
DefaultSamplingStrategy = DefaultUnuranSamplingStrategy
20+
21+
__all__ = [
22+
"SamplingMethod",
23+
"SamplingMethodConfig",
24+
"DefaultSampler",
25+
"DefaultSamplingStrategy",
26+
]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Expose the UNU.RAN sampling API interfaces alongside their default
3+
implementations backed by our C bindings.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
__author__ = "Artem Romanyuk"
9+
__copyright__ = "Copyright (c) 2025 PySATL project"
10+
__license__ = "SPDX-License-Identifier: MIT"
11+
12+
from .core import (
13+
DefaultUnuranSampler,
14+
DefaultUnuranSamplingStrategy,
15+
)
16+
from .method_config import (
17+
UnuranMethod,
18+
UnuranMethodConfig,
19+
)
20+
21+
__all__ = [
22+
"UnuranMethod",
23+
"UnuranMethodConfig",
24+
"DefaultUnuranSampler",
25+
"DefaultUnuranSamplingStrategy",
26+
]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Core UNU.RAN sampler and sampling strategy implementations.
2+
3+
Exposes :class:`DefaultUnuranSampler`, which wraps the UNU.RAN C library via
4+
CFFI and automatically selects a sampling method (PINV, NINV, DGT, …) based on
5+
the available distribution characteristics, and
6+
:class:`DefaultUnuranSamplingStrategy`, the high-level strategy that integrates
7+
the sampler with the PySATL distribution protocol.
8+
"""
9+
10+
__author__ = "Artem Romanyuk"
11+
__copyright__ = "Copyright (c) 2025 PySATL project"
12+
__license__ = "SPDX-License-Identifier: MIT"
13+
14+
from .unuran_sampler import DefaultUnuranSampler
15+
from .unuran_sampling_strategy import DefaultUnuranSamplingStrategy
16+
17+
__all__ = [
18+
"DefaultUnuranSampler",
19+
"DefaultUnuranSamplingStrategy",
20+
]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
PySATL Core — UNU.RAN Sampler Internals
3+
=========================================
4+
5+
Internal sub-package that wires together the two building blocks of the
6+
UNU.RAN sampler:
7+
8+
- ``UnuranSamplerInitializer`` — creates the UNU.RAN distribution object,
9+
registers distribution callbacks (PDF, CDF, PMF, …) and initialises
10+
the generator for a pre-selected sampling method (PINV, NINV, DGT, …).
11+
- ``ensure_default_urng`` — registers a NumPy-backed uniform RNG as the
12+
UNU.RAN default URNG so that seeding and reproducibility work through
13+
the standard NumPy interface.
14+
"""
15+
16+
from __future__ import annotations
17+
18+
__author__ = "Artem Romanyuk"
19+
__copyright__ = "Copyright (c) 2025 PySATL project"
20+
__license__ = "SPDX-License-Identifier: MIT"
21+
22+
from .initialization import UnuranSamplerInitializer
23+
from .urng import ensure_default_urng
24+
25+
__all__ = [
26+
"UnuranSamplerInitializer",
27+
"ensure_default_urng",
28+
]
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
"""
2+
Callback creation utilities for UNU.RAN sampler bindings.
3+
4+
Provides callback function creation for PDF/PMF evaluation needed during
5+
UNU.RAN distribution setup.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
__author__ = "Artem Romanyuk"
11+
__copyright__ = "Copyright (c) 2025 PySATL project"
12+
__license__ = "SPDX-License-Identifier: MIT"
13+
14+
from collections.abc import Mapping
15+
from typing import TYPE_CHECKING, Any
16+
17+
import numpy as np
18+
19+
from pysatl_core.types import CharacteristicName, Kind
20+
21+
if TYPE_CHECKING:
22+
from pysatl_core.types import Method
23+
24+
_CONT_SIG = "double(double, const struct unur_distr*)"
25+
_DISCR_SIG = "double(int, const struct unur_distr*)"
26+
27+
28+
class UnuranCallback:
29+
"""
30+
Factory and registry for CFFI callbacks wired into a UNU.RAN distribution object.
31+
32+
Creates closures over distribution characteristic functions (PDF, dPDF, CDF,
33+
PPF, PMF) and registers them with the corresponding UNU.RAN setter. All live
34+
callback objects are kept in an internal list so the garbage collector cannot
35+
reclaim them while the UNU.RAN generator is alive.
36+
"""
37+
38+
def __init__(
39+
self,
40+
unuran_distr: Any,
41+
kind: Kind,
42+
lib: Any,
43+
ffi: Any,
44+
characteristics: Mapping[CharacteristicName, Method[Any, Any]],
45+
) -> None:
46+
"""
47+
Parameters
48+
----------
49+
unuran_distr : Any
50+
CFFI pointer to the UNU.RAN distribution object.
51+
kind : Kind
52+
Whether the distribution is continuous or discrete.
53+
lib : Any
54+
CFFI library handle.
55+
ffi : Any
56+
CFFI FFI instance used to create callbacks.
57+
characteristics : Mapping[CharacteristicName, Method[Any, Any]]
58+
Map of available characteristic names to their callables.
59+
"""
60+
self._unuran_distr = unuran_distr
61+
self._kind = kind
62+
self._lib = lib
63+
self._ffi = ffi
64+
self._characteristics = dict(characteristics)
65+
self._callbacks: list[Any] = []
66+
67+
@property
68+
def callbacks(self) -> list[Any]:
69+
"""Live CFFI callback objects that must remain referenced for UNU.RAN."""
70+
return self._callbacks
71+
72+
def _create_callback(self, char_name: CharacteristicName, signature: str) -> Any | None:
73+
"""
74+
Create a CFFI callback for the given characteristic and UNU.RAN signature.
75+
76+
The wrapper converts the scalar input to a 1-D ``float64`` numpy array
77+
before calling the characteristic, then converts the result to ``float``.
78+
This ensures compatibility with characteristics that use numpy array
79+
methods internally, while satisfying UNU.RAN's ``double`` return type.
80+
81+
Parameters
82+
----------
83+
char_name:
84+
Characteristic to look up in the distribution.
85+
signature:
86+
CFFI function type string, e.g. ``"double(double, const struct unur_distr*)"``.
87+
88+
Returns
89+
-------
90+
CFFI callback or None
91+
None if the characteristic is not available.
92+
"""
93+
func = self._characteristics.get(char_name)
94+
if func is None:
95+
return None
96+
97+
def cb(x: Any, _: Any) -> float:
98+
return float(func(np.asarray(x, dtype=float)))
99+
100+
return self._ffi.callback(signature, cb)
101+
102+
def setup_callback(self, func: Any | None, cffi_func: Any, error_text: str) -> None:
103+
"""
104+
Register a single CFFI callback with its UNU.RAN setter.
105+
106+
Parameters
107+
----------
108+
func : Any or None
109+
CFFI callback to register. Skipped silently when ``None``.
110+
cffi_func : Any
111+
UNU.RAN setter function (e.g. ``unur_distr_cont_set_pdf``).
112+
error_text : str
113+
Message template passed to :class:`RuntimeError` on failure; must
114+
contain one ``{}`` placeholder for the error code.
115+
116+
Raises
117+
------
118+
RuntimeError
119+
If the setter returns a non-zero error code.
120+
"""
121+
if func:
122+
self._callbacks.append(func)
123+
result = cffi_func(self._unuran_distr, func)
124+
if result != 0:
125+
raise RuntimeError(error_text.format(result))
126+
127+
def setup_continuous_callbacks(self) -> None:
128+
"""
129+
Set up callbacks for continuous distributions.
130+
131+
Configures PDF, dPDF (if available), CDF, and PPF callbacks for the UNURAN
132+
continuous distribution object.
133+
134+
Raises
135+
------
136+
RuntimeError
137+
If setting any callback fails (non-zero return code).
138+
139+
Notes
140+
-----
141+
All created callbacks are appended to ``_callbacks`` list to prevent
142+
garbage collection. Only available callbacks are set (missing
143+
characteristics are skipped).
144+
"""
145+
self.setup_callback(
146+
self._create_callback(CharacteristicName.PDF, _CONT_SIG),
147+
self._lib.unur_distr_cont_set_pdf,
148+
"Failed to set PDF callback (error code: {})",
149+
)
150+
self.setup_callback(
151+
self._create_callback(CharacteristicName.DPDF, _CONT_SIG),
152+
self._lib.unur_distr_cont_set_dpdf,
153+
"Failed to set dPDF callback (error code: {})",
154+
)
155+
self.setup_callback(
156+
self._create_callback(CharacteristicName.CDF, _CONT_SIG),
157+
self._lib.unur_distr_cont_set_cdf,
158+
"Failed to set CDF callback (error code: {})",
159+
)
160+
self.setup_callback(
161+
self._create_callback(CharacteristicName.PPF, _CONT_SIG),
162+
self._lib.unur_distr_cont_set_invcdf,
163+
"Failed to set PPF callback (error code: {})",
164+
)
165+
166+
def _create_indexed_pmf_callback(self, points: np.ndarray) -> Any | None:
167+
"""
168+
Create a CFFI PMF callback that maps integer indices to support values.
169+
170+
UNU.RAN will call this with indices ``0, 1, ..., n-1``. The callback
171+
converts each index to the corresponding support value via ``points[i]``
172+
before evaluating the PMF, enabling DGT to work with arbitrary
173+
(non-integer, sparse) supports.
174+
175+
Parameters
176+
----------
177+
points : np.ndarray
178+
Array of support values; ``points[i]`` is the actual value for index ``i``.
179+
180+
Returns
181+
-------
182+
CFFI callback or None
183+
None if PMF is not available.
184+
"""
185+
func = self._characteristics.get(CharacteristicName.PMF)
186+
if func is None:
187+
return None
188+
189+
def cb(i: Any, _: Any) -> float:
190+
return float(func(np.asarray(points[i], dtype=float)))
191+
192+
return self._ffi.callback(_DISCR_SIG, cb)
193+
194+
def setup_discrete_callbacks(self, index_remap_points: np.ndarray | None = None) -> None:
195+
"""
196+
Set up callbacks for discrete distributions.
197+
198+
Configures PMF and CDF callbacks for the UNURAN discrete distribution
199+
object.
200+
201+
Parameters
202+
----------
203+
index_remap_points : np.ndarray or None
204+
When provided, the PMF callback treats its integer argument as an
205+
index into this array and evaluates the PMF at ``points[i]`` instead
206+
of at ``i`` directly. Use this when the UNU.RAN domain is set to
207+
``[0, n-1]`` (index space) rather than the actual support values.
208+
209+
Raises
210+
------
211+
RuntimeError
212+
If setting any callback fails (non-zero return code).
213+
214+
Notes
215+
-----
216+
All created callbacks are appended to ``_callbacks`` list to prevent
217+
garbage collection. Only available callbacks are set (missing
218+
characteristics are skipped).
219+
"""
220+
if index_remap_points is not None:
221+
pmf_cb = self._create_indexed_pmf_callback(index_remap_points)
222+
else:
223+
pmf_cb = self._create_callback(CharacteristicName.PMF, _DISCR_SIG)
224+
225+
self.setup_callback(
226+
pmf_cb,
227+
self._lib.unur_distr_discr_set_pmf,
228+
"Failed to set PMF callback (error code: {})",
229+
)
230+
self.setup_callback(
231+
self._create_callback(CharacteristicName.CDF, _DISCR_SIG),
232+
self._lib.unur_distr_discr_set_cdf,
233+
"Failed to set CDF callback (error code: {})",
234+
)

0 commit comments

Comments
 (0)