Skip to content

Commit 21fb4ae

Browse files
committed
refactor: type-hinting & PEP 561
1 parent 881a8ca commit 21fb4ae

6 files changed

Lines changed: 202 additions & 130 deletions

File tree

.github/workflows/tests.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@ jobs:
2525
- name: Install dependencies
2626
run: |
2727
python -m pip install --upgrade pip
28-
pip install pytest
29-
pip install .
28+
pip install . --group dev
3029
3130
- name: Run Pytest
3231
run: |
33-
pytest tests/
32+
pytest tests/
33+
34+
- name: Run Mypy
35+
run: |
36+
mypy oracletrace

oracletrace/cli.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22
import sys
33
import os
44
import json
5-
import argparse
65
import runpy
76
import csv
8-
from .tracer import Tracer
9-
from .compare import compare_traces
7+
from .tracer import Tracer, TracerData
8+
from .compare import compare_traces, ComparisonData
9+
from typing import List, Dict, Any, Optional
10+
from re import Pattern
11+
from argparse import ArgumentParser, Namespace
12+
from pathlib import Path
13+
from dataclasses import asdict
1014

1115

12-
def main():
13-
parser = argparse.ArgumentParser(
16+
def main() -> int:
17+
parser: ArgumentParser = ArgumentParser(
1418
description="OracleTrace - Lightweight execution tracer for Python projects"
1519
)
1620
parser.add_argument("target", help="Python script to trace")
@@ -39,21 +43,21 @@ def main():
3943
default=5.0,
4044
help="Regression threshold percentage used with --fail-on-regression.",
4145
)
42-
args = parser.parse_args()
46+
args: Namespace = parser.parse_args()
4347

44-
target = args.target
48+
target: str = args.target
4549

4650
if not os.path.exists(target):
4751
print(f"Target not found: {target}")
4852
return 1
4953

5054
target = os.path.abspath(target)
51-
root = os.getcwd()
52-
target_dir = os.path.dirname(target)
55+
root: str = os.getcwd()
56+
target_dir: str = os.path.dirname(target)
5357
# Setup paths so imports work correctly in the target script
5458
sys.path.insert(0, target_dir)
55-
ignored_args = [] if args.ignore is None else args.ignore
56-
ignore_patterns = []
59+
ignored_args: List[str] = [] if args.ignore is None else args.ignore
60+
ignore_patterns: List[Pattern] = []
5761

5862
for pattern in ignored_args:
5963
try:
@@ -63,19 +67,19 @@ def main():
6367
return 1
6468

6569
# Start tracing, run the script, then stop
66-
tracer = Tracer(root, ignore_patterns=ignore_patterns)
70+
tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns)
6771
tracer.start()
6872
try:
6973
runpy.run_path(target, run_name="__main__")
7074
finally:
7175
tracer.stop()
7276

73-
data = tracer.get_trace_data()
77+
data: TracerData = tracer.get_trace_data()
7478

7579
# Save json
7680
if args.json:
7781
with open(args.json, "w", encoding="utf-8") as f:
78-
json.dump(data, f, indent=4)
82+
json.dump(asdict(data), f, indent=4)
7983

8084
# Display the analysis
8185
if args.top:
@@ -86,17 +90,17 @@ def main():
8690
# Export as csv
8791
if args.csv:
8892
with open(args.csv, "w", newline="", encoding="utf-8") as f:
89-
writer = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"])
93+
writer: csv.DictWriter = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"])
9094
writer.writeheader()
91-
for fn in data["functions"]:
95+
for fn in data.functions:
9296
writer.writerow({
93-
"function": fn["name"],
94-
"total_time": fn["total_time"],
95-
"calls": fn["call_count"],
96-
"avg_time": fn["avg_time"],
97+
"function": fn.name,
98+
"total_time": fn.total_time,
99+
"calls": fn.call_count,
100+
"avg_time": fn.avg_time,
97101
})
98102

