Skip to content

Commit 02b6df3

Browse files
committed
refactor: type-hinting & PEP 561
1 parent 881a8ca commit 02b6df3

5 files changed

Lines changed: 81 additions & 70 deletions

File tree

.github/workflows/tests.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ jobs:
2525
- name: Install dependencies
2626
run: |
2727
python -m pip install --upgrade pip
28-
pip install pytest
29-
pip install .
28+
pip install . --group dev
3029
3130
- name: Run Pytest
3231
run: |
33-
pytest tests/
32+
pytest tests/
33+
34+
- name: Run Mypy
35+
run: |
36+
mypy oracletrace

oracletrace/cli.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
import sys
33
import os
44
import json
5-
import argparse
65
import runpy
76
import csv
87
from .tracer import Tracer
98
from .compare import compare_traces
9+
from typing import List, Dict, Any, Optional
10+
from re import Pattern
11+
from argparse import ArgumentParser, Namespace
12+
from pathlib import Path
1013

1114

12-
def main():
13-
parser = argparse.ArgumentParser(
15+
def main() -> int:
16+
parser: ArgumentParser = ArgumentParser(
1417
description="OracleTrace - Lightweight execution tracer for Python projects"
1518
)
1619
parser.add_argument("target", help="Python script to trace")
@@ -39,21 +42,21 @@ def main():
3942
default=5.0,
4043
help="Regression threshold percentage used with --fail-on-regression.",
4144
)
42-
args = parser.parse_args()
45+
args: Namespace = parser.parse_args()
4346

44-
target = args.target
47+
target: str = args.target
4548

4649
if not os.path.exists(target):
4750
print(f"Target not found: {target}")
4851
return 1
4952

5053
target = os.path.abspath(target)
51-
root = os.getcwd()
52-
target_dir = os.path.dirname(target)
54+
root: str = os.getcwd()
55+
target_dir: str = os.path.dirname(target)
5356
# Setup paths so imports work correctly in the target script
5457
sys.path.insert(0, target_dir)
55-
ignored_args = [] if args.ignore is None else args.ignore
56-
ignore_patterns = []
58+
ignored_args: List[str] = [] if args.ignore is None else args.ignore
59+
ignore_patterns: List[Pattern] = []
5760

5861
for pattern in ignored_args:
5962
try:
@@ -63,14 +66,14 @@ def main():
6366
return 1
6467

6568
# Start tracing, run the script, then stop
66-
tracer = Tracer(root, ignore_patterns=ignore_patterns)
69+
tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns)
6770
tracer.start()
6871
try:
6972
runpy.run_path(target, run_name="__main__")
7073
finally:
7174
tracer.stop()
7275

73-
data = tracer.get_trace_data()
76+
data: Dict[str, Any] = tracer.get_trace_data()
7477

7578
# Save json
7679
if args.json:
@@ -86,7 +89,7 @@ def main():
8689
# Export as csv
8790
if args.csv:
8891
with open(args.csv, "w", newline="", encoding="utf-8") as f:
89-
writer = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"])
92+
writer: csv.DictWriter = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"])
9093
writer.writeheader()
9194
for fn in data["functions"]:
9295
writer.writerow({
@@ -96,7 +99,7 @@ def main():
9699
"avg_time": fn["avg_time"],
97100
})
98101

99-
comparison_result = None
102+
comparison_result: Optional[Dict[str,Any]] = None
100103

101104
# Compare jsons
102105
if args.compare:
@@ -105,7 +108,7 @@ def main():
105108
return 1
106109

107110
with open(args.compare, "r", encoding="utf-8") as f:
108-
old_data = json.load(f)
111+
old_data: Dict[str, Any] = json.load(f)
109112

110113
comparison_result = compare_traces(old_data, data, threshold=args.threshold)
111114

oracletrace/compare.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from rich import print
2+
from typing import Any, Dict, List, Set, Optional
23

34

4-
def compare_traces(old_data, new_data, threshold=5.0):
5-
old_funcs = {f["name"]: f for f in old_data["functions"]}
6-
new_funcs = {f["name"]: f for f in new_data["functions"]}
5+
def compare_traces(old_data: Dict[str, Any], new_data: Dict[str, Any], threshold: float = 5.0) -> Dict[str, Any]:
6+
old_funcs: Dict[str, Any] = {f["name"]: f for f in old_data["functions"]}
7+
new_funcs: Dict[str, Any] = {f["name"]: f for f in new_data["functions"]}
78

8-
regressions = []
9+
regressions: List[Dict[str, Any]] = []
910

1011
print("\n[bold cyan]Comparison Results:[/]\n")
1112

12-
all_functions = set(old_funcs) | set(new_funcs)
13+
all_functions: Set[str] = set(old_funcs) | set(new_funcs)
1314

