Skip to content

Commit 6371ca5

Browse files
addressing all Required items by claude code review
- Path Traversal Protection - Replace Bare Exception Handlers - Test Discovery Logging
1 parent db501da commit 6371ca5

7 files changed

Lines changed: 109 additions & 17 deletions

File tree

codeflash/benchmarking/trace_benchmarks.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from codeflash.cli_cmds.console import logger
1111
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
1212

13+
# Constant for subprocess drain timeout after killing
14+
SUBPROCESS_DRAIN_TIMEOUT_SECONDS = 5
15+
1316

1417
def trace_benchmarks_pytest(
1518
benchmarks_root: Path, tests_root: Path, project_root: Path, trace_file: Path, timeout: int = 300
@@ -48,7 +51,14 @@ def trace_benchmarks_pytest(
4851
except subprocess.TimeoutExpired:
4952
with contextlib.suppress(OSError):
5053
process.kill()
51-
stdout_content, stderr_content = process.communicate(timeout=5)
54+
try:
55+
# Try to drain remaining output after killing
56+
stdout_content, stderr_content = process.communicate(timeout=SUBPROCESS_DRAIN_TIMEOUT_SECONDS)
57+
except subprocess.TimeoutExpired:
58+
# Last resort: terminate and get partial output
59+
with contextlib.suppress(OSError):
60+
process.terminate()
61+
stdout_content, stderr_content = "", "Process killed after timeout"
5262
raise subprocess.TimeoutExpired(cmd_list, timeout, output=stdout_content, stderr=stderr_content) from None
5363
result = subprocess.CompletedProcess(cmd_list, returncode, stdout_content, stderr_content)
5464
else:

codeflash/code_utils/code_replacer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def replace_function_definitions_in_module(
447447

448448
new_code: str = replace_functions_and_add_imports(
449449
# adding the global assignments before replacing the code, not after
450-
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
450+
# because of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
451451
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
452452
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
453453
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,

codeflash/code_utils/git_worktree_utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
worktree_dirs = codeflash_cache_dir / "worktrees"
2222
patches_dir = codeflash_cache_dir / "patches"
2323

24+
# Constants for Windows retry logic
25+
MAX_WINDOWS_RETRIES = 3
26+
INITIAL_RETRY_DELAY_SECONDS = 0.5
27+
2428

2529
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
2630
repository = git.Repo(worktree_dir, search_parent_directories=True)
@@ -113,14 +117,17 @@ def remove_worktree(worktree_dir: Path) -> None:
113117
return
114118

115119
is_windows = sys.platform == "win32"
116-
max_retries = 3 if is_windows else 1
117-
retry_delay = 0.5 # Start with 500ms delay
120+
max_retries = MAX_WINDOWS_RETRIES if is_windows else 1
121+
retry_delay = INITIAL_RETRY_DELAY_SECONDS
118122

119123
# Try to get the repository and git root for worktree removal
120124
try:
121125
repository = git.Repo(worktree_dir, search_parent_directories=True)
122-
except Exception:
126+
except (git.exc.InvalidGitRepositoryError, OSError, PermissionError) as e:
123127
# If we can't access the repository, try manual cleanup
128+
# Log at debug level since this is expected in some edge cases
129+
from codeflash.cli_cmds.console import logger
130+
logger.debug(f"Could not access git repository at {worktree_dir}: {e}. Attempting manual cleanup.")
124131
_manual_cleanup_worktree_directory(worktree_dir)
125132
return
126133

@@ -140,17 +147,20 @@ def remove_worktree(worktree_dir: Path) -> None:
140147
else:
141148
# Last attempt failed or non-permission error
142149
break
143-
except Exception:
150+
except (OSError, PermissionError) as e:
151+
# Log unexpected errors for debugging
152+
from codeflash.cli_cmds.console import logger
153+
logger.debug(f"Worktree removal attempt {attempt + 1} failed with unexpected error: {e}")
144154
break
145155

146156
# Fallback: Try to remove worktree entry from git, then manually delete directory
147-
with contextlib.suppress(Exception):
157+
with contextlib.suppress(git.exc.GitCommandError, OSError, PermissionError):
148158
# Try to prune the worktree entry from git (this doesn't delete the directory)
149159
# Use git worktree prune to remove stale entries
150160
repository.git.worktree("prune")
151161

152162
# Manually remove the directory (always attempt, even if prune failed)
153-
with contextlib.suppress(Exception):
163+
with contextlib.suppress(OSError, PermissionError):
154164
_manual_cleanup_worktree_directory(worktree_dir)
155165

156166

@@ -178,8 +188,8 @@ def _manual_cleanup_worktree_directory(worktree_dir: Path) -> None:
178188

179189
# Attempt removal with retries on Windows
180190
is_windows = sys.platform == "win32"
181-
max_retries = 3 if is_windows else 1
182-
retry_delay = 0.5
191+
max_retries = MAX_WINDOWS_RETRIES if is_windows else 1
192+
retry_delay = INITIAL_RETRY_DELAY_SECONDS
183193

184194
for attempt in range(max_retries):
185195
attempt_num = attempt + 1

codeflash/discovery/discover_unit_tests.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
331331
# Be conservative except when an alias is used (which requires exact method matching)
332332
for target_func in fnames:
333333
if "." in target_func:
334+
# Split to extract class name; method name is intentionally discarded (leading underscore)
335+
# as we only need to check if the imported class matches the target function's class
334336
class_name, _method_name = target_func.split(".", 1)
335337
if aname == class_name and not alias.asname:
336338
self.found_any_target_function = True
@@ -604,23 +606,38 @@ def discover_tests_pytest(
604606
check=False,
605607
**run_kwargs,
606608
)
607-
except subprocess.TimeoutExpired:
609+
except subprocess.TimeoutExpired as e:
610+
logger.error(
611+
f"Test discovery subprocess timed out after {run_kwargs.get('timeout', 600)} seconds. "
612+
f"Command: {discovery_script}"
613+
)
608614
result = subprocess.CompletedProcess(args=[], returncode=-1, stdout="", stderr="Timeout")
609-
except Exception as e:
615+
except (OSError, subprocess.SubprocessError, ValueError) as e:
616+
logger.error(
617+
f"Test discovery subprocess failed with error: {e}. "
618+
f"Command: {discovery_script}, "
619+
f"Project root: {project_root}, Tests root: {tests_root}"
620+
)
610621
result = subprocess.CompletedProcess(args=[], returncode=-1, stdout="", stderr=str(e))
611622

612623
try:
613624
# Check if pickle file exists before trying to read it
614625
if not tmp_pickle_path.exists():
615626
tests, pytest_rootdir = [], None
616-
logger.warning(
627+
logger.error(
617628
f"Test discovery pickle file not found. "
618-
f"Subprocess return code: {result.returncode}, stdout: {result.stdout}, stderr: {result.stderr}"
629+
f"Subprocess return code: {result.returncode}, stdout: {result.stdout[:500]}, stderr: {result.stderr[:500]}"
619630
)
620631
exitcode = result.returncode if result.returncode != 0 else -1
621632
else:
622633
with tmp_pickle_path.open(mode="rb") as f:
623634
exitcode, tests, pytest_rootdir = pickle.load(f)
635+
# Log error if subprocess failed even though pickle file exists
636+
if exitcode != 0:
637+
logger.error(
638+
f"Test discovery subprocess returned non-zero exit code: {exitcode}. "
639+
f"Subprocess return code: {result.returncode}, stdout: {result.stdout[:500]}, stderr: {result.stderr[:500]}"
640+
)
624641
except Exception as e:
625642
tests, pytest_rootdir = [], None
626643
logger.exception(

codeflash/discovery/functions_to_optimize.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,12 +665,54 @@ def filter_functions(
665665
blocklist_funcs_removed_count: int = 0
666666
previous_checkpoint_functions_removed_count: int = 0
667667

668+
def _validate_path_no_traversal(path: Path | str) -> bool:
669+
"""Validate that a path does not contain path traversal components.
670+
671+
This prevents path traversal attacks by rejecting paths with '..' components.
672+
Paths passed to this function should be from trusted sources (git operations,
673+
file system discovery), but we validate defensively.
674+
675+
Args:
676+
path: Path to validate
677+
678+
Returns:
679+
True if path is safe (no traversal components), False otherwise
680+
"""
681+
path_str = str(path)
682+
# Check for path traversal attempts
683+
if ".." in path_str:
684+
return False
685+
# Check for absolute paths that might escape (additional safety check)
686+
# Note: We allow absolute paths as they're needed for worktree paths
687+
return True
688+
668689
def _resolve_path(path: Path | str) -> Path:
669690
# Use strict=False so we don't fail on paths that don't exist yet (e.g. worktree paths)
691+
# SECURITY: Validate path before resolution to prevent traversal attacks
692+
if not _validate_path_no_traversal(path):
693+
raise ValueError(f"Path contains traversal components: {path}")
670694
return Path(path).resolve(strict=False)
671695

672696
def _resolve_path_consistent(path: Path | str) -> Path:
673-
"""Resolve path consistently: use strict resolution if path exists, otherwise non-strict."""
697+
"""Resolve path consistently: use strict resolution if path exists, otherwise non-strict.
698+
699+
SECURITY: This function validates paths to prevent traversal attacks before resolution.
700+
Paths should come from trusted sources (git operations, file system discovery),
701+
but we validate defensively.
702+
703+
Args:
704+
path: Path to resolve (from trusted sources like git diff or file discovery)
705+
706+
Returns:
707+
Resolved absolute Path
708+
709+
Raises:
710+
ValueError: If path contains traversal components
711+
"""
712+
# SECURITY: Validate path before any resolution to prevent traversal attacks
713+
if not _validate_path_no_traversal(path):
714+
raise ValueError(f"Path contains traversal components: {path}")
715+
674716
path_obj = Path(path)
675717
if path_obj.exists():
676718
try:
@@ -691,6 +733,10 @@ def _resolve_path_consistent(path: Path | str) -> Path:
691733
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
692734
for file_path_path, functions in modified_functions.items():
693735
_functions = functions
736+
# SECURITY: Validate file path before processing to prevent traversal attacks
737+
if not _validate_path_no_traversal(file_path_path):
738+
logger.warning(f"Skipping file with traversal components: {file_path_path}")
739+
continue
694740
# Resolve file path to absolute path
695741
# Convert to Path if it's a string (e.g., from get_functions_within_git_diff)
696742
file_path_obj = Path(file_path_path)

codeflash/lsp/lsp_message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
json_primitive_types = (str, float, int, bool)
1212
max_code_lines_before_collapse = 45
1313

14-
# \\u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
14+
# \\u241F is the message delimiter because it can be more than one message sent over the same message, so we need something to separate each message
1515
message_delimiter = "\\u241F"
1616

1717

codeflash/verification/test_runner.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
BEHAVIORAL_BLOCKLISTED_PLUGINS = ["benchmark", "codspeed", "xdist", "sugar"]
2020
BENCHMARKING_BLOCKLISTED_PLUGINS = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
2121

22+
# Constant for subprocess drain timeout after killing
23+
SUBPROCESS_DRAIN_TIMEOUT_SECONDS = 5
24+
2225

2326
def execute_test_subprocess(
2427
cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600
@@ -74,7 +77,13 @@ def execute_test_subprocess(
7477
process.kill()
7578

7679
# Drain remaining output after killing
77-
stdout_content, stderr_content = process.communicate(timeout=5)
80+
try:
81+
stdout_content, stderr_content = process.communicate(timeout=SUBPROCESS_DRAIN_TIMEOUT_SECONDS)
82+
except subprocess.TimeoutExpired:
83+
# Last resort: terminate and get partial output
84+
with contextlib.suppress(OSError):
85+
process.terminate()
86+
stdout_content, stderr_content = "", "Process killed after timeout"
7887
raise subprocess.TimeoutExpired(
7988
cmd_list, timeout, output=stdout_content, stderr=stderr_content
8089
) from None

0 commit comments

Comments
 (0)