Skip to content

Commit d07e30e

Browse files
committed
Fixed Broken MethodMonitor
1 parent 06a1cbf commit d07e30e

2 files changed

Lines changed: 273 additions & 133 deletions

File tree

classmods/_method_monitor.py

Lines changed: 129 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
1-
from typing import Any, Dict, List, Tuple, Type, Callable
1+
from typing import Any, Dict, List, Tuple, Type, Callable, Optional
22
from functools import wraps
3+
import inspect
34

45
class MethodMonitor:
5-
# Dictionary to store monitors for each (class, method) pair
6-
monitors_registery: Dict[Tuple[Type, str], List['MethodMonitor']] = {}
6+
"""
7+
Monitor calls to a specific method on a class.
8+
9+
Multiple monitors may be attached to the same (class, method) pair.
10+
All active monitors are executed AFTER the original method call.
11+
"""
12+
13+
_registry: Dict[Tuple[Type, str], List["MethodMonitor"]] = {}
714

815
def __init__(
9-
self,
10-
target: Type,
11-
callable: Callable[..., None],
12-
monitor_args: tuple = (),
13-
monitor_kwargs: dict = {},
16+
self,
17+
target: Type,
18+
monitor_callable: Callable[..., None],
19+
monitor_args: Optional[Tuple] = None,
20+
monitor_kwargs: Optional[Dict[str, Any]] = None,
1421
*,
15-
target_method: str = '__init__',
22+
target_method: str | Callable = "__init__",
1623
active: bool = True,
17-
) -> None:
24+
) -> None:
1825
"""
1926
A class to monitor method calls of a target class, triggering a handler function after the method is called.
2027
@@ -24,12 +31,12 @@ def __init__(
2431
2532
Args:
2633
target (Type): The target class whose method will be monitored.
27-
callable (MonitorCallable): A callable to execute when the target method is called. -
34+
monitor_callable (MonitorCallable): A callable to execute when the target method is called. -
2835
Signature: monitor_callable(instance: object, *monitor_args, **monitor_kwargs). -
2936
**warning**: sends `None` as the first arg to `MonitorCallable` if target method is `StaticMethod` !!
30-
monitor_args (tuple): Positional arguments to pass to `callable` (default: empty tuple).
31-
monitor_kwargs (dict): Keyword arguments to pass to `callable` (default: empty dict).
32-
target_method (str): Name of the method to monitor (default: '__init__').
37+
monitor_args (Optional[Tuple]): Positional arguments to pass to `callable` (default: empty tuple).
38+
monitor_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to `callable` (default: empty dict).
39+
target_method (str | Callable): Name of the method to monitor or the method itself (default: '__init__').
3340
active (bool): Whether the monitor active initially (default: True).
3441
3542
Example:
@@ -43,61 +50,90 @@ def __init__(
4350
>>> obj.my_method() # Also calls `my_handler(obj)`
4451
"""
4552
self._target = target
46-
self._monitor_callable = callable
47-
self._monitor_args = monitor_args
48-
self._monitor_kwargs = monitor_kwargs
49-
self._target_method = target_method
53+
self._monitor_callable = monitor_callable
54+
self._monitor_args = monitor_args or ()
55+
self._monitor_kwargs = monitor_kwargs or {}
56+
self._target_method = target_method if isinstance(target_method, str) else target_method.__name__
5057
self._active = active
5158

52-
# Add this Monitor to the list of Monitors for each (class, method)
53-
key = self._create_registery_key()
54-
if key not in self.monitors_registery:
55-
self.monitors_registery[key] = []
56-
self._wrap_class_method(target, self._target_method)
57-
58-
self.monitors_registery[key].append(self)
59+
key = (self._target, self._target_method)
5960

61+
if key not in self._registry:
62+
self._wrap_method()
63+
self._registry[key] = []
6064

61-
def _create_registery_key(self) -> Tuple[Type, str]:
62-
return (self._target, self._target_method)
63-
64-
def _create_original_name(self, method_name: str) -> str:
65-
return f'__original_{method_name}'
65+
self._registry[key].append(self)
6666

