|
| 1 | +""" |
| 2 | +Callback creation utilities for UNU.RAN sampler bindings. |
| 3 | +
|
| 4 | +Provides callback function creation for PDF/PMF evaluation needed during |
| 5 | +UNU.RAN distribution setup. |
| 6 | +""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +__author__ = "Artem Romanyuk" |
| 11 | +__copyright__ = "Copyright (c) 2025 PySATL project" |
| 12 | +__license__ = "SPDX-License-Identifier: MIT" |
| 13 | + |
| 14 | +from collections.abc import Mapping |
| 15 | +from typing import TYPE_CHECKING, Any |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +from pysatl_core.types import CharacteristicName, Kind |
| 20 | + |
| 21 | +if TYPE_CHECKING: |
| 22 | + from pysatl_core.types import Method |
| 23 | + |
| 24 | +_CONT_SIG = "double(double, const struct unur_distr*)" |
| 25 | +_DISCR_SIG = "double(int, const struct unur_distr*)" |
| 26 | + |
| 27 | + |
| 28 | +class UnuranCallback: |
| 29 | + """ |
| 30 | + Factory and registry for CFFI callbacks wired into a UNU.RAN distribution object. |
| 31 | +
|
| 32 | + Creates closures over distribution characteristic functions (PDF, dPDF, CDF, |
| 33 | + PPF, PMF) and registers them with the corresponding UNU.RAN setter. All live |
| 34 | + callback objects are kept in an internal list so the garbage collector cannot |
| 35 | + reclaim them while the UNU.RAN generator is alive. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + unuran_distr: Any, |
| 41 | + kind: Kind, |
| 42 | + lib: Any, |
| 43 | + ffi: Any, |
| 44 | + characteristics: Mapping[CharacteristicName, Method[Any, Any]], |
| 45 | + ) -> None: |
| 46 | + """ |
| 47 | + Parameters |
| 48 | + ---------- |
| 49 | + unuran_distr : Any |
| 50 | + CFFI pointer to the UNU.RAN distribution object. |
| 51 | + kind : Kind |
| 52 | + Whether the distribution is continuous or discrete. |
| 53 | + lib : Any |
| 54 | + CFFI library handle. |
| 55 | + ffi : Any |
| 56 | + CFFI FFI instance used to create callbacks. |
| 57 | + characteristics : Mapping[CharacteristicName, Method[Any, Any]] |
| 58 | + Map of available characteristic names to their callables. |
| 59 | + """ |
| 60 | + self._unuran_distr = unuran_distr |
| 61 | + self._kind = kind |
| 62 | + self._lib = lib |
| 63 | + self._ffi = ffi |
| 64 | + self._characteristics = dict(characteristics) |
| 65 | + self._callbacks: list[Any] = [] |
| 66 | + |
| 67 | + @property |
| 68 | + def callbacks(self) -> list[Any]: |
| 69 | + """Live CFFI callback objects that must remain referenced for UNU.RAN.""" |
| 70 | + return self._callbacks |
| 71 | + |
| 72 | + def _create_callback(self, char_name: CharacteristicName, signature: str) -> Any | None: |
| 73 | + """ |
| 74 | + Create a CFFI callback for the given characteristic and UNU.RAN signature. |
| 75 | +
|
| 76 | + The wrapper converts the scalar input to a 1-D ``float64`` numpy array |
| 77 | + before calling the characteristic, then converts the result to ``float``. |
| 78 | + This ensures compatibility with characteristics that use numpy array |
| 79 | + methods internally, while satisfying UNU.RAN's ``double`` return type. |
| 80 | +
|
| 81 | + Parameters |
| 82 | + ---------- |
| 83 | + char_name: |
| 84 | + Characteristic to look up in the distribution. |
| 85 | + signature: |
| 86 | + CFFI function type string, e.g. ``"double(double, const struct unur_distr*)"``. |
| 87 | +
|
| 88 | + Returns |
| 89 | + ------- |
| 90 | + CFFI callback or None |
| 91 | + None if the characteristic is not available. |
| 92 | + """ |
| 93 | + func = self._characteristics.get(char_name) |
| 94 | + if func is None: |
| 95 | + return None |
| 96 | + |
| 97 | + def cb(x: Any, _: Any) -> float: |
| 98 | + return float(func(np.asarray(x, dtype=float))) |
| 99 | + |
| 100 | + return self._ffi.callback(signature, cb) |
| 101 | + |
| 102 | + def setup_callback(self, func: Any | None, cffi_func: Any, error_text: str) -> None: |
| 103 | + """ |
| 104 | + Register a single CFFI callback with its UNU.RAN setter. |
| 105 | +
|
| 106 | + Parameters |
| 107 | + ---------- |
| 108 | + func : Any or None |
| 109 | + CFFI callback to register. Skipped silently when ``None``. |
| 110 | + cffi_func : Any |
| 111 | + UNU.RAN setter function (e.g. ``unur_distr_cont_set_pdf``). |
| 112 | + error_text : str |
| 113 | + Message template passed to :class:`RuntimeError` on failure; must |
| 114 | + contain one ``{}`` placeholder for the error code. |
| 115 | +
|
| 116 | + Raises |
| 117 | + ------ |
| 118 | + RuntimeError |
| 119 | + If the setter returns a non-zero error code. |
| 120 | + """ |
| 121 | + if func: |
| 122 | + self._callbacks.append(func) |
| 123 | + result = cffi_func(self._unuran_distr, func) |
| 124 | + if result != 0: |
| 125 | + raise RuntimeError(error_text.format(result)) |
| 126 | + |
| 127 | + def setup_continuous_callbacks(self) -> None: |
| 128 | + """ |
| 129 | + Set up callbacks for continuous distributions. |
| 130 | +
|
| 131 | + Configures PDF, dPDF (if available), CDF, and PPF callbacks for the UNURAN |
| 132 | + continuous distribution object. |
| 133 | +
|
| 134 | + Raises |
| 135 | + ------ |
| 136 | + RuntimeError |
| 137 | + If setting any callback fails (non-zero return code). |
| 138 | +
|
| 139 | + Notes |
| 140 | + ----- |
| 141 | + All created callbacks are appended to ``_callbacks`` list to prevent |
| 142 | + garbage collection. Only available callbacks are set (missing |
| 143 | + characteristics are skipped). |
| 144 | + """ |
| 145 | + self.setup_callback( |
| 146 | + self._create_callback(CharacteristicName.PDF, _CONT_SIG), |
| 147 | + self._lib.unur_distr_cont_set_pdf, |
| 148 | + "Failed to set PDF callback (error code: {})", |
| 149 | + ) |
| 150 | + self.setup_callback( |
| 151 | + self._create_callback(CharacteristicName.DPDF, _CONT_SIG), |
| 152 | + self._lib.unur_distr_cont_set_dpdf, |
| 153 | + "Failed to set dPDF callback (error code: {})", |
| 154 | + ) |
| 155 | + self.setup_callback( |
| 156 | + self._create_callback(CharacteristicName.CDF, _CONT_SIG), |
| 157 | + self._lib.unur_distr_cont_set_cdf, |
| 158 | + "Failed to set CDF callback (error code: {})", |
| 159 | + ) |
| 160 | + self.setup_callback( |
| 161 | + self._create_callback(CharacteristicName.PPF, _CONT_SIG), |
| 162 | + self._lib.unur_distr_cont_set_invcdf, |
| 163 | + "Failed to set PPF callback (error code: {})", |
| 164 | + ) |
| 165 | + |
| 166 | + def _create_indexed_pmf_callback(self, points: np.ndarray) -> Any | None: |
| 167 | + """ |
| 168 | + Create a CFFI PMF callback that maps integer indices to support values. |
| 169 | +
|
| 170 | + UNU.RAN will call this with indices ``0, 1, ..., n-1``. The callback |
| 171 | + converts each index to the corresponding support value via ``points[i]`` |
| 172 | + before evaluating the PMF, enabling DGT to work with arbitrary |
| 173 | + (non-integer, sparse) supports. |
| 174 | +
|
| 175 | + Parameters |
| 176 | + ---------- |
| 177 | + points : np.ndarray |
| 178 | + Array of support values; ``points[i]`` is the actual value for index ``i``. |
| 179 | +
|
| 180 | + Returns |
| 181 | + ------- |
| 182 | + CFFI callback or None |
| 183 | + None if PMF is not available. |
| 184 | + """ |
| 185 | + func = self._characteristics.get(CharacteristicName.PMF) |
| 186 | + if func is None: |
| 187 | + return None |
| 188 | + |
| 189 | + def cb(i: Any, _: Any) -> float: |
| 190 | + return float(func(np.asarray(points[i], dtype=float))) |
| 191 | + |
| 192 | + return self._ffi.callback(_DISCR_SIG, cb) |
| 193 | + |
| 194 | + def setup_discrete_callbacks(self, index_remap_points: np.ndarray | None = None) -> None: |
| 195 | + """ |
| 196 | + Set up callbacks for discrete distributions. |
| 197 | +
|
| 198 | + Configures PMF and CDF callbacks for the UNURAN discrete distribution |
| 199 | + object. |
| 200 | +
|
| 201 | + Parameters |
| 202 | + ---------- |
| 203 | + index_remap_points : np.ndarray or None |
| 204 | + When provided, the PMF callback treats its integer argument as an |
| 205 | + index into this array and evaluates the PMF at ``points[i]`` instead |
| 206 | + of at ``i`` directly. Use this when the UNU.RAN domain is set to |
| 207 | + ``[0, n-1]`` (index space) rather than the actual support values. |
| 208 | +
|
| 209 | + Raises |
| 210 | + ------ |
| 211 | + RuntimeError |
| 212 | + If setting any callback fails (non-zero return code). |
| 213 | +
|
| 214 | + Notes |
| 215 | + ----- |
| 216 | + All created callbacks are appended to ``_callbacks`` list to prevent |
| 217 | + garbage collection. Only available callbacks are set (missing |
| 218 | + characteristics are skipped). |
| 219 | + """ |
| 220 | + if index_remap_points is not None: |
| 221 | + pmf_cb = self._create_indexed_pmf_callback(index_remap_points) |
| 222 | + else: |
| 223 | + pmf_cb = self._create_callback(CharacteristicName.PMF, _DISCR_SIG) |
| 224 | + |
| 225 | + self.setup_callback( |
| 226 | + pmf_cb, |
| 227 | + self._lib.unur_distr_discr_set_pmf, |
| 228 | + "Failed to set PMF callback (error code: {})", |
| 229 | + ) |
| 230 | + self.setup_callback( |
| 231 | + self._create_callback(CharacteristicName.CDF, _DISCR_SIG), |
| 232 | + self._lib.unur_distr_discr_set_cdf, |
| 233 | + "Failed to set CDF callback (error code: {})", |
| 234 | + ) |
0 commit comments