Skip to content

Commit 279a8fc

Browse files
committed
feat: add --memory flag to codeflash compare for peak memory profiling
Adds a second profiling phase using pytest-memray that runs after timing benchmarks. Memory tables are suppressed when the delta is <1%.
1 parent 74c29b2 commit 279a8fc

8 files changed

Lines changed: 495 additions & 4 deletions

File tree

codeflash/benchmarking/compare.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
if TYPE_CHECKING:
2626
from collections.abc import Callable
2727

28-
from codeflash.benchmarking.plugin.plugin import BenchmarkStats
28+
from codeflash.benchmarking.plugin.plugin import BenchmarkStats, MemoryStats
2929
from codeflash.models.function_types import FunctionToOptimize
3030
from codeflash.models.models import BenchmarkKey
3131

@@ -42,6 +42,8 @@ class CompareResult:
4242
head_stats: dict[BenchmarkKey, BenchmarkStats] = field(default_factory=dict)
4343
base_function_ns: dict[str, dict[BenchmarkKey, float]] = field(default_factory=dict)
4444
head_function_ns: dict[str, dict[BenchmarkKey, float]] = field(default_factory=dict)
45+
base_memory: dict[BenchmarkKey, MemoryStats] = field(default_factory=dict)
46+
head_memory: dict[BenchmarkKey, MemoryStats] = field(default_factory=dict)
4547

4648
def format_markdown(self) -> str:
4749
if not self.base_stats and not self.head_stats:
@@ -106,6 +108,29 @@ def sort_key(fn: str, _bm_key: BenchmarkKey = bm_key) -> float:
106108
f"| `{short_name}` | {fmt_us(b)} | {fmt_us(h)} | {md_bar(b, h)} | {md_speedup(b, h)} |"
107109
)
108110

111+
# Memory section (skip when delta is negligible)
112+
base_mem = self.base_memory.get(bm_key)
113+
head_mem = self.head_memory.get(bm_key)
114+
if has_meaningful_memory_change(base_mem, head_mem):
115+
lines.append("")
116+
lines.append("#### Memory")
117+
lines.append("")
118+
lines.append("| Ref | Peak Memory | Allocations | Delta |")
119+
lines.append("|:---|---:|---:|:---|")
120+
if base_mem:
121+
lines.append(
122+
f"| `{base_short}` (base) | {md_bytes(base_mem.peak_memory_bytes)}"
123+
f" | {base_mem.total_allocations:,} | |"
124+
)
125+
if head_mem:
126+
delta = md_memory_delta(
127+
base_mem.peak_memory_bytes if base_mem else None, head_mem.peak_memory_bytes
128+
)
129+
lines.append(
130+
f"| `{head_short}` (head) | {md_bytes(head_mem.peak_memory_bytes)}"
131+
f" | {head_mem.total_allocations:,} | {delta} |"
132+
)
133+
109134
sections.append("\n".join(lines))
110135

111136
sections.append("---\n*Generated by codeflash optimization agent*")
@@ -120,16 +145,23 @@ def compare_branches(
120145
tests_root: Path,
121146
functions: Optional[dict[Path, list[FunctionToOptimize]]] = None,
122147
timeout: int = 600,
148+
memory: bool = False,
123149
) -> CompareResult:
124150
"""Compare benchmark performance between two git refs.
125151
126152
If functions is None, auto-detects changed functions from git diff.
127153
Returns a CompareResult with timing data from both refs.
128154
"""
155+
import sys
156+
129157
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
130158
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
131159
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
132160

161+
if memory and sys.platform == "win32":
162+
logger.error("--memory requires memray which is not available on Windows")
163+
return CompareResult(base_ref=base_ref, head_ref=head_ref)
164+
133165
repo = git.Repo(project_root, search_parent_directories=True)
134166
repo_root = Path(repo.working_dir)
135167

@@ -182,12 +214,17 @@ def compare_branches(
182214
head_worktree = worktree_dirs / f"compare-head-{timestamp}"
183215
base_trace_db = worktree_dirs / f"trace-base-{timestamp}.db"
184216
head_trace_db = worktree_dirs / f"trace-head-{timestamp}.db"
217+
base_memray_dir = worktree_dirs / f"memray-base-{timestamp}"
218+
head_memray_dir = worktree_dirs / f"memray-head-{timestamp}"
219+
memray_prefix = "cf-mem"
185220

186221
result = CompareResult(base_ref=base_ref, head_ref=head_ref)
187222

188223
from rich.console import Group
189224

190225
step_labels = ["Creating worktrees", f"Benchmarking base ({base_short})", f"Benchmarking head ({head_short})"]
226+
if memory:
227+
step_labels.extend([f"Memory profiling base ({base_short})", f"Memory profiling head ({head_short})"])
191228

192229
def build_steps(current_step: int) -> Group:
193230
lines: list[Text] = []
@@ -260,6 +297,18 @@ def build_panel(current_step: int) -> Panel:
260297
trace_fn=trace_benchmarks_pytest,
261298
)
262299

