Skip to content

Commit 218bed4

Browse files
committed
Post-rebase fixes
line_profiler/_line_profiler.pyx _SysMonitoringState.line_tracing_events, .line_tracing_event_set New class attributes _LineProfilerManager _base_callback() - Updated call signature - Now using `._call_callback()` to handle stored callbacks _call_callback() New helper method for calling the stored callback which: - Fetches changes to the `sys.monitoring` events and callables to `self.mon_state` - Restores the global event list and tool-id lock to ensure continued profiling if necessary tests/test_sys_trace.py suspend_tracing Updated implementation to also suspend `sys.monitoring`-based line profiling _test_helper_callback_preservation() Updated to skip the legacy-tracing check when using `sys.monitoring`-based line profiling test_callback_wrapping() Updated implementation to reflect non-interference with legacy trace functions when using `sys.monitoring`-based line profiling test_wrapping_thread_local_callbacks() - Updated parametrization to reflect non-interference with legacy trace functions when using `sys.monitoring`-based line profiling - Updated assertion to give more informative errors
1 parent 77d82ae commit 218bed4

2 files changed

Lines changed: 91 additions & 45 deletions

File tree

line_profiler/_line_profiler.pyx

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,18 @@ cdef class _SysMonitoringState:
314314
cdef dict callbacks # type: dict[int, Callable | None]
315315
cdef int events
316316

317+
if CAN_USE_SYS_MONITORING:
318+
line_tracing_event_set = ( # type: ClassVar[FrozenSet[int]]
319+
frozenset({sys.monitoring.events.LINE,
320+
sys.monitoring.events.PY_RETURN,
321+
sys.monitoring.events.PY_YIELD}))
322+
line_tracing_events = (sys.monitoring.events.LINE
323+
| sys.monitoring.events.PY_RETURN
324+
| sys.monitoring.events.PY_YIELD)
325+
else:
326+
line_tracing_event_set = frozenset({})
327+
line_tracing_events = 0
328+
317329
def __init__(self, name=None, callbacks=None, events=0):
318330
self.name = name
319331
self.callbacks = callbacks or {}
@@ -342,11 +354,7 @@ cdef class _SysMonitoringState:
342354
self.events = mon.get_events(mon.PROFILER_ID)
343355
mon.free_tool_id(mon.PROFILER_ID)
344356
mon.use_tool_id(mon.PROFILER_ID, 'line_profiler')
345-
events = (self.events
346-
| mon.events.LINE
347-
| mon.events.PY_RETURN
348-
| mon.events.PY_YIELD)
349-
mon.set_events(mon.PROFILER_ID, events)
357+
mon.set_events(mon.PROFILER_ID, self.events | self.line_tracing_events)
350358