1415
for name in sorted(all_functions):
15-
old = old_funcs.get(name)
16-
new = new_funcs.get(name)
16+
old: Optional[None] = old_funcs.get(name)
17+
new: Optional[None] = new_funcs.get(name)
1718

1819
if not old:
1920
print(f"[green]+ {name} (new function)[/]")
@@ -23,16 +24,16 @@ def compare_traces(old_data, new_data, threshold=5.0):
2324
print(f"[red]- {name} (removed)[/]")
2425
continue
2526

26-
old_time = old["total_time"]
27-
new_time = new["total_time"]
27+
old_time: float = old["total_time"]
28+
new_time: float = new["total_time"]
2829

2930
if old_time == 0:
3031
continue
3132

32-
diff = new_time - old_time
33-
percent = (diff / old_time) * 100
33+
diff: float = new_time - old_time
34+
percent: float = (diff / old_time) * 100
3435

35-
color = "red" if percent > threshold else "green" if percent < -threshold else "yellow"
36+
color: str = "red" if percent > threshold else "green" if percent < -threshold else "yellow"
3637

3738
print(
3839
f"{name}\n"

oracletrace/py.typed

Whitespace-only changes.

oracletrace/tracer.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,40 @@
55
from rich.tree import Tree
66
from rich.table import Table
77
from rich import print
8+
from typing import List, Optional, Callable, DefaultDict, Any, Tuple, Dict
9+
from re import Pattern
10+
from pathlib import Path
11+
from types import FrameType
812

913

1014
class Tracer:
11-
def __init__(self, root_dir, ignore_patterns = None):
12-
self._root_path = os.path.abspath(root_dir)
13-
self._call_stack = []
14-
self._func_calls = defaultdict(int)
15-
self._func_time = defaultdict(float)
16-
self._call_map = defaultdict(lambda: defaultdict(int))
17-
self._original_profile_func = None
18-
self._enabled = False
19-
self._start_time = 0.0
20-
self._total_time = 0.0
21-
self._ignore_patterns = ignore_patterns
22-
23-
24-
def start(self):
15+
def __init__(self, root_dir: str, ignore_patterns: Optional[List[Pattern]] = None) -> None:
16+
self._root_path: str = os.path.abspath(root_dir)
17+
self._call_stack: List[Tuple[int, str, float]] = []
18+
self._func_calls: DefaultDict[str, int] = defaultdict(int)
19+
self._func_time: DefaultDict[str, float] = defaultdict(float)
20+
self._call_map: DefaultDict[str, DefaultDict[str, int]] = defaultdict(lambda: defaultdict(int))
21+
self._original_profile_func: Optional[Callable[[FrameType, str, Any], object]] = None
22+
self._enabled: bool = False
23+
self._start_time: float = 0.0
24+
self._total_time: float = 0.0
25+
self._ignore_patterns: Optional[List[Pattern]] = ignore_patterns
26+
27+
28+
def start(self) -> None:
2529
# Start Tracer
2630
self._enabled = True
2731
self._start_time = time.perf_counter()
2832
self._original_profile_func = sys.getprofile()
2933
sys.setprofile(self._trace)
3034

31-
def stop(self):
35+
def stop(self) -> None:
3236
# Stops Tracer
3337
self._enabled = False
3438
self._total_time = time.perf_counter() - self._start_time
3539
sys.setprofile(self._original_profile_func)
3640

37-
def _is_ignored(self, filename):
41+
def _is_ignored(self, filename: str) -> bool:
3842
# Return true if the filename should be ignored
3943
if not self._ignore_patterns:
4044
return False
@@ -45,46 +49,46 @@ def _is_ignored(self, filename):
4549

4650
return False
4751

48-
def _is_user_code(self, filename):
52+
def _is_user_code(self, filename: str) -> bool:
4953
# Filter out files not in the project root
50-
if not filename.startswith(self._root_path):
54+
if not filename.startswith(str(self._root_path)):
5155
return False
5256
# Filter out third-party libraries
5357
if "site-packages" in filename or "dist-packages" in filename:
5458
return False
5559
return True
5660

57-
def _get_key(self, frame):
58-
co_filename = frame.f_code.co_filename
61+
def _get_key(self, frame: FrameType) -> Optional[str]:
62+
co_filename: str = frame.f_code.co_filename
5963
# Ignore internal python frames (e.g. <string>)
6064
if co_filename.startswith("<"):
6165
return None
62-
filename = os.path.abspath(co_filename)
66+
filename: str = os.path.abspath(co_filename)
6367
# Check if the file belongs to the user's project
6468
if not self._is_user_code(filename):
6569
return None
6670
# Create a relative path key for readability
67-
rel_path = os.path.relpath(filename, self._root_path)
68-
qualname = getattr(frame.f_code, "co_qualname", frame.f_code.co_name)
69-
key = f"{rel_path}:{qualname}"
71+
rel_path: str = os.path.relpath(filename, self._root_path)
72+
qualname: str = getattr(frame.f_code, "co_qualname", frame.f_code.co_name)
73+
key: str = f"{rel_path}:{qualname}"
7074

7175
# Check if the file should be ignored based on inputted ignoring pattern
7276
if self._is_ignored(key):
7377
return None
7478

7579
return key
7680

77-
def _trace(self, frame, event, arg):
81+
def _trace(self, frame: FrameType, event: str, _: Any) -> None:
7882
try:
7983
if not self._enabled:
8084
return
8185

8286
if event == "call":
83-
key = self._get_key(frame)
87+
key: Optional[str] = self._get_key(frame)
8488
if not key:
8589
return
8690

87-
caller = self._call_stack[-1][1] if self._call_stack else "<module>"
91+
caller: str = self._call_stack[-1][1] if self._call_stack else "<module>"
8892
self._call_map[caller][key] += 1
8993
self._func_calls[key] += 1
9094
self._call_stack.append((id(frame), key, time.perf_counter()))
@@ -99,8 +103,8 @@ def _trace(self, frame, event, arg):
99103
self._func_time[key] += time.perf_counter() - start
100104
else:
101105
# Stack unwinding (handle exceptions or missed returns)
102-
fid = id(frame)
103-
found = False
106+
fid: int = id(frame)
107+
found: bool = False
104108
# Search for the frame in the stack from top to bottom
105109
for i in range(len(self._call_stack) - 1, -1, -1):
106110
if self._call_stack[i][0] == fid:
@@ -117,15 +121,15 @@ def _trace(self, frame, event, arg):
117121
except Exception as e:
118122
print(f"[bold red]Error in oracletrace tracer: {e}[/]", file=sys.stderr)
119123

120-
def show_results(self, _top):
124+
def show_results(self, _top: Optional[int]) -> None:
121125
if not self._func_calls:
122126
print("[yellow]No calls traced.[/]")
123127
return
124128

125129
# Summary table
126130
print("[bold green]Summary:[/]")
127131
print(f"[bold cyan]Total execution time: {self._total_time:.4f}s[/]")
128-
table = Table(title="Top functions by Total Time")
132+
table: Table = Table(title="Top functions by Total Time")
129133
table.add_column("Function", justify="left", style="cyan", no_wrap=True)
130134
table.add_column("Total Time (s)", justify="right", style="magenta")
131135
table.add_column("Calls", justify="right", style="green")
@@ -145,11 +149,11 @@ def show_results(self, _top):
145149

146150
print("\n[bold green]Logic Flow:[/]")
147151

148-
tree = Tree("[bold yellow]<module>[/]")
152+
tree: Tree = Tree("[bold yellow]<module>[/]")
149153

150154
# Recursively build the execution tree
151-
def add_nodes(parent_node, parent_key, current_path):
152-
children = self._call_map.get(parent_key, {})
155+
def add_nodes(parent_node: Tree, parent_key: str, current_path: set[str]) -> None:
156+
children: DefaultDict[str,int] = self._call_map.get(parent_key, defaultdict(int))
153157
# Sort children by total execution time
154158
sorted_children = sorted(
155159
children.items(),
@@ -171,8 +175,8 @@ def add_nodes(parent_node, parent_key, current_path):
171175
add_nodes(tree, "<module>", {"<module>"})
172176
print(tree)
173177

174-
def get_trace_data(self):
175-
functions = []
178+
def get_trace_data(self) -> Dict[str, Any]:
179+
functions: List[Dict[str, Any]] = []
176180

177181
for key, total_time in self._func_time.items():
178182
calls = self._func_calls[key]
@@ -198,10 +202,10 @@ def get_trace_data(self):
198202
}
199203

200204

201-
_tracer_instance = None
205+
_tracer_instance: Optional[Tracer] = None
202206

203207

204-
def start_trace(root_dir):
208+
def start_trace(root_dir: str) -> None:
205209
# Starts tracer instance
206210
global _tracer_instance
207211
if _tracer_instance is not None:
@@ -211,7 +215,7 @@ def start_trace(root_dir):
211215
_tracer_instance.start()
212216

213217

214-
def stop_trace():
218+
def stop_trace() -> Optional[Dict[str,Any]]:
215219
# Stops tracer instance
216220
global _tracer_instance
217221
if _tracer_instance:
@@ -224,10 +228,10 @@ def stop_trace():
224228
return None
225229

226230

227-
def show_results():
231+
def show_results(top: Optional[int]) -> None:
228232
# Show results from global tracer instance
229233
global _tracer_instance
230234
if _tracer_instance:
231-
_tracer_instance.show_results()
235+
_tracer_instance.show_results(top)
232236
else:
233237
print("[yellow]Tracer was not started.[/]")

0 commit comments

Comments
 (0)