300+
# Steps 4-5: Memory profiling (reuses existing worktrees)
301+
if memory:
302+
from codeflash.benchmarking.trace_benchmarks import memory_benchmarks_pytest
303+
304+
live.update(build_panel(3))
305+
wt_base_benchmarks = base_worktree / benchmarks_root.relative_to(repo_root)
306+
memory_benchmarks_pytest(wt_base_benchmarks, base_worktree, base_memray_dir, memray_prefix, timeout)
307+
308+
live.update(build_panel(4))
309+
wt_head_benchmarks = head_worktree / benchmarks_root.relative_to(repo_root)
310+
memory_benchmarks_pytest(wt_head_benchmarks, head_worktree, head_memray_dir, memray_prefix, timeout)
311+
263312
# Load results
264313
if base_trace_db.exists():
265314
result.base_stats = CodeFlashBenchmarkPlugin.get_benchmark_timings(base_trace_db)
@@ -269,6 +318,14 @@ def build_panel(current_step: int) -> Panel:
269318
result.head_stats = CodeFlashBenchmarkPlugin.get_benchmark_timings(head_trace_db)
270319
result.head_function_ns = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(head_trace_db)
271320

321+
if memory:
322+
from codeflash.benchmarking.plugin.plugin import MemoryStats
323+
324+
if base_memray_dir.exists():
325+
result.base_memory = MemoryStats.parse_memray_results(base_memray_dir, memray_prefix)
326+
if head_memray_dir.exists():
327+
result.head_memory = MemoryStats.parse_memray_results(head_memray_dir, memray_prefix)
328+
272329
# Render comparison
273330
render_comparison(result)
274331

@@ -282,10 +339,16 @@ def build_panel(current_step: int) -> Panel:
282339
remove_worktree(base_worktree)
283340
remove_worktree(head_worktree)
284341
repo.git.worktree("prune")
285-
# Cleanup trace DBs
342+
# Cleanup trace DBs and memray dirs
286343
for db in [base_trace_db, head_trace_db]:
287344
if db.exists():
288345
db.unlink()
346+
if memory:
347+
import shutil
348+
349+
for memray_dir in [base_memray_dir, head_memray_dir]:
350+
if memray_dir.exists():
351+
shutil.rmtree(memray_dir)
289352

290353
return result
291354

@@ -543,6 +606,31 @@ def sort_key(fn: str, _bm_key: BenchmarkKey = bm_key) -> float:
543606

544607
console.print(t2, justify="center")
545608

609+
# Table 3: Memory (skip when delta is negligible)
610+
base_mem = result.base_memory.get(bm_key)
611+
head_mem = result.head_memory.get(bm_key)
612+
if has_meaningful_memory_change(base_mem, head_mem):
613+
console.print()
614+
t3 = Table(title="Memory (peak per test)", border_style="magenta", show_lines=True, expand=False)
615+
t3.add_column("Ref", style="bold cyan")
616+
t3.add_column("Peak Memory", justify="right")
617+
t3.add_column("Allocations", justify="right")
618+
t3.add_column("Delta", justify="right")
619+
620+
if base_mem:
621+
t3.add_row(
622+
f"{base_short} (base)", fmt_bytes(base_mem.peak_memory_bytes), f"{base_mem.total_allocations:,}", ""
623+
)
624+
if head_mem:
625+
delta = fmt_memory_delta(base_mem.peak_memory_bytes if base_mem else None, head_mem.peak_memory_bytes)
626+
t3.add_row(
627+
f"{head_short} (head)",
628+
fmt_bytes(head_mem.peak_memory_bytes),
629+
f"{head_mem.total_allocations:,}",
630+
delta,
631+
)
632+
console.print(t3, justify="center")
633+
546634
console.print()
547635

548636