351359
# Register tracebacks
352360
for event_id, callback in [
@@ -435,7 +443,7 @@ cdef class _LineProfilerManager:
435443
:py:func:`sys.monitoring.register_callback`.
436444
"""
437445
self._base_callback(
438-
sys.monitoring.events.LINE, code, lineno, (lineno,))
446+
1, sys.monitoring.events.LINE, code, lineno, (lineno,))
439447

440448
@cython.profile(False)
441449
cpdef handle_return_event(
@@ -467,47 +475,69 @@ cdef class _LineProfilerManager:
467475
This is deliberately made a non-traceable C method so that
468476
we don't fall info infinite recursion.
469477
"""
470-
self._base_callback(event_id,
471-
code,
472-
PyCode_Addr2Line(<PyCodeObject*>code, offset),
473-
(offset, retval))
478+
cdef int lineno = PyCode_Addr2Line(<PyCodeObject*>code, offset)
479+
self._base_callback(0, event_id, code, lineno, (offset, retval))
474480

475481
cdef void _base_callback(
476-
self, int event_id, object code, int lineno, object other_args):
482+
self, int is_line_event, int event_id,
483+
object code, int lineno, object other_args):
477484
"""
478485
Base for the various callbacks passed to
479-
:py:func:`sys.monitoring.register_callback`. Also takes care of
480-
the restoration of callbacks should they be unset.
486+
:py:func:`sys.monitoring.register_callback`.
481487
482488
Note:
483489
This is deliberately made a non-traceable C method so that
484490
we don't fall info infinite recursion.
485491
"""
486-
cdef object callback_before, callback_after, callback_wrapped
487-
cdef dict callbacks = self.mon_state.callbacks
488-
mon = sys.monitoring
489-
inner_trace_callback((event_id == mon.events.LINE),
490-
self.active_instances,
491-
code,
492-
lineno)
493-
if not self.wrap_trace:
494-
return
492+
inner_trace_callback(
493+
is_line_event, self.active_instances, code, lineno)
494+
if self._wrap_trace:
495+
self._call_callback(event_id, code, other_args)
496+
497+
cdef void _call_callback(
498+
self, int event_id, object code, object other_args):
499+
"""
500+
Call the stored callback in ``self.mon_state``. Also takes care
501+
of the restoration of :py:mod:`sys.monitoring` callbacks,
502+
tool-id lock, and events should they be unset.
495503
496-
# Call wrapped callback
497-
callback_wrapped = callbacks.get(event_id)
498-
if callback_wrapped is None:
504+
Note:
505+
This is deliberately made a non-traceable C method so that
506+
we don't fall info infinite recursion.
507+
"""
508+
mon = sys.monitoring
509+
cdef object callback # type: Callable | None
510+
cdef object callback_after # type: Callable | None
511+
cdef object callback_wrapped # type: Callable | None
512+
cdef int events_needed
513+
cdef int ev_id
514+
cdef int prof_id = mon.PROFILER_ID
515+
cdef _SysMonitoringState state = self.mon_state
516+
cdef dict callbacks_before = {}
517+
# Call wrapped callback where suitable
518+
callback_wrapped = state.callbacks.get(event_id)
519+
if callback_wrapped is None or not (event_id & state.events):
499520
return
500-
callback_before = get_current_callback(event_id)
521+
for ev_id in state.line_tracing_event_set:
522+
callbacks_before[ev_id] = get_current_callback(ev_id)
523+
events_needed = state.line_tracing_events
501524
try:
502525
callback_wrapped(code, *other_args)
503526
finally:
504-
callback_after = get_current_callback(event_id)
505-
if callback_after is None:
506-
# The wrapped callback has unset itself;
507-
# remove from `.mon_state.callbacks`
508-
callbacks[event_id] = None
509-
if callback_before is not callback_after:
510-
mon.register_callback(mon.PROFILER_ID, callback_before)
527+
state.events = mon.get_events(prof_id)
528+
register = mon.register_callback
529+
# If the wrapped callback has changed:
530+
for ev_id, callback in callbacks_before.items():
531+
# - Restore the `sys.monitoring` callback
532+
callback_after = register(prof_id, ev_id, callback)
533+
# - Remember the updated callback in `state.callbacks`
534+
if callback is not callback_after:
535+
state.callbacks[ev_id] = callback_after
536+
# Reset the tool ID lock if released
537+
if not mon.get_tool(prof_id):
538+
mon.use_tool_id(prof_id, 'line_profiler')
539+
# Restore the `sys.monitoring` events if unset
540+
mon.set_events(prof_id, state.events | events_needed)
511541

512542
cpdef _handle_enable_event(self, prof):
513543
cdef TraceCallback* legacy_callback

tests/test_sys_trace.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
from ast import literal_eval
2727
from contextlib import nullcontext
2828
from io import StringIO
29-
from types import FrameType
29+
from types import FrameType, ModuleType
3030
from typing import Any, Optional, Union, Callable, List, Literal
3131
from line_profiler import LineProfiler
3232

3333

3434
# Common utilities
3535

