Skip to content

Commit 03e88bd

Browse files
authored
fix(classes): make RemoteClassWrapper and LRUCache picklable (#306)
* fix(AE-2745): make RemoteClassWrapper and LRUCache picklable asyncio.Lock and threading.RLock are not picklable. When ResourceManager._save_resources() calls cloudpickle.dump(), the serialization fails silently, causing duplicate endpoints on re-deploy. - Add __getstate__/__setstate__ to RemoteClassWrapper excluding _init_lock and _stub, resetting _initialized on restore - Add __getstate__/__setstate__ to LRUCache excluding _lock - Add 5 tests covering pickle roundtrip and state exclusion * fix(AE-2745): repopulate serialization cache on unpickle, harden style - __setstate__ now calls get_or_cache_class_data() to repopulate the serialization cache after unpickle, preventing method_proxy from crashing on cache miss in a new process - LRUCache.__getstate__ uses pop() instead of del for consistency with RemoteClassWrapper's defensive style - Add test_pickle_repopulates_serialization_cache covering the cache miss scenario after cloudpickle roundtrip * fix(AE-2745): fix 3.11 CI, address review feedback - Revert __setstate__ cache repopulation: adding _SERIALIZED_CLASS_CACHE as a direct global reference forced cloudpickle to serialize the LRUCache instance as part of class globals, hitting a weakref on 3.11. Not needed: cloudpickle preserves the cache in the class globals copy. - LRUCache.__getstate__ acquires _lock for thread-safe pickling - Use _class_type.__name__ instead of identity checks in tests since cloudpickle rehydrates non-importable classes by value - Replace cache repopulation test with metadata preservation test
1 parent 02058db commit 03e88bd

4 files changed

Lines changed: 159 additions & 1 deletion

File tree

src/runpod_flash/core/utils/lru_cache.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,13 @@ def __getitem__(self, key: str) -> Dict[str, Any]:
7373
def __setitem__(self, key: str, value: Dict[str, Any]) -> None:
7474
"""Set item using bracket notation."""
7575
self.set(key, value)
76+
77+
def __getstate__(self) -> Dict[str, Any]:
78+
with self._lock:
79+
state = self.__dict__.copy()
80+
state.pop("_lock")
81+
return state
82+
83+
def __setstate__(self, state: Dict[str, Any]) -> None:
84+
self.__dict__.update(state)
85+
self._lock = threading.RLock()

src/runpod_flash/execute_class.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,19 @@ def __init__(self, *args, **kwargs):
225225
cls, args, kwargs, self._cache_key
226226
)
227227

228+
_UNPICKLABLE_ATTRS = frozenset({"_init_lock", "_stub"})
229+
230+
def __getstate__(self) -> dict:
231+
state = self.__dict__.copy()
232+
for attr in self._UNPICKLABLE_ATTRS:
233+
state.pop(attr, None)
234+
state["_initialized"] = False
235+
return state
236+
237+
def __setstate__(self, state: dict) -> None:
238+
self.__dict__.update(state)
239+
self._init_lock = asyncio.Lock()
240+
228241
async def _ensure_initialized(self):
229242
"""Ensure the remote instance is created exactly once, even under concurrent calls."""
230243
# Fast path: already initialized, no lock needed.