99-
comparison_result = None
103+
comparison_result: Optional[ComparisonData] = None
100104

101105
# Compare jsons
102106
if args.compare:
@@ -105,11 +109,11 @@ def main():
105109
return 1
106110

107111
with open(args.compare, "r", encoding="utf-8") as f:
108-
old_data = json.load(f)
112+
old_data: TracerData = TracerData.from_dict(json.load(f))
109113

110114
comparison_result = compare_traces(old_data, data, threshold=args.threshold)
111115

112-
if args.fail_on_regression and comparison_result["has_regression"]:
116+
if args.fail_on_regression and comparison_result.has_regression:
113117
print(
114118
f"Build failed: performance regression above {args.threshold:.2f}% detected."
115119
)

oracletrace/compare.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
1+
from .tracer import TracerData, FunctionData
12
from rich import print
3+
from typing import Any, Dict, List, Set, Optional
4+
from dataclasses import dataclass
25

6+
@dataclass
7+
class RegressionData:
8+
name: str
9+
old_time: float
10+
new_time: float
11+
percent: float
312

4-
def compare_traces(old_data, new_data, threshold=5.0):
5-
old_funcs = {f["name"]: f for f in old_data["functions"]}
6-
new_funcs = {f["name"]: f for f in new_data["functions"]}
13+
@dataclass
14+
class ComparisonData:
15+
regressions: List[RegressionData]
16+
has_regression: bool
717

8-
regressions = []
18+
def compare_traces(old_data: TracerData, new_data: TracerData, threshold: float = 5.0) -> ComparisonData:
19+
old_funcs: Dict[str, FunctionData] = {f.name: f for f in old_data.functions}
20+
new_funcs: Dict[str, FunctionData] = {f.name: f for f in new_data.functions}
21+
22+
regressions: List[RegressionData] = []
923

1024
print("\n[bold cyan]Comparison Results:[/]\n")
1125

12-
all_functions = set(old_funcs) | set(new_funcs)
26+
all_functions: Set[str] = set(old_funcs) | set(new_funcs)
1327

1428
for name in sorted(all_functions):
15-
old = old_funcs.get(name)
16-
new = new_funcs.get(name)
29+
old: Optional[FunctionData] = old_funcs.get(name)
30+
new: Optional[FunctionData] = new_funcs.get(name)
1731

1832
if not old:
1933
print(f"[green]+ {name} (new function)[/]")
@@ -23,16 +37,16 @@ def compare_traces(old_data, new_data, threshold=5.0):
2337
print(f"[red]- {name} (removed)[/]")
2438
continue
2539

26-
old_time = old["total_time"]
27-
new_time = new["total_time"]
40+
old_time: float = old.total_time
41+
new_time: float = new.total_time
2842

2943
if old_time == 0:
3044
continue
3145

32-
diff = new_time - old_time
33-
percent = (diff / old_time) * 100
46+
diff: float = new_time - old_time
47+
percent: float = (diff / old_time) * 100
3448

35-
color = "red" if percent > threshold else "green" if percent < -threshold else "yellow"
49+
color: str = "red" if percent > threshold else "green" if percent < -threshold else "yellow"
3650

3751
print(
3852
f"{name}\n"
@@ -42,15 +56,15 @@ def compare_traces(old_data, new_data, threshold=5.0):
4256

4357
if percent > threshold:
4458
regressions.append(
45-
{
46-
"name": name,
47-
"old_time": old_time,
48-
"new_time": new_time,
49-
"percent": percent,
50-
}
59+
RegressionData(
60+
name = name,
61+
new_time = new_time,
62+
old_time = old_time,
63+
percent = percent
64+
)
5165
)
5266

53-
return {
54-
"regressions": regressions,
55-
"has_regression": len(regressions) > 0,
56-
}
67+
return ComparisonData(
68+
regressions = regressions,
69+
has_regression = len(regressions) > 0
70+
)

oracletrace/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)