Skip to content

Commit 267d3eb

Browse files
committed
refactor: type-hinting & PEP 561
1 parent 46f223f commit 267d3eb

5 files changed

Lines changed: 81 additions & 74 deletions

File tree

.github/workflows/tests.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ jobs:
2525
- name: Install dependencies
2626
run: |
2727
python -m pip install --upgrade pip
28-
pip install pytest
29-
pip install .
28+
pip install . --group dev
3029
3130
- name: Run Pytest
3231
run: |
33-
pytest tests/
32+
pytest tests/
33+
34+
- name: Run Mypy
35+
run: |
36+
mypy oracletrace

oracletrace/cli.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
import sys
33
import os
44
import json
5-
import argparse
65
import runpy
76
import csv
87
from .tracer import Tracer
98
from .compare import compare_traces
10-
11-
12-
def main():
13-
parser = argparse.ArgumentParser(
9+
from typing import List, Dict, Any, Optional
10+
from re import Pattern
11+
from argparse import ArgumentParser, Namespace
12+
from pathlib import Path
13+
def main() -> int:
14+
parser: ArgumentParser = ArgumentParser(
1415
description="OracleTrace - Lightweight execution tracer for Python projects"
1516
)
1617
parser.add_argument("target", help="Python script to trace")
@@ -39,21 +40,21 @@ def main():
3940
default=5.0,
4041
help="Regression threshold percentage used with --fail-on-regression.",
4142
)
42-
args = parser.parse_args()
43+
args: Namespace = parser.parse_args()
4344

44-
target = args.target
45+
target: str = args.target
4546

4647
if not os.path.exists(target):
4748
print(f"Target not found: {target}")
4849
return 1
4950

5051
target = os.path.abspath(target)
51-
root = os.getcwd()
52-
target_dir = os.path.dirname(target)
52+
root: str = os.getcwd()
53+
target_dir: str = os.path.dirname(target)
5354
# Setup paths so imports work correctly in the target script
5455
sys.path.insert(0, target_dir)
55-
ignored_args = [] if args.ignore is None else args.ignore
56-
ignore_patterns = []
56+
ignored_args: List[str] = [] if args.ignore is None else args.ignore
57+
ignore_patterns: List[Pattern] = []
5758

5859
for pattern in ignored_args:
5960
try:
@@ -63,14 +64,14 @@ def main():
6364
return 1
6465

6566
# Start tracing, run the script, then stop
66-
tracer = Tracer(root, ignore_patterns=ignore_patterns)
67+
tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns)
6768
tracer.start()
6869
try:
6970
runpy.run_path(target, run_name="__main__")
7071
finally:
7172
tracer.stop()
7273

73-
data = tracer.get_trace_data()
74+
data: Dict[str, Any] = tracer.get_trace_data()
7475

7576
# Save json
7677
if args.json:
@@ -86,7 +87,7 @@ def main():
8687
# Export as csv
8788
if args.csv:
8889
with open(args.csv, "w", newline="", encoding="utf-8") as f:
89-
writer = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"])
90+
writer: csv.DictWriter = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"])
9091
writer.writeheader()
9192
for fn in data["functions"]:
9293
writer.writerow({
@@ -96,7 +97,7 @@ def main():
9697
"avg_time": fn["avg_time"],
9798
})
9899

99-
comparison_result = None
100+
comparison_result: Optional[Dict[str,Any]] = None
100101

101102
# Compare jsons
102103
if args.compare:
@@ -105,7 +106,7 @@ def main():
105106
return 1
106107

107108
with open(args.compare, "r", encoding="utf-8") as f:
108-
old_data = json.load(f)
109+
old_data: Dict[str, Any] = json.load(f)
109110

110111
comparison_result = compare_traces(old_data, data, threshold=args.threshold)
111112

oracletrace/compare.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from rich import print
2+
from typing import Any, Dict, List, Set, Optional
23

34

4-
def compare_traces(old_data, new_data, threshold=5.0):
5-
old_funcs = {f["name"]: f for f in old_data["functions"]}
6-
new_funcs = {f["name"]: f for f in new_data["functions"]}
5+
def compare_traces(old_data: Dict[str, Any], new_data: Dict[str, Any], threshold: float = 5.0) -> Dict[str, Any]:
6+
old_funcs: Dict[str, Any] = {f["name"]: f for f in old_data["functions"]}
7+
new_funcs: Dict[str, Any] = {f["name"]: f for f in new_data["functions"]}
78

8-
regressions = []
9+
regressions: List[Dict[str, Any]] = []
910

1011
print("\n[bold cyan]Comparison Results:[/]\n")
1112

12-
all_functions = set(old_funcs) | set(new_funcs)
13+
all_functions: Set[str] = set(old_funcs) | set(new_funcs)
1314