@@ -641,3 +729,63 @@ def md_bar(before: Optional[float], after: Optional[float], width: int = 10) ->
641729
filled = min(filled, width)
642730
bar = "\u2588" * filled + "\u2591" * (width - filled)
643731
return f"`{bar}` {pct:+.0f}%"
732+
733+
734+
def fmt_bytes(b: Optional[int]) -> str:
735+
if b is None:
736+
return "-"
737+
if b >= 1 << 30:
738+
return f"{b / (1 << 30):,.1f} GiB"
739+
if b >= 1 << 20:
740+
return f"{b / (1 << 20):,.1f} MiB"
741+
if b >= 1 << 10:
742+
return f"{b / (1 << 10):,.1f} KiB"
743+
return f"{b:,} B"
744+
745+
746+
def fmt_memory_delta(before: Optional[int], after: Optional[int]) -> str:
747+
if before is None or after is None or before == 0:
748+
return "-"
749+
pct = ((after - before) / before) * 100
750+
if pct < 0:
751+
return _GREEN_TPL % pct
752+
return _RED_TPL % pct
753+
754+
755+
def md_bytes(b: Optional[int]) -> str:
756+
if b is None:
757+
return "-"
758+
if b >= 1 << 30:
759+
return f"{b / (1 << 30):,.1f} GiB"
760+
if b >= 1 << 20:
761+
return f"{b / (1 << 20):,.1f} MiB"
762+
if b >= 1 << 10:
763+
return f"{b / (1 << 10):,.1f} KiB"
764+
return f"{b:,} B"
765+
766+
767+
def md_memory_delta(before: Optional[int], after: Optional[int]) -> str:
768+
if before is None or after is None or before == 0:
769+
return "-"
770+
pct = ((after - before) / before) * 100
771+
emoji = "\U0001f7e2" if pct <= 0 else "\U0001f534"
772+
return f"{emoji} {pct:+.0f}%"
773+
774+
775+
def has_meaningful_memory_change(
776+
base_mem: Optional[MemoryStats], head_mem: Optional[MemoryStats], threshold_pct: float = 1.0
777+
) -> bool:
778+
"""Return True if peak memory or allocation count changed by more than threshold_pct."""
779+
if base_mem is None or head_mem is None:
780+
return base_mem is not None or head_mem is not None
781+
if base_mem.peak_memory_bytes == 0 and head_mem.peak_memory_bytes == 0:
782+
return False
783+
if base_mem.peak_memory_bytes > 0:
784+
mem_pct = abs((head_mem.peak_memory_bytes - base_mem.peak_memory_bytes) / base_mem.peak_memory_bytes) * 100
785+
if mem_pct > threshold_pct:
786+
return True
787+
if base_mem.total_allocations > 0:
788+
alloc_pct = abs((head_mem.total_allocations - base_mem.total_allocations) / base_mem.total_allocations) * 100
789+
if alloc_pct > threshold_pct:
790+
return True
791+
return False

codeflash/benchmarking/plugin/plugin.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,51 @@ def from_per_iteration_times(times_ns: list[float], iterations: int) -> Benchmar
6868
)
6969

7070