6767
@staticmethod
68-
def _is_static_method(method: staticmethod|classmethod|Callable[[Any] ,Any]) -> bool:
69-
return isinstance(method, (staticmethod, classmethod))
70-
71-
def _wrap_class_method(self, target: Type, method_name: str) -> None:
72-
"""Wrap the target method to call all Monitors."""
73-
original_name = self._create_original_name(method_name)
74-
75-
if not hasattr(target, method_name):
76-
raise ValueError(f"The target class {target.__name__} does not have a method '{method_name}'.")
77-
78-
# Save the original method if not already saved
79-
if not hasattr(target, original_name):
80-
setattr(target, original_name, getattr(target, method_name))
81-
82-
original_method = getattr(target, original_name)
83-
84-
@wraps(original_method)
85-
def new_method(*args, **kwargs) -> Any:
86-
output = original_method(*args, **kwargs)
87-
88-
key = self._create_registery_key()
89-
for monitor in MethodMonitor.monitors_registery.get(key, []):
90-
if monitor.is_active():
91-
monitor._monitor_callable(
92-
args[0] if self._is_static_method(original_method) else None,
93-
*monitor._monitor_args,
94-
**monitor._monitor_kwargs,
95-
)
96-
97-
return output
98-
99-
setattr(target, method_name, new_method)
100-
68+
def _get_method_type(target_class, method_name) -> str:
69+
"""Return 'instance', 'class', or 'static'."""
70+
for cls in target_class.__mro__:
71+
if method_name in cls.__dict__:
72+
attr = cls.__dict__[method_name]
73+
if isinstance(attr, staticmethod):
74+
return "static"
75+
elif isinstance(attr, classmethod):
76+
return "class"
77+
else:
78+
return "instance"
79+
return "instance"
80+
81+
def _wrap_method(self) -> None:
82+
if not hasattr(self._target, self._target_method):
83+
raise AttributeError(
84+
f"{self._target.__name__} has no method '{self._target_method}'"
85+
)
86+
87+
# Get the original descriptor
88+
original_attr = None
89+
for cls in self._target.__mro__:
90+
if self._target_method in cls.__dict__:
91+
original_attr = cls.__dict__[self._target_method]
92+
break
93+
94+
if original_attr is None:
95+
raise AttributeError(f"Cannot find method {self._target_method}")
96+
97+
method_type = self._get_method_type(self._target, self._target_method)
98+
99+
# Extract the original function for wrapping
100+
if isinstance(original_attr, (staticmethod, classmethod)):
101+
original_func = original_attr.__func__
102+
else:
103+
original_func = original_attr
104+
105+
@wraps(original_func)
106+
def wrapper(*args, **kwargs):
107+
result = original_func(*args, **kwargs)
108+
109+
for monitor in self._registry.get((self._target, self._target_method), []):
110+
if not monitor._active:
111+
continue
112+
113+
if method_type == "static":
114+
first_arg = None
115+
else:
116+
first_arg = args[0]
117+
118+
monitor._monitor_callable(
119+
first_arg, *monitor._monitor_args, **monitor._monitor_kwargs
120+
)
121+
122+
return result
123+
124+
# Reapply descriptor to preserve type
125+
if method_type == "static":
126+
wrapped = staticmethod(wrapper)
127+
elif method_type == "class":
128+
wrapped = classmethod(wrapper)
129+
else:
130+
wrapped = wrapper
131+
132+
# Store reference to original for unwrapping
133+
wrapped.__methodmonitor_original__ = original_func # type: ignore[attr-defined]
134+
wrapped.__methodmonitor_wrapped__ = True # type: ignore[attr-defined]
135+
136+
setattr(self._target, self._target_method, wrapped)
101137

102138
def activate(self) -> None:
103139
"""Activate the monitor."""
@@ -108,31 +144,39 @@ def deactivate(self) -> None:
108144
self._active = False
109145

110146
def remove(self) -> None:
111-
"""Remove the handler and restore the original method if no monitors are left."""
112-
key = self._create_registery_key()
113-
if key in self.monitors_registery:
114-
self.monitors_registery[key].remove(self)
115-
if not self.monitors_registery[key]:
116-
# Restore the original method
117-
original_name = self._create_original_name(self._target_method)
118-
if hasattr(self._target, original_name):
119-
setattr(self._target, self._target_method, getattr(self._target, original_name))
120-
delattr(self._target, original_name)
121-
122-
del self.monitors_registery[key]
123-
147+
"""Remove the monitor and restore original method if no monitors left."""
148+
key = (self._target, self._target_method)
149+
monitors = self._registry.get(key)
150+
if not monitors:
151+
return
152+
153+
if self in monitors:
154+
monitors.remove(self)
155+
156+
if not monitors:
157+
wrapped = getattr(self._target, self._target_method)
158+
original = getattr(wrapped, "__methodmonitor_original__", None)
159+
if original:
160+
# Reapply original descriptor type
161+
method_type = self._get_method_type(self._target, self._target_method)
162+
if method_type == "static":
163+
original = staticmethod(original)
164+
elif method_type == "class":
165+
original = classmethod(original)
166+
setattr(self._target, self._target_method, original)
167+
168+
del self._registry[key]
124169

125170
def is_active(self) -> bool:
126-
return bool(self._active)
171+
return self._active
127172

128173
def __bool__(self) -> bool:
129-
return self.is_active()
130-
131-
def __str__(self) -> str:
132-
return f'<MethodMonitor of: {self._target} (method={self._target_method})>'
174+
return self._active
133175

134176
def __repr__(self) -> str:
135-
return f'MethodMonitor({self._target}, {self._monitor_callable}, target_method={self._target_method}, monitor_args={self._monitor_args}, monitor_kwargs={self._monitor_kwargs})'
136-
137-
def __del__(self) -> None:
138-
self.remove()
177+
return (
178+
f"MethodMonitor("
179+
f"target={self._target.__name__}, "
180+
f"method={self._target_method}, "
181+
f"active={self._active})"
182+
)

0 commit comments

Comments
 (0)