diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4e5a453..1e449f3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,9 +25,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest - pip install . + pip install . --group dev - name: Run Pytest run: | - pytest tests/ \ No newline at end of file + pytest tests/ + + - name: Run Mypy + run: | + mypy oracletrace \ No newline at end of file diff --git a/oracletrace/cli.py b/oracletrace/cli.py index 0126e7c..21c9773 100644 --- a/oracletrace/cli.py +++ b/oracletrace/cli.py @@ -2,15 +2,19 @@ import sys import os import json -import argparse import runpy import csv -from .tracer import Tracer -from .compare import compare_traces +from .tracer import Tracer, TracerData +from .compare import compare_traces, ComparisonData +from typing import List, Dict, Any, Optional +from re import Pattern +from argparse import ArgumentParser, Namespace +from pathlib import Path +from dataclasses import asdict -def main(): - parser = argparse.ArgumentParser( +def main() -> int: + parser: ArgumentParser = ArgumentParser( description="OracleTrace - Lightweight execution tracer for Python projects" ) parser.add_argument("target", help="Python script to trace") @@ -39,21 +43,21 @@ def main(): default=5.0, help="Regression threshold percentage used with --fail-on-regression.", ) - args = parser.parse_args() + args: Namespace = parser.parse_args() - target = args.target + target: str = args.target if not os.path.exists(target): print(f"Target not found: {target}") return 1 target = os.path.abspath(target) - root = os.getcwd() - target_dir = os.path.dirname(target) + root: str = os.getcwd() + target_dir: str = os.path.dirname(target) # Setup paths so imports work correctly in the target script sys.path.insert(0, target_dir) - ignored_args = [] if args.ignore is None else args.ignore - ignore_patterns = [] + ignored_args: List[str] = [] if args.ignore is None else args.ignore + ignore_patterns: List[Pattern] = [] for pattern in ignored_args: try: @@ -63,19 +67,19 @@ def main(): return 1 # Start tracing, run the script, then stop - tracer = Tracer(root, ignore_patterns=ignore_patterns) + tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns) tracer.start() try: runpy.run_path(target, run_name="__main__") finally: tracer.stop() - data = tracer.get_trace_data() + data: TracerData = tracer.get_trace_data() # Save json if args.json: with open(args.json, "w", encoding="utf-8") as f: - json.dump(data, f, indent=4) + json.dump(asdict(data), f, indent=4) # Display the analysis if args.top: @@ -86,17 +90,17 @@ def main(): # Export as csv if args.csv: with open(args.csv, "w", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"]) + writer: csv.DictWriter = csv.DictWriter(f, fieldnames=["function", "total_time", "calls", "avg_time"]) writer.writeheader() - for fn in data["functions"]: + for fn in data.functions: writer.writerow({ - "function": fn["name"], - "total_time": fn["total_time"], - "calls": fn["call_count"], - "avg_time": fn["avg_time"], + "function": fn.name, + "total_time": fn.total_time, + "calls": fn.call_count, + "avg_time": fn.avg_time, }) - comparison_result = None + comparison_result: Optional[ComparisonData] = None # Compare jsons if args.compare: @@ -105,11 +109,11 @@ def main(): return 1 with open(args.compare, "r", encoding="utf-8") as f: - old_data = json.load(f) + old_data: TracerData = TracerData.from_dict(json.load(f)) comparison_result = compare_traces(old_data, data, threshold=args.threshold) - if args.fail_on_regression and comparison_result["has_regression"]: + if args.fail_on_regression and comparison_result.has_regression: print( f"Build failed: performance regression above {args.threshold:.2f}% detected." ) diff --git a/oracletrace/compare.py b/oracletrace/compare.py index 6cbecf3..7b7315e 100644 --- a/oracletrace/compare.py +++ b/oracletrace/compare.py @@ -1,19 +1,33 @@ +from .tracer import TracerData, FunctionData from rich import print +from typing import Any, Dict, List, Set, Optional +from dataclasses import dataclass +@dataclass +class RegressionData: + name: str + old_time: float + new_time: float + percent: float -def compare_traces(old_data, new_data, threshold=5.0): - old_funcs = {f["name"]: f for f in old_data["functions"]} - new_funcs = {f["name"]: f for f in new_data["functions"]} +@dataclass +class ComparisonData: + regressions: List[RegressionData] + has_regression: bool - regressions = [] +def compare_traces(old_data: TracerData, new_data: TracerData, threshold: float = 5.0) -> ComparisonData: + old_funcs: Dict[str, FunctionData] = {f.name: f for f in old_data.functions} + new_funcs: Dict[str, FunctionData] = {f.name: f for f in new_data.functions} + + regressions: List[RegressionData] = [] print("\n[bold cyan]Comparison Results:[/]\n") - all_functions = set(old_funcs) | set(new_funcs) + all_functions: Set[str] = set(old_funcs) | set(new_funcs) for name in sorted(all_functions): - old = old_funcs.get(name) - new = new_funcs.get(name) + old: Optional[FunctionData] = old_funcs.get(name) + new: Optional[FunctionData] = new_funcs.get(name) if not old: print(f"[green]+ {name} (new function)[/]") @@ -23,16 +37,16 @@ def compare_traces(old_data, new_data, threshold=5.0): print(f"[red]- {name} (removed)[/]") continue - old_time = old["total_time"] - new_time = new["total_time"] + old_time: float = old.total_time + new_time: float = new.total_time if old_time == 0: continue - diff = new_time - old_time - percent = (diff / old_time) * 100 + diff: float = new_time - old_time + percent: float = (diff / old_time) * 100 - color = "red" if percent > threshold else "green" if percent < -threshold else "yellow" + color: str = "red" if percent > threshold else "green" if percent < -threshold else "yellow" print( f"{name}\n" @@ -42,15 +56,15 @@ def compare_traces(old_data, new_data, threshold=5.0): if percent > threshold: regressions.append( - { - "name": name, - "old_time": old_time, - "new_time": new_time, - "percent": percent, - } + RegressionData( + name = name, + new_time = new_time, + old_time = old_time, + percent = percent + ) ) - return { - "regressions": regressions, - "has_regression": len(regressions) > 0, - } + return ComparisonData( + regressions = regressions, + has_regression = len(regressions) > 0 + ) diff --git a/oracletrace/py.typed b/oracletrace/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/oracletrace/tracer.py b/oracletrace/tracer.py index cb7285b..31fcbe1 100644 --- a/oracletrace/tracer.py +++ b/oracletrace/tracer.py @@ -5,36 +5,66 @@ from rich.tree import Tree from rich.table import Table from rich import print - +from typing import List, Optional, Callable, DefaultDict, Any, Tuple, Dict +from re import Pattern +from pathlib import Path +from types import FrameType +from dataclasses import dataclass + +@dataclass +class TracerMetadata: + root_path: str + total_functions: int + total_execution_time: float + +@dataclass +class FunctionData: + name: str + total_time: float + call_count: int + avg_time: float + callees: List[str] + +@dataclass +class TracerData: + metadata: TracerMetadata + functions: List[FunctionData] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TracerData": + return cls( + metadata=TracerMetadata(**data["metadata"]), + functions=[FunctionData(**f) for f in data["functions"]], + ) class Tracer: - def __init__(self, root_dir, ignore_patterns = None): - self._root_path = os.path.abspath(root_dir) - self._call_stack = [] - self._func_calls = defaultdict(int) - self._func_time = defaultdict(float) - self._call_map = defaultdict(lambda: defaultdict(int)) - self._original_profile_func = None - self._enabled = False - self._start_time = 0.0 - self._total_time = 0.0 - self._ignore_patterns = ignore_patterns - - - def start(self): + def __init__(self, root_dir: str, ignore_patterns: Optional[List[Pattern]] = None) -> None: + self._root_path: str = os.path.abspath(root_dir) + self._call_stack: List[Tuple[int, str, float]] = [] + self._func_calls: DefaultDict[str, int] = defaultdict(int) + self._func_time: DefaultDict[str, float] = defaultdict(float) + self._call_map: DefaultDict[str, DefaultDict[str, int]] = defaultdict(lambda: defaultdict(int)) + self._original_profile_func: Optional[Callable[[FrameType, str, Any], object]] = None + self._enabled: bool = False + self._start_time: float = 0.0 + self._total_time: float = 0.0 + self._ignore_patterns: Optional[List[Pattern]] = ignore_patterns + + + def start(self) -> None: # Start Tracer self._enabled = True self._start_time = time.perf_counter() self._original_profile_func = sys.getprofile() sys.setprofile(self._trace) - def stop(self): + def stop(self) -> None: # Stops Tracer self._enabled = False self._total_time = time.perf_counter() - self._start_time sys.setprofile(self._original_profile_func) - def _is_ignored(self, filename): + def _is_ignored(self, filename: str) -> bool: # Return true if the filename should be ignored if not self._ignore_patterns: return False @@ -45,28 +75,28 @@ def _is_ignored(self, filename): return False - def _is_user_code(self, filename): + def _is_user_code(self, filename: str) -> bool: # Filter out files not in the project root - if not filename.startswith(self._root_path): + if not filename.startswith(str(self._root_path)): return False # Filter out third-party libraries if "site-packages" in filename or "dist-packages" in filename: return False return True - def _get_key(self, frame): - co_filename = frame.f_code.co_filename + def _get_key(self, frame: FrameType) -> Optional[str]: + co_filename: str = frame.f_code.co_filename # Ignore internal python frames (e.g. ) if co_filename.startswith("<"): return None - filename = os.path.abspath(co_filename) + filename: str = os.path.abspath(co_filename) # Check if the file belongs to the user's project if not self._is_user_code(filename): return None # Create a relative path key for readability - rel_path = os.path.relpath(filename, self._root_path) - qualname = getattr(frame.f_code, "co_qualname", frame.f_code.co_name) - key = f"{rel_path}:{qualname}" + rel_path: str = os.path.relpath(filename, self._root_path) + qualname: str = getattr(frame.f_code, "co_qualname", frame.f_code.co_name) + key: str = f"{rel_path}:{qualname}" # Check if the file should be ignored based on inputted ignoring pattern if self._is_ignored(key): @@ -74,17 +104,17 @@ def _get_key(self, frame): return key - def _trace(self, frame, event, arg): + def _trace(self, frame: FrameType, event: str, _: Any) -> None: try: if not self._enabled: return if event == "call": - key = self._get_key(frame) + key: Optional[str] = self._get_key(frame) if not key: return - caller = self._call_stack[-1][1] if self._call_stack else "" + caller: str = self._call_stack[-1][1] if self._call_stack else "" self._call_map[caller][key] += 1 self._func_calls[key] += 1 self._call_stack.append((id(frame), key, time.perf_counter())) @@ -99,8 +129,8 @@ def _trace(self, frame, event, arg): self._func_time[key] += time.perf_counter() - start else: # Stack unwinding (handle exceptions or missed returns) - fid = id(frame) - found = False + fid: int = id(frame) + found: bool = False # Search for the frame in the stack from top to bottom for i in range(len(self._call_stack) - 1, -1, -1): if self._call_stack[i][0] == fid: @@ -117,7 +147,7 @@ def _trace(self, frame, event, arg): except Exception as e: print(f"[bold red]Error in oracletrace tracer: {e}[/]", file=sys.stderr) - def show_results(self, _top): + def show_results(self, _top: Optional[int]) -> None: if not self._func_calls: print("[yellow]No calls traced.[/]") return @@ -125,7 +155,7 @@ def show_results(self, _top): # Summary table print("[bold green]Summary:[/]") print(f"[bold cyan]Total execution time: {self._total_time:.4f}s[/]") - table = Table(title="Top functions by Total Time") + table: Table = Table(title="Top functions by Total Time") table.add_column("Function", justify="left", style="cyan", no_wrap=True) table.add_column("Total Time (s)", justify="right", style="magenta") table.add_column("Calls", justify="right", style="green") @@ -145,11 +175,11 @@ def show_results(self, _top): print("\n[bold green]Logic Flow:[/]") - tree = Tree("[bold yellow][/]") + tree: Tree = Tree("[bold yellow][/]") # Recursively build the execution tree - def add_nodes(parent_node, parent_key, current_path): - children = self._call_map.get(parent_key, {}) + def add_nodes(parent_node: Tree, parent_key: str, current_path: set[str]) -> None: + children: DefaultDict[str,int] = self._call_map.get(parent_key, defaultdict(int)) # Sort children by total execution time sorted_children = sorted( children.items(), @@ -171,37 +201,39 @@ def add_nodes(parent_node, parent_key, current_path): add_nodes(tree, "", {""}) print(tree) - def get_trace_data(self): - functions = [] + def get_trace_data(self) -> TracerData: + functions: List[FunctionData] = [] for key, total_time in self._func_time.items(): calls = self._func_calls[key] avg_time = total_time / calls if calls else 0 functions.append( - { - "name": key, - "total_time": total_time, - "call_count": calls, - "avg_time": avg_time, - "callees": list(self._call_map.get(key, {}).keys()), - } + FunctionData( + name = key, + total_time = total_time, + call_count = calls, + avg_time = avg_time, + callees = list(self._call_map.get(key, {}).keys()), + ) ) - return { - "metadata": { - "root_path": self._root_path, - "total_functions": len(functions), - "total_execution_time": self._total_time, - }, - "functions": functions, - } + metadata: TracerMetadata = TracerMetadata( + root_path = self._root_path, + total_functions = len(functions), + total_execution_time = self._total_time, + ) + + return TracerData( + metadata=metadata, + functions=functions, + ) -_tracer_instance = None +_tracer_instance: Optional[Tracer] = None -def start_trace(root_dir): +def start_trace(root_dir: str) -> None: # Starts tracer instance global _tracer_instance if _tracer_instance is not None: @@ -211,7 +243,7 @@ def start_trace(root_dir): _tracer_instance.start() -def stop_trace(): +def stop_trace() -> Optional[TracerData]: # Stops tracer instance global _tracer_instance if _tracer_instance: @@ -224,10 +256,10 @@ def stop_trace(): return None -def show_results(): +def show_results(top: Optional[int]) -> None: # Show results from global tracer instance global _tracer_instance if _tracer_instance: - _tracer_instance.show_results() + _tracer_instance.show_results(top) else: print("[yellow]Tracer was not started.[/]") diff --git a/pyproject.toml b/pyproject.toml index 5d39e6c..66d9b50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,9 @@ license = { file = "LICENSE" } requires-python = ">=3.10" dependencies = ["rich"] +[dependency-groups] +dev = ["mypy", "pytest"] + keywords = [ "python profiler", "performance regression", diff --git a/tests/test_cli.py b/tests/test_cli.py index 8eb6fb9..8d57ed0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,7 +2,9 @@ import importlib import sys from pathlib import Path - +from oracletrace.tracer import TracerData, FunctionData, TracerMetadata +from oracletrace.compare import ComparisonData +from dataclasses import asdict import pytest @@ -15,24 +17,41 @@ @pytest.fixture -def trace_data(): - return { - "functions": [ - { - "name": "foo", - "total_time": 1.5, - "call_count": 3, - "avg_time": 0.5, - }, - { - "name": "bar", - "total_time": 2.0, - "call_count": 2, - "avg_time": 1.0, - }, +def trace_data() -> TracerData: + return TracerData( + metadata = TracerMetadata( + total_execution_time = 3.0, + total_functions = 2, + root_path = str(REPO_ROOT) + ), + functions = [ + FunctionData( + name = "foo", + total_time = 1.5, + call_count = 3, + avg_time = 0.5, + callees=[] + ), + FunctionData( + name = "bar", + total_time = 2.0, + call_count = 2, + avg_time = 1.0, + callees=[] + ) ] - } - + ) + +@pytest.fixture +def empty_trace_data() -> TracerData: + return TracerData( + metadata = TracerMetadata( + total_execution_time = 0.0, + total_functions = 0, + root_path = str(REPO_ROOT) + ), + functions = [] + ) class FakeTracer: def __init__(self, root, ignore_patterns, data): @@ -125,7 +144,7 @@ def fake_run_path(path, run_name): assert run_path_calls == [(str(target.resolve()), "__main__")] loaded_json = json.loads(json_output.read_text(encoding="utf-8")) - assert loaded_json == trace_data + assert TracerData.from_dict(loaded_json) == trace_data csv_text = csv_output.read_text(encoding="utf-8") assert "function,total_time,calls,avg_time" in csv_text @@ -173,20 +192,20 @@ def test_main_returns_1_when_compare_file_not_found(monkeypatch, tmp_path, trace assert "Compare file not found:" in captured.out -def test_main_fails_with_exit_2_on_regression(monkeypatch, tmp_path, trace_data, capsys): +def test_main_fails_with_exit_2_on_regression(monkeypatch, tmp_path, empty_trace_data, capsys): target = tmp_path / "target.py" target.write_text("print('hello')\n", encoding="utf-8") compare_file = tmp_path / "baseline.json" - compare_file.write_text(json.dumps({"functions": []}), encoding="utf-8") + compare_file.write_text(json.dumps(asdict(empty_trace_data)), encoding="utf-8") - monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, trace_data)) + monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, empty_trace_data)) monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None) compare_calls = [] def fake_compare_traces(old_data, new_data, threshold): compare_calls.append((old_data, new_data, threshold)) - return {"has_regression": True} + return ComparisonData(regressions=[], has_regression=True) monkeypatch.setattr(cli, "compare_traces", fake_compare_traces) @@ -213,11 +232,11 @@ def test_main_returns_0_when_no_regression(monkeypatch, tmp_path, trace_data): target = tmp_path / "target.py" target.write_text("print('hello')\n", encoding="utf-8") compare_file = tmp_path / "baseline.json" - compare_file.write_text(json.dumps({"functions": []}), encoding="utf-8") + compare_file.write_text(json.dumps(asdict(trace_data)), encoding="utf-8") monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, trace_data)) monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None) - monkeypatch.setattr(cli, "compare_traces", lambda old_data, new_data, threshold: {"has_regression": False}) + monkeypatch.setattr(cli, "compare_traces", lambda old_data, new_data, threshold: ComparisonData(regressions=[], has_regression=False)) exit_code = _run_cli( monkeypatch,