1415
for name in sorted(all_functions):
15-
old = old_funcs.get(name)
16-
new = new_funcs.get(name)
16+
old: Optional[None] = old_funcs.get(name)
17+
new: Optional[None] = new_funcs.get(name)
1718

1819
if not old:
1920
print(f"[green]+ {name} (new function)[/]")
@@ -23,16 +24,16 @@ def compare_traces(old_data, new_data, threshold=5.0):
2324
print(f"[red]- {name} (removed)[/]")
2425
continue
2526

26-
old_time = old["total_time"]
27-
new_time = new["total_time"]
27+
old_time: float = old["total_time"]
28+
new_time: float = new["total_time"]
2829

2930
if old_time == 0:
3031
continue
3132

32-
diff = new_time - old_time
33-
percent = (diff / old_time) * 100
33+
diff: float = new_time - old_time
34+
percent: float = (diff / old_time) * 100
3435

35-
color = "red" if percent > threshold else "green" if percent < -threshold else "yellow"
36+
color: str = "red" if percent > threshold else "green" if percent < -threshold else "yellow"
3637

3738
print(
3839
f"{name}\n"

oracletrace/py.typed

Whitespace-only changes.

oracletrace/tracer.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,38 @@
55
from rich.tree import Tree
66
from rich.table import Table
77
from rich import print
8-
9-
8+
from typing import List, Optional, Callable, DefaultDict, Any, Tuple, Dict
9+
from re import Pattern
10+
from pathlib import Path
11+
from types import FrameType
1012
class Tracer:
11-
def __init__(self, root_dir, ignore_patterns = None):
12-
self._root_path = os.path.abspath(root_dir)
13-
self._call_stack = []
14-
self._func_calls = defaultdict(int)
15-
self._func_time = defaultdict(float)
16-
self._call_map = defaultdict(lambda: defaultdict(int))
17-
self._original_profile_func = None
18-
self._enabled = False
19-
self._start_time = 0.0
20-
self._total_time = 0.0
21-
self._ignore_patterns = ignore_patterns
22-
23-
24-
def start(self):
13+
def __init__(self, root_dir: str, ignore_patterns: Optional[List[Pattern]] = None) -> None:
14+
self._root_path: str = os.path.abspath(root_dir)
15+
self._call_stack: List[Tuple[int, str, float]] = []
16+
self._func_calls: DefaultDict[str, int] = defaultdict(int)
17+
self._func_time: DefaultDict[str, float] = defaultdict(float)
18+
self._call_map: DefaultDict[str, DefaultDict[str, int]] = defaultdict(lambda: defaultdict(int))
19+
self._original_profile_func: Optional[Callable[[FrameType, str, Any], object]] = None
20+
self._enabled: bool = False
21+
self._start_time: float = 0.0
22+
self._total_time: float = 0.0
23+
self._ignore_patterns: Optional[List[Pattern]] = ignore_patterns
24+
25+
26+
def start(self) -> None:
2527
# Start Tracer
2628
self._enabled = True
2729
self._start_time = time.perf_counter()
2830
self._original_profile_func = sys.getprofile()
2931
sys.setprofile(self._trace)
3032

31-
def stop(self):
33+
def stop(self) -> None:
3234
# Stops Tracer
3335
self._enabled = False
3436
self._total_time = time.perf_counter() - self._start_time
3537
sys.setprofile(self._original_profile_func)
3638

37-
def _is_ignored(self, filename):
39+
def _is_ignored(self, filename: str) -> bool:
3840
# Return true if the filename should be ignored
3941
if not self._ignore_patterns:
4042
return False
@@ -45,46 +47,46 @@ def _is_ignored(self, filename):
4547

4648
return False
4749

48-
def _is_user_code(self, filename):
50+
def _is_user_code(self, filename: str) -> bool:
4951
# Filter out files not in the project root
50-
if not filename.startswith(self._root_path):
52+
if not filename.startswith(str(self._root_path)):
5153
return False
5254
# Filter out third-party libraries
5355
if "site-packages" in filename or "dist-packages" in filename:
5456
return False
5557
return True
5658

57-
def _get_key(self, frame):
58-
co_filename = frame.f_code.co_filename
59+
def _get_key(self, frame: FrameType) -> Optional[str]:
60+
co_filename: str = frame.f_code.co_filename
5961
# Ignore internal python frames (e.g. <string>)
6062
if co_filename.startswith("<"):
6163
return None
62-
filename = os.path.abspath(co_filename)
64+
filename: str = os.path.abspath(co_filename)
6365
# Check if the file belongs to the user's project
6466
if not self._is_user_code(filename):
6567
return None
6668
# Create a relative path key for readability
67-
rel_path = os.path.relpath(filename, self._root_path)
68-
qualname = getattr(frame.f_code, "co_qualname", frame.f_code.co_name)
69-
key = f"{rel_path}:{qualname}"
69+
rel_path: str = os.path.relpath(filename, self._root_path)
70+
qualname: str = getattr(frame.f_code, "co_qualname", frame.f_code.co_name)
71+
key: str = f"{rel_path}:{qualname}"
7072

7173
# Check if the file should be ignored based on inputted ignoring pattern
7274
if self._is_ignored(key):
7375
return None
7476

7577
return key
7678

77-
def _trace(self, frame, event, arg):
79+
def _trace(self, frame: FrameType, event: str, _: Any) -> None:
7880
try:
7981
if not self._enabled:
8082
return
8183

8284
if event == "call":
83-
key = self._get_key(frame)
85+
key: Optional[str] = self._get_key(frame)
8486
if not key:
8587
return
8688

87-
caller = self._call_stack[-1][1] if self._call_stack else "<module>"
89+
caller: str = self._call_stack[-1][1] if self._call_stack else "<module>"
8890
self._call_map[caller][key] += 1
8991
self._func_calls[key] += 1
9092
self._call_stack.append((id(frame), key, time.perf_counter()))
@@ -99,8 +101,8 @@ def _trace(self, frame, event, arg):
99101
self._func_time[key] += time.perf_counter() - start
100102
else:
101103
# Stack unwinding (handle exceptions or missed returns)
102-
fid = id(frame)
103-
found = False
104+
fid: int = id(frame)
105+
found: bool = False
104106
# Search for the frame in the stack from top to bottom
105107
for i in range(len(self._call_stack) - 1, -1, -1):
106108
if self._call_stack[i][0] == fid:
@@ -117,15 +119,15 @@ def _trace(self, frame, event, arg):
117119
except Exception as e:
118120
print(f"[bold red]Error in oracletrace tracer: {e}[/]", file=sys.stderr)
119121

120-
def show_results(self, _top):
122+
def show_results(self, _top: Optional[int]) -> None:
121123
if not self._func_calls:
122124
print("[yellow]No calls traced.[/]")
123125
return
124126

125127
# Summary table
126128
print("[bold green]Summary:[/]")
127129
print(f"[bold cyan]Total execution time: {self._total_time:.4f}s[/]")
128-
table = Table(title="Top functions by Total Time")
130+
table: Table = Table(title="Top functions by Total Time")
129131
table.add_column("Function", justify="left", style="cyan", no_wrap=True)
130132
table.add_column("Total Time (s)", justify="right", style="magenta")
131133
table.add_column("Calls", justify="right", style="green")
@@ -145,11 +147,11 @@ def show_results(self, _top):
145147

146148
print("\n[bold green]Logic Flow:[/]")
147149

148-
tree = Tree("[bold yellow]<module>[/]")
150+
tree: Tree = Tree("[bold yellow]<module>[/]")
149151

150152
# Recursively build the execution tree
151-
def add_nodes(parent_node, parent_key, current_path):
152-
children = self._call_map.get(parent_key, {})
153+
def add_nodes(parent_node: Tree, parent_key: str, current_path: set[str]) -> None:
154+
children: DefaultDict[str,int] = self._call_map.get(parent_key, defaultdict(int))
153155
# Sort children by total execution time
154156
sorted_children = sorted(
155157
children.items(),
@@ -171,8 +173,8 @@ def add_nodes(parent_node, parent_key, current_path):
171173
add_nodes(tree, "<module>", {"<module>"})
172174
print(tree)
173175

174-
def get_trace_data(self):
175-
functions = []
176+
def get_trace_data(self) -> Dict[str, Any]:
177+
functions: List[Dict[str, Any]] = []
176178

177179
for key, total_time in self._func_time.items():
178180
calls = self._func_calls[key]
@@ -198,10 +200,10 @@ def get_trace_data(self):
198200
}
199201

200202

201-
_tracer_instance = None
203+
_tracer_instance: Optional[Tracer] = None
202204

203205

204-
def start_trace(root_dir):
206+
def start_trace(root_dir: str) -> None:
205207
# Starts tracer instance
206208
global _tracer_instance
207209
if _tracer_instance is not None:
@@ -211,7 +213,7 @@ def start_trace(root_dir):
211213
_tracer_instance.start()
212214

213215

214-
def stop_trace():
216+
def stop_trace() -> Optional[Dict[str,Any]]:
215217
# Stops tracer instance
216218
global _tracer_instance
217219
if _tracer_instance:
@@ -224,10 +226,10 @@ def stop_trace():
224226
return None
225227

226228

227-
def show_results():
229+
def show_results(top: Optional[int]) -> None:
228230
# Show results from global tracer instance
229231
global _tracer_instance
230232
if _tracer_instance:
231-
_tracer_instance.show_results()
233+
_tracer_instance.show_results(top)
232234
else:
233235
print("[yellow]Tracer was not started.[/]")

0 commit comments

Comments
 (0)