Skip to content

Commit 186d04e

Browse files
committed
fix(profiling): make annotate_pipeline transient by returning restore() callable
Class-level patches are saved before application and the function now returns a restore() callable that undoes them, making the patches explicitly transient. This addresses reviewer feedback that class-level patches without cleanup are non-transient.
1 parent 1b6c59c commit 186d04e

1 file changed

Lines changed: 31 additions & 9 deletions

File tree

examples/profiling/profiling_utils.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,18 @@ def annotate_pipeline(pipe):
2929
"""Apply profiler annotations to key pipeline methods.
3030
3131
Monkey-patches bound methods so they appear as named spans in the trace.
32-
Non-invasive no source modifications required.
32+
Non-invasive -- no source modifications required.
3333
3434
Sub-component methods are patched at the **class level** (via
35-
setattr(type(component), ...)) rather than on the instance. This
35+
`setattr(type(component), ...)`) rather than on the instance. This
3636
ensures Python's descriptor protocol re-binds the wrapper to whichever
3737
instance accesses it, so shallow-copied components (e.g. the duplicated
3838
audio_scheduler inside LTX2) call their own logic rather than the
3939
original instance's.
40+
41+
Returns:
42+
A restore callable that undoes all patches and restores the
43+
original method definitions, making the annotation transient.
4044
"""
4145
annotations = [
4246
("transformer", "forward", "transformer_forward"),
@@ -45,6 +49,8 @@ def annotate_pipeline(pipe):
4549
("scheduler", "step", "scheduler_step"),
4650
]
4751

52+
saved = [] # (target, method_name, original_value, is_class_patch)
53+
4854
# Annotate sub-component methods
4955
for component_name, method_name, label in annotations:
5056
component = getattr(pipe, component_name, None)
@@ -54,20 +60,36 @@ def annotate_pipeline(pipe):
5460
if method is None:
5561
continue
5662
if inspect.ismethod(method):
57-
# Wrap the underlying function and patch at the class level so
58-
# that the descriptor protocol correctly rebinds the wrapper to
59-
# whichever instance accesses it. This prevents instance-
60-
# isolation bugs when a component is shallow-copied after
61-
# annotation (e.g. audio_scheduler = copy.copy(self.scheduler)
62-
# in the LTX2 pipeline).
63-
setattr(type(component), method_name, annotate(method.__func__, label))
63+
# Patch at the class level so the descriptor protocol correctly
64+
# re-binds the wrapper to whichever instance accesses it. This
65+
# prevents instance-isolation bugs when a component is
66+
# shallow-copied. The original class attribute is saved so the
67+
# patch can be reversed when restore() is called.
68+
cls = type(component)
69+
original = cls.__dict__.get(method_name)
70+
setattr(cls, method_name, annotate(method.__func__, label))
71+
saved.append((cls, method_name, original, True))
6472
else:
73+
original = component.__dict__.get(method_name)
6574
setattr(component, method_name, annotate(method, label))
75+
saved.append((component, method_name, original, False))
6676

6777
# Annotate pipeline-level methods
6878
if hasattr(pipe, "encode_prompt"):
79+
original = pipe.__dict__.get("encode_prompt")
6980
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
81+
saved.append((pipe, "encode_prompt", original, False))
82+
83+
def restore():
84+
"""Undo all patches applied by annotate_pipeline, restoring originals."""
85+
for target, name, original, is_class in saved:
86+
if original is None:
87+
if name in vars(target):
88+
delattr(target, name)
89+
else:
90+
setattr(target, name, original)
7091

92+
return restore
7193

7294
def flush():
7395
gc.collect()

0 commit comments

Comments
 (0)