|
10 | 10 | from devito.tools import (UnboundedMultiTuple, ctypes_to_cstr, toposort, |
11 | 11 | filter_ordered, transitive_closure, UnboundTuple) |
12 | 12 | from devito.tools.abc import WeakValueCache |
| 13 | +from devito.tools.memoization import _memoized_instances |
13 | 14 | from devito.types.basic import Symbol |
14 | 15 |
|
15 | 16 |
|
@@ -222,6 +223,7 @@ def test_safety(self) -> None: |
222 | 223 | barrier = Barrier(num_threads) |
223 | 224 |
|
224 | 225 | def worker(_: int) -> CacheObject: |
| 226 | + # Wait until all threads can try to access the cache at once |
225 | 227 | barrier.wait() |
226 | 228 | return cache.get_or_create(1, CacheObject) |
227 | 229 |
|
@@ -323,3 +325,126 @@ def worker(index: int): |
323 | 325 | # Ensure all threads received the exception |
324 | 326 | for exc in exceptions: |
325 | 327 | assert isinstance(exc, SupplierException) |
| 328 | + |
| 329 | + |
| 330 | +class TestMemoizedInstances: |
| 331 | + """ |
| 332 | + Tests for the `memoized_instances` decorator. |
| 333 | + """ |
| 334 | + |
| 335 | + def test_memo(self): |
| 336 | + """ |
| 337 | + Tests basic functionality of memoized instances. |
| 338 | + """ |
| 339 | + @_memoized_instances |
| 340 | + class Box: |
| 341 | + def __init__(self, value: int): |
| 342 | + self.value = value |
| 343 | + self.init_calls = getattr(self, 'init_calls', 0) + 1 |
| 344 | + |
| 345 | + # Create instances with the same value |
| 346 | + box1 = Box(10) |
| 347 | + box2 = Box(10) |
| 348 | + box3 = Box(20) |
| 349 | + |
| 350 | + # Ensure they are the same instance |
| 351 | + assert box1.value == 10 |
| 352 | + assert box1 is box2 |
| 353 | + |
| 354 | + # Ensure initialization only happened once |
| 355 | + assert box1.init_calls == 1 |
| 356 | + |
| 357 | + # Ensure different values create different instances |
| 358 | + assert box1 is not box3 |
| 359 | + assert box3.init_calls == 1 |
| 360 | + |
| 361 | + def test_memo_with_new(self): |
| 362 | + """ |
| 363 | + Tests that `memoized_instances` works correctly with `__new__`. |
| 364 | + """ |
| 365 | + @_memoized_instances |
| 366 | + class BoxWithNew: |
| 367 | + def __new__(cls, value: int): |
| 368 | + instance = super().__new__(cls) |
| 369 | + instance.value = value |
| 370 | + return instance |
| 371 | + |
| 372 | + def __init__(self, value: int): |
| 373 | + self.value += value |
| 374 | + self.init_calls = getattr(self, 'init_calls', 0) + 1 |
| 375 | + |
| 376 | + # Create instances with the same value |
| 377 | + box1 = BoxWithNew(10) |
| 378 | + box2 = BoxWithNew(10) |
| 379 | + box3 = BoxWithNew(20) |
| 380 | + |
| 381 | + # Ensure they are the same instance |
| 382 | + assert box1.value == 20 |
| 383 | + assert box1 is box2 |
| 384 | + |
| 385 | + # Ensure initialization only happened once |
| 386 | + assert box1.init_calls == 1 |
| 387 | + |
| 388 | + # Ensure different values create different instances |
| 389 | + assert box1 is not box3 |
| 390 | + assert box3.init_calls == 1 |
| 391 | + |
| 392 | + def test_idempotency(self): |
| 393 | + """ |
| 394 | + Tests that applying the decorator multiple times in an inheritance chain |
| 395 | + does not change the behavior. |
| 396 | + """ |
| 397 | + @_memoized_instances |
| 398 | + @_memoized_instances |
| 399 | + class Box: |
| 400 | + def __init__(self, value: int): |
| 401 | + self.value = value |
| 402 | + self.init_calls = getattr(self, 'init_calls', 0) + 1 |
| 403 | + |
| 404 | + @_memoized_instances |
| 405 | + class SubBox(Box): |
| 406 | + def __init__(self, value: int): |
| 407 | + super().__init__(value) |
| 408 | + self.sub_init_calls = getattr(self, 'sub_init_calls', 0) + 1 |
| 409 | + |
| 410 | + # Create instances with the same value |
| 411 | + box = Box(10) |
| 412 | + subbox1 = SubBox(10) |
| 413 | + subbox2 = SubBox(10) |
| 414 | + subbox3 = SubBox(20) |
| 415 | + |
| 416 | + # Ensure the subclass instances are not the same as the base class |
| 417 | + assert box is not subbox1 |
| 418 | + assert subbox1 is subbox2 |
| 419 | + assert subbox1 is not subbox3 |
| 420 | + |
| 421 | + def test_constructed_elsewhere(self): |
| 422 | + """ |
| 423 | + Tests that instances somehow constructed without the replaced new function |
| 424 | + are still initialized correctly (edge case). |
| 425 | + """ |
| 426 | + class Box: |
| 427 | + def __new__(cls, _: int): |
| 428 | + return super().__new__(cls) |
| 429 | + |
| 430 | + def __init__(self, value: int): |
| 431 | + self.value = value |
| 432 | + self.init_calls = getattr(self, 'init_calls', 0) + 1 |
| 433 | + |
| 434 | + # Store the original __new__ method |
| 435 | + original_new = Box.__new__ |
| 436 | + Box = _memoized_instances(Box) |
| 437 | + |
| 438 | + # Create a cached instance as normal |
| 439 | + box1 = Box(10) |
| 440 | + |
| 441 | + # Restore the original __new__ method and construct a new instance |
| 442 | + Box.__new__ = original_new |
| 443 | + box2 = Box(10) |
| 444 | + |
| 445 | + # Ensure the new instance is initialized correctly |
| 446 | + assert box2.value == 10 |
| 447 | + assert box2.init_calls == 1 |
| 448 | + |
| 449 | + # Ensure the instances are not the same |
| 450 | + assert box1 is not box2 |
0 commit comments