71+
@dataclass
72+
class MemoryStats:
73+
peak_memory_bytes: int
74+
total_allocations: int
75+
76+
@staticmethod
77+
def parse_memray_results(bin_dir: Path, bin_prefix: str) -> dict:
78+
from codeflash.models.models import BenchmarkKey
79+
80+
try:
81+
from memray import FileReader
82+
except ImportError as e:
83+
msg = "memray is required for --memory profiling. Install with: uv add memray pytest-memray"
84+
raise ImportError(msg) from e
85+
86+
results: dict[BenchmarkKey, MemoryStats] = {}
87+
for bin_file in sorted(bin_dir.glob(f"{bin_prefix}-*.bin")):
88+
stem = bin_file.stem
89+
# pytest-memray names: {prefix}-{nodeid with :: and os.sep replaced by -}.bin
90+
nodeid_part = stem[len(bin_prefix) + 1 :] # strip "{prefix}-"
91+
# Extract the test function name (last segment after the final -)
92+
# Node IDs look like: tests-benchmarks-test_file.py-test_func_name
93+
# We need the module_path and function_name for BenchmarkKey
94+
# Split on ".py-" to separate module path from function name
95+
parts = nodeid_part.split(".py-", 1)
96+
if len(parts) == 2:
97+
module_part = parts[0].replace("-", ".")
98+
function_name = parts[1]
99+
else:
100+
module_part = nodeid_part.rsplit("-", 1)[0].replace("-", ".")
101+
function_name = nodeid_part.rsplit("-", 1)[-1] if "-" in nodeid_part else nodeid_part
102+
103+
try:
104+
reader = FileReader(str(bin_file))
105+
meta = reader.metadata
106+
bm_key = BenchmarkKey(module_path=module_part, function_name=function_name)
107+
results[bm_key] = MemoryStats(
108+
peak_memory_bytes=meta.peak_memory, total_allocations=meta.total_allocations
109+
)
110+
reader.close()
111+
except OSError:
112+
continue
113+
return results
114+
115+
71116
class CodeFlashBenchmarkPlugin:
72117
def __init__(self) -> None:
73118
self._trace_path = None
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Subprocess entry point for memory profiling benchmarks via pytest-memray.
2+
3+
Runs pytest with --memray --native to profile peak memory per test function.
4+
The codeflash-benchmark plugin is left active (without --codeflash-trace) so it
5+
provides a no-op ``benchmark`` fixture for tests that depend on it.
6+
"""
7+
8+
import sys
9+
from pathlib import Path
10+
11+
benchmarks_root = sys.argv[1]
12+
memray_bin_dir = sys.argv[2]
13+
memray_bin_prefix = sys.argv[3]
14+
15+
if __name__ == "__main__":
16+
import pytest
17+
18+
Path(memray_bin_dir).mkdir(parents=True, exist_ok=True)
19+
20+
exitcode = pytest.main(
21+
[
22+
benchmarks_root,
23+
"--memray",
24+
"--native",
25+
f"--memray-bin-path={memray_bin_dir}",
26+
f"--memray-bin-prefix={memray_bin_prefix}",
27+
"--hide-memray-summary",
28+
"-p",
29+
"no:benchmark",
30+
"-p",
31+
"no:codspeed",
32+
"-p",
33+
"no:cov",
34+
"-p",
35+
"no:profiling",
36+
"-s",
37+
"-o",
38+
"addopts=",
39+
]
40+
)
41+
42+
sys.exit(exitcode)

codeflash/benchmarking/trace_benchmarks.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,39 @@ def trace_benchmarks_pytest(
4646
error_section = combined_output
4747
logger.warning(f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}")
4848
logger.debug(f"Full pytest output:\n{combined_output}")
49+
50+
51+
def memory_benchmarks_pytest(
52+
benchmarks_root: Path, project_root: Path, memray_bin_dir: Path, memray_bin_prefix: str, timeout: int = 300
53+
) -> None:
54+
benchmark_env = make_env_with_project_root(project_root)
55+
run_args = get_cross_platform_subprocess_run_args(
56+
cwd=project_root, env=benchmark_env, timeout=timeout, check=False, text=True, capture_output=True
57+
)
58+
result = subprocess.run( # noqa: PLW1510
59+
[
60+
SAFE_SYS_EXECUTABLE,
61+
Path(__file__).parent / "pytest_new_process_memory_benchmarks.py",
62+
benchmarks_root,
63+
memray_bin_dir,
64+
memray_bin_prefix,
65+
],
66+
**run_args,
67+
)
68+
if result.returncode != 0:
69+
combined_output = result.stdout
70+
if result.stderr:
71+
combined_output = combined_output + "\n" + result.stderr if combined_output else result.stderr
72+
73+
if "ERROR collecting" in combined_output:
74+
error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
75+
match = re.search(error_pattern, combined_output)
76+
error_section = match.group(1) if match else combined_output
77+
elif "FAILURES" in combined_output:
78+
error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
79+
match = re.search(error_pattern, combined_output)
80+
error_section = match.group(1) if match else combined_output
81+
else:
82+
error_section = combined_output
83+
logger.warning(f"Error collecting memory benchmarks - Pytest Exit code: {result.returncode}, {error_section}")
84+
logger.debug(f"Full pytest output:\n{combined_output}")

codeflash/cli_cmds/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,9 @@ def _build_parser() -> ArgumentParser:
392392
)
393393
compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)")
394394
compare_parser.add_argument("--output", "-o", type=str, help="Write markdown report to file")
395+
compare_parser.add_argument(
396+
"--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)"
397+
)
395398
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
396399

397400
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")

codeflash/cli_cmds/cmd_compare.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def run_compare(args: Namespace) -> None:
7373
tests_root=tests_root,
7474
functions=functions,
7575
timeout=args.timeout,
76+
memory=getattr(args, "memory", False),
7677
)
7778

7879
if not result.base_stats and not result.head_stats:

0 commit comments

Comments
 (0)