3636
DEBUG = False
37+
USE_SYS_MONITORING = isinstance(getattr(sys, 'monitoring', None), ModuleType)
3738

3839
Event = Literal['call', 'line', 'return', 'exception', 'opcode']
3940
TracingFunc = Callable[[FrameType, Event, Any], Union['TracingFunc', None]]
@@ -163,14 +164,25 @@ def baz(n: int) -> int:
163164
class suspend_tracing:
164165
def __init__(self):
165166
self.callback = None
167+
self.events = 0
166168

167169
def __enter__(self):
168-
self.callback = sys.gettrace()
169-
sys.settrace(None)
170+
if USE_SYS_MONITORING:
171+
mod = sys.monitoring
172+
self.events = mod.get_events(mod.PROFILER_ID)
173+
mod.set_events(mod.PROFILER_ID, mod.events.NO_EVENTS)
174+
else:
175+
self.callback = sys.gettrace()
176+
sys.settrace(None)
170177

171178
def __exit__(self, *_, **__):
172-
sys.settrace(self.callback)
173-
self.callback = None
179+
if USE_SYS_MONITORING:
180+
mod = sys.monitoring
181+
mod.set_events(mod.PROFILER_ID, self.events)
182+
self.events = 0
183+
else:
184+
sys.settrace(self.callback)
185+
self.callback = None
174186

175187

176188
def get_incr_logger(logs: List[str], func: Literal[foo, bar, baz] = foo, *,
@@ -270,8 +282,9 @@ def _test_helper_callback_preservation(
270282
assert sys.gettrace() is callback, f'can\'t set trace to {callback!r}'
271283
profile = LineProfiler(wrap_trace=False)
272284
profile.enable_by_count()
273-
assert profile in sys.gettrace().active_instances, (
274-
'can\'t set trace to the profiler')
285+
if not USE_SYS_MONITORING:
286+
assert profile in sys.gettrace().active_instances, (
287+
'can\'t set trace to the profiler')
275288
profile.disable_by_count()
276289
assert sys.gettrace() is callback, f'trace not restored to {callback!r}'
277290
sys.settrace(None)
@@ -311,7 +324,7 @@ def test_callback_wrapping(
311324
else:
312325
foo_like = foo
313326
trace_preserved = True
314-
if trace_preserved:
327+
if trace_preserved or USE_SYS_MONITORING:
315328
exp_logs = [f'foo: spam = {spam}' for spam in range(1, 6)]
316329
else:
317330
exp_logs = []
@@ -558,9 +571,12 @@ def test_wrapping_thread_local_callbacks(label: str,
558571
[(True, True, 100, {0: 2, # Both calls are traced
559572
5: 0, # Tracing suspended
560573
7: 100}), # Tracing restored (both calls)
561-
# If `set_frame_local_trace` is false, tracing is suspended for the
562-
# rest of the frame
563-
(True, False, 100, {0: 2, 5: 0, 7: 0}),
574+
# If `set_frame_local_trace` is false:
575+
# - When using legacy tracing, tracing is suspended for the rest of
576+
# the frame
577+
# - Else, tracing is unaffected
578+
(True, False, 100,
579+
{0: 2, 5: 0, 7: 100 if USE_SYS_MONITORING else 0}),
564580
# Calling a function always triggers `<trace>.__call__()`
565581
(False, True, 100, {0: 1, # Only one of the calls is traced
566582
2: 100}), # 100 hits on the line in the loop
@@ -621,4 +637,4 @@ def func_break_in_middle(n):
621637
all_nhits = {lineno - body_start_line: _nhits
622638
for (lineno, _nhits, _) in entries}
623639
all_nhits = {lineno: all_nhits.get(lineno, 0) for lineno in nhits}
624-
assert all_nhits == nhits
640+
assert all_nhits == nhits, f'expected {nhits=}, got {all_nhits=}'

0 commit comments

Comments
 (0)