tests/unit/core/utils/test_lru_cache.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Tests for LRU cache implementation."""
22

3+
import threading
34
from concurrent.futures import ThreadPoolExecutor, as_completed
45

6+
import cloudpickle
57
import pytest
68

79
from runpod_flash.core.utils.lru_cache import LRUCache
@@ -296,3 +298,17 @@ def writer(start: int):
296298

297299
assert not errors
298300
assert len(cache) == 100
301+
302+
def test_pickle_roundtrip(self):
303+
"""LRUCache must survive cloudpickle roundtrip (AE-2745)."""
304+
cache = LRUCache(max_size=5)
305+
cache.set("a", {"value": 1})
306+
cache.set("b", {"value": 2})
307+
308+
data = cloudpickle.dumps(cache)
309+
restored = cloudpickle.loads(data)
310+
311+
assert restored.get("a") == {"value": 1}
312+
assert restored.get("b") == {"value": 2}
313+
assert restored.max_size == 5
314+
assert type(restored._lock) is type(threading.RLock())

tests/unit/test_execute_class.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
import cloudpickle
1111
import pytest
1212
from runpod_flash.core.resources import ServerlessResource
13-
from runpod_flash.execute_class import create_remote_class, extract_class_code_simple
13+
from runpod_flash.execute_class import (
14+
_SERIALIZED_CLASS_CACHE,
15+
create_remote_class,
16+
extract_class_code_simple,
17+
)
1418
from runpod_flash.protos.remote_execution import FunctionRequest
1519

1620

@@ -704,3 +708,118 @@ def complex_method(
704708
obj = ReconstructedClass("test")
705709
assert obj.name == "test"
706710
assert obj.CLASS_VAR == "class_variable"
711+
712+
713+
class TestRemoteClassWrapperPickle:
714+
"""Test pickle support for RemoteClassWrapper (AE-2745).
715+
716+
asyncio.Lock is not picklable. RemoteClassWrapper must implement
717+
__getstate__/__setstate__ so ResourceManager._save_resources() can
718+
persist state via cloudpickle without raising.
719+
"""
720+
721+
def setup_method(self):
722+
_SERIALIZED_CLASS_CACHE.clear()
723+
self.resource_config = ServerlessResource(
724+
name="test-resource", image="test-image:latest", cpu=1, memory=512
725+
)
726+
727+
def test_pickle_roundtrip(self):
728+
"""RemoteClassWrapper instances must survive cloudpickle roundtrip."""
729+
730+
class MyModel:
731+
def predict(self, x):
732+
return x
733+
734+
RemoteWrapper = create_remote_class(
735+
MyModel, self.resource_config, ["numpy"], ["git"], True
736+
)
737+
instance = RemoteWrapper(42, name="test")
738+
739+
data = cloudpickle.dumps(instance)
740+
restored = cloudpickle.loads(data)
741+
742+
assert restored._class_type.__name__ == "MyModel"
743+
assert restored._constructor_args == (42,)
744+
assert restored._constructor_kwargs == {"name": "test"}
745+
assert restored._dependencies == ["numpy"]
746+
assert restored._system_dependencies == ["git"]
747+
assert restored._instance_id == instance._instance_id
748+
assert not restored._initialized
749+
assert type(restored._init_lock) is type(asyncio.Lock())
750+
751+
def test_pickle_excludes_lock_and_stub(self):
752+
"""Pickle state must not contain non-picklable fields."""
753+
754+
class MyModel:
755+
def predict(self, x):
756+
return x
757+
758+
RemoteWrapper = create_remote_class(MyModel, self.resource_config, [], [], True)
759+
instance = RemoteWrapper()
760+
761+
state = instance.__getstate__()
762+
763+
assert "_init_lock" not in state
764+
assert "_stub" not in state
765+
assert state["_initialized"] is False
766+
767+
def test_pickle_resets_initialized_flag(self):
768+
"""Unpickled instance must re-initialize on next use."""
769+
770+
class MyModel:
771+
def predict(self, x):
772+
return x
773+
774+
RemoteWrapper = create_remote_class(MyModel, self.resource_config, [], [], True)
775+
instance = RemoteWrapper()
776+
instance._initialized = True
777+
instance._stub = "fake_stub"
778+
779+
data = cloudpickle.dumps(instance)
780+
restored = cloudpickle.loads(data)
781+
782+
assert not restored._initialized
783+
assert not hasattr(restored, "_stub")
784+
assert type(restored._init_lock) is type(asyncio.Lock())
785+
786+
def test_pickle_inside_tuple_like_save_resources(self):
787+
"""Simulate ResourceManager._save_resources() pickle pattern."""
788+
789+
class MyModel:
790+
def predict(self, x):
791+
return x
792+
793+
RemoteWrapper = create_remote_class(MyModel, self.resource_config, [], [], True)
794+
instance = RemoteWrapper()
795+
796+
resources = {"uid1": instance}
797+
configs = {"uid1": "some_hash"}
798+
payload = (resources, configs)
799+
800+
data = cloudpickle.dumps(payload)
801+
restored_resources, restored_configs = cloudpickle.loads(data)
802+
803+
assert "uid1" in restored_resources
804+
assert not restored_resources["uid1"]._initialized
805+
806+
def test_pickle_preserves_cache_key_and_class_data(self):
807+
"""Unpickled instance retains cache_key and class metadata for cache lookup."""
808+
809+
class MyModel:
810+
def predict(self, x):
811+
return x
812+
813+
RemoteWrapper = create_remote_class(
814+
MyModel, self.resource_config, ["numpy"], ["git"], True
815+
)
816+
instance = RemoteWrapper(1, tag="v1")
817+
818+
data = cloudpickle.dumps(instance)
819+
restored = cloudpickle.loads(data)
820+
821+
assert restored._cache_key == instance._cache_key
822+
assert restored._clean_class_code == instance._clean_class_code
823+
assert restored._class_type.__name__ == "MyModel"
824+
assert restored._constructor_args == (1,)
825+
assert restored._constructor_kwargs == {"tag": "v1"}

0 commit comments

Comments
 (0)