|
7 | 7 |
|
8 | 8 | from __future__ import annotations |
9 | 9 |
|
| 10 | +import dataclasses |
| 11 | + |
| 12 | +from pysatl_core.families.registry_graph import BinaryOperationType |
| 13 | + |
10 | 14 | __author__ = "Leonid Elkin, Mikhail Mikhailov" |
11 | 15 | __copyright__ = "Copyright (c) 2025 PySATL project" |
12 | 16 | __license__ = "SPDX-License-Identifier: MIT" |
13 | 17 |
|
| 18 | +from types import NotImplementedType |
14 | 19 | from typing import TYPE_CHECKING |
15 | 20 |
|
16 | 21 | from pysatl_core.distributions.distribution import _KEEP, Distribution |
@@ -193,3 +198,112 @@ def sample(self, n: int, **options: Any) -> NumericArray: |
193 | 198 | the sampling strategy. |
194 | 199 | """ |
195 | 200 | return self.sampling_strategy.sample(n, distr=self, **options) |
| 201 | + |
| 202 | + def _try_to_transform_with_optimization( |
| 203 | + self, other: ParametricFamilyDistribution, kind: BinaryOperationType |
| 204 | + ) -> None | ParametricFamilyDistribution: |
| 205 | + registry = ParametricFamilyRegister() |
| 206 | + transform_result = registry.find_binary_transformation( |
| 207 | + self.family_name, other.family_name, kind |
| 208 | + ) |
| 209 | + if transform_result is None: |
| 210 | + return None |
| 211 | + |
| 212 | + family, transform_parametrization = transform_result |
| 213 | + new_parametrization = transform_parametrization( |
| 214 | + self.parametrization.transform_to_base_parametrization(), |
| 215 | + other.parametrization.transform_to_base_parametrization(), |
| 216 | + ) |
| 217 | + return family(new_parametrization.name, **dataclasses.asdict(new_parametrization)) # type:ignore[call-overload] |
| 218 | + |
| 219 | + def __add__( |
| 220 | + self, other: object |
| 221 | + ) -> ParametricFamilyDistribution | Distribution | NotImplementedType: |
| 222 | + """Return ``self + other`` for scalar or distribution operands.""" |
| 223 | + if isinstance(other, ParametricFamilyDistribution): |
| 224 | + transformation_result = self._try_to_transform_with_optimization( |
| 225 | + other, BinaryOperationType.ADD |
| 226 | + ) |
| 227 | + if transformation_result is not None: |
| 228 | + return transformation_result |
| 229 | + |
| 230 | + return TransformationOperatorsMixin.__add__(self, other) |
| 231 | + |
| 232 | + def __radd__( |
| 233 | + self, other: object |
| 234 | + ) -> ParametricFamilyDistribution | Distribution | NotImplementedType: |
| 235 | + """Return ``other + self`` for scalar or distribution operands.""" |
| 236 | + if isinstance(other, ParametricFamilyDistribution): |
| 237 | + transformation_result = other._try_to_transform_with_optimization( |
| 238 | + self, BinaryOperationType.ADD |
| 239 | + ) |
| 240 | + if transformation_result is not None: |
| 241 | + return transformation_result |
| 242 | + |
| 243 | + return TransformationOperatorsMixin.__radd__(self, other) |
| 244 | + |
| 245 | + def __sub__(self, other: object) -> Distribution | NotImplementedType: |
| 246 | + """Return ``self - other`` for scalar or distribution operands.""" |
| 247 | + if isinstance(other, ParametricFamilyDistribution): |
| 248 | + transformation_result = self._try_to_transform_with_optimization( |
| 249 | + other, BinaryOperationType.SUB |
| 250 | + ) |
| 251 | + if transformation_result is not None: |
| 252 | + return transformation_result |
| 253 | + |
| 254 | + return TransformationOperatorsMixin.__sub__(self, other) |
| 255 | + |
| 256 | + def __rsub__(self, other: object) -> Distribution | NotImplementedType: |
| 257 | + """Return ``other - self`` for scalar or distribution operands.""" |
| 258 | + if isinstance(other, ParametricFamilyDistribution): |
| 259 | + transformation_result = other._try_to_transform_with_optimization( |
| 260 | + self, BinaryOperationType.SUB |
| 261 | + ) |
| 262 | + if transformation_result is not None: |
| 263 | + return transformation_result |
| 264 | + |
| 265 | + return TransformationOperatorsMixin.__rsub__(self, other) |
| 266 | + |
| 267 | + def __mul__(self, other: object) -> Distribution | NotImplementedType: |
| 268 | + """Return ``self * other`` for scalar or distribution operands.""" |
| 269 | + if isinstance(other, ParametricFamilyDistribution): |
| 270 | + transformation_result = self._try_to_transform_with_optimization( |
| 271 | + other, BinaryOperationType.MUL |
| 272 | + ) |
| 273 | + if transformation_result is not None: |
| 274 | + return transformation_result |
| 275 | + |
| 276 | + return TransformationOperatorsMixin.__mul__(self, other) |
| 277 | + |
| 278 | + def __rmul__(self, other: object) -> Distribution | NotImplementedType: |
| 279 | + """Return ``other * self`` for scalar or distribution operands.""" |
| 280 | + if isinstance(other, ParametricFamilyDistribution): |
| 281 | + transformation_result = other._try_to_transform_with_optimization( |
| 282 | + self, BinaryOperationType.MUL |
| 283 | + ) |
| 284 | + if transformation_result is not None: |
| 285 | + return transformation_result |
| 286 | + |
| 287 | + return TransformationOperatorsMixin.__rmul__(self, other) |
| 288 | + |
| 289 | + def __truediv__(self, other: object) -> Distribution | NotImplementedType: |
| 290 | + """Return ``self / other`` for scalar or distribution operands.""" |
| 291 | + if isinstance(other, ParametricFamilyDistribution): |
| 292 | + transformation_result = self._try_to_transform_with_optimization( |
| 293 | + other, BinaryOperationType.DIV |
| 294 | + ) |
| 295 | + if transformation_result is not None: |
| 296 | + return transformation_result |
| 297 | + |
| 298 | + return TransformationOperatorsMixin.__truediv__(self, other) |
| 299 | + |
| 300 | + def __rtruediv__(self, other: object) -> Distribution | NotImplementedType: |
| 301 | + """Return ``other / self`` for distribution operands.""" |
| 302 | + if isinstance(other, ParametricFamilyDistribution): |
| 303 | + transformation_result = other._try_to_transform_with_optimization( |
| 304 | + self, BinaryOperationType.DIV |
| 305 | + ) |
| 306 | + if transformation_result is not None: |
| 307 | + return transformation_result |
| 308 | + |
| 309 | + return TransformationOperatorsMixin.__rtruediv__(self, other) |
0 commit comments