Skip to content

Commit 2a12d86

Browse files
feat: Implement intelligent context compression engine
- Add `compression_level` field to `FileSnapshot` in `models.py`. - Create `CompressionStrategy` interface and implementations (`FullStrategy`, `SkeletonStrategy`, `SignatureStrategy`) in `codesage/snapshot/strategies.py`. - Refactor `SnapshotCompressor` in `codesage/snapshot/compressor.py` to use a greedy budgeting algorithm based on file risk. - Update `ContextBuilder` in `codesage/llm/context_builder.py` to utilize the new compression strategies. - Add unit tests for strategies and token budgeting logic.
1 parent c4c8044 commit 2a12d86

4 files changed

Lines changed: 430 additions & 217 deletions

File tree

codesage/llm/context_builder.py

Lines changed: 25 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import List, Dict, Any, Optional
33

44
from codesage.snapshot.models import ProjectSnapshot, FileSnapshot
5+
from codesage.snapshot.strategies import CompressionStrategyFactory
56

67
class ContextBuilder:
78
def __init__(self, model_name: str = "gpt-4", max_tokens: int = 8000, reserve_tokens: int = 1000):
@@ -22,8 +23,7 @@ def fit_to_window(self,
2223
snapshot: ProjectSnapshot) -> str:
2324
"""
2425
Builds a context string that fits within the token window.
25-
Prioritizes primary files (full content), then reference files (summaries/interfaces),
26-
then truncates if necessary.
26+
Uses the compression_level specified in FileSnapshot to determine content.
2727
"""
2828
available_tokens = self.max_tokens - self.reserve_tokens
2929

@@ -42,49 +42,39 @@ def fit_to_window(self,
4242
context_parts.append(project_context)
4343
current_tokens += tokens
4444

45-
# 2. Add Primary Files
46-
for file in primary_files:
45+
# Combine primary and reference files for processing
46+
# Note: In the new logic, the SnapshotCompressor should have already assigned appropriate levels
47+
# based on global budget. However, ContextBuilder might receive raw snapshots.
48+
# Here we assume we respect the file.compression_level if set.
49+
50+
all_files = primary_files + reference_files
51+
52+
for file in all_files:
4753
content = self._read_file(file.path)
4854
if not content: continue
4955

50-
file_block = f"<file path=\"{file.path}\">\n{content}\n</file>\n"
56+
# Apply compression strategy
57+
strategy = CompressionStrategyFactory.get_strategy(getattr(file, "compression_level", "full"))
58+
processed_content = strategy.compress(content, file.path, file.language)
59+
60+
# Decorate
61+
file_block = f"<file path=\"{file.path}\">\n{processed_content}\n</file>\n"
5162
tokens = self.count_tokens(file_block)
5263

5364
if current_tokens + tokens <= available_tokens:
5465
context_parts.append(file_block)
5566
current_tokens += tokens
5667
else:
57-
# Compression needed
58-
# We try to keep imports and signatures
59-
compressed = self._compress_file(file, content)
60-
tokens = self.count_tokens(compressed)
61-
if current_tokens + tokens <= available_tokens:
62-
context_parts.append(compressed)
63-
current_tokens += tokens
68+
# If even the compressed content doesn't fit, we might need to truncate
69+
# Or stop adding files.
70+
remaining = available_tokens - current_tokens
71+
if remaining > 50:
72+
truncated = processed_content[:(remaining * 3)] + "\n...(truncated due to context limit)"
73+
context_parts.append(f"<file path=\"{file.path}\">\n{truncated}\n</file>\n")
74+
current_tokens += remaining # Approximate
75+
break
6476
else:
65-
# Even compressed is too large, hard truncate
66-
remaining = available_tokens - current_tokens
67-
if remaining > 20: # Ensure at least some chars
68-
chars_limit = remaining * 4
69-
if chars_limit > len(compressed):
70-
chars_limit = len(compressed)
71-
72-
truncated = compressed[:chars_limit] + "\n...(truncated)"
73-
context_parts.append(truncated)
74-
current_tokens += remaining # Stop here
75-
break
76-
else:
77-
break # No space even for truncated
78-
79-
# 3. Add Reference Files (Summaries) if space permits
80-
for file in reference_files:
81-
if current_tokens >= available_tokens: break
82-
83-
summary = self._summarize_file(file)
84-
tokens = self.count_tokens(summary)
85-
if current_tokens + tokens <= available_tokens:
86-
context_parts.append(summary)
87-
current_tokens += tokens
77+
break
8878

8979
return "\n".join(context_parts)
9080

@@ -94,76 +84,3 @@ def _read_file(self, path: str) -> str:
9484
return f.read()
9585
except Exception:
9686
return ""
97-
98-
def _compress_file(self, file_snapshot: FileSnapshot, content: str) -> str:
99-
"""
100-
Retains imports, structs/classes/interfaces, and function signatures.
101-
Removes function bodies.
102-
"""
103-
if not file_snapshot.symbols:
104-
# Fallback: keep first 50 lines
105-
lines = content.splitlines()
106-
return f"<file path=\"{file_snapshot.path}\" compressed=\"true\">\n" + "\n".join(lines[:50]) + "\n... (bodies omitted)\n</file>\n"
107-
108-
lines = content.splitlines()
109-
110-
# Intervals to exclude (function bodies)
111-
exclude_intervals = []
112-
113-
funcs = file_snapshot.symbols.get("functions", [])
114-
115-
for f in funcs:
116-
start = f.get("start_line", 0)
117-
end = f.get("end_line", 0)
118-
if end > start:
119-
# To preserve closing brace if it is on end_line, we exclude up to end_line - 1?
120-
# It depends on where end_line points. Tree-sitter end_point is row/col.
121-
# If end_line is the line index (0-based) where function ends.
122-
# Usually closing brace is on end_line.
123-
124-
# Check if end_line contains ONLY brace.
125-
# If we exclude start+1 to end-1, we keep start and end line.
126-
127-
exclude_start = start + 1
128-
exclude_end = end - 1
129-
130-
if exclude_end >= exclude_start:
131-
exclude_intervals.append((exclude_start, exclude_end)) # inclusive
132-
133-
# Sort intervals
134-
exclude_intervals.sort()
135-
136-
compressed_lines = []
137-
skipping = False
138-
139-
for i, line in enumerate(lines):
140-
is_excluded = False
141-
for start_idx, end_idx in exclude_intervals:
142-
if start_idx <= i <= end_idx: # Excluding body
143-
is_excluded = True
144-
break
145-
146-
if is_excluded:
147-
if not skipping:
148-
compressed_lines.append(" ... (body omitted)")
149-
skipping = True
150-
else:
151-
compressed_lines.append(line)
152-
skipping = False
153-
154-
return f"<file path=\"{file_snapshot.path}\" compressed=\"true\">\n" + "\n".join(compressed_lines) + "\n</file>\n"
155-
156-
def _summarize_file(self, file_snapshot: FileSnapshot) -> str:
157-
lines = [f"File: {file_snapshot.path}"]
158-
if file_snapshot.symbols:
159-
if "functions" in file_snapshot.symbols:
160-
funcs = file_snapshot.symbols["functions"]
161-
lines.append("Functions: " + ", ".join([f['name'] for f in funcs]))
162-
if "structs" in file_snapshot.symbols:
163-
structs = file_snapshot.symbols["structs"]
164-
lines.append("Structs: " + ", ".join([s['name'] for s in structs]))
165-
if "external_commands" in file_snapshot.symbols:
166-
cmds = file_snapshot.symbols["external_commands"]
167-
lines.append("External Commands: " + ", ".join(cmds))
168-
169-
return "\n".join(lines) + "\n"

codesage/snapshot/compressor.py

Lines changed: 100 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,112 @@
1-
import fnmatch
2-
import json
3-
import hashlib
4-
from typing import Any, Dict, List
5-
1+
from typing import Any, Dict, List, Optional
2+
import os
3+
import tiktoken
64
from codesage.snapshot.models import ProjectSnapshot, FileSnapshot
7-
from codesage.analyzers.ast_models import ASTNode
8-
5+
from codesage.snapshot.strategies import CompressionStrategyFactory, FullStrategy
96

107
class SnapshotCompressor:
11-
"""Compresses a ProjectSnapshot to reduce its size."""
8+
"""Compresses a ProjectSnapshot to reduce its token usage for LLM context."""
129

13-
def __init__(self, config: Dict[str, Any]):
14-
self.config = config.get("compression", {})
15-
self.exclude_patterns = self.config.get("exclude_patterns", [])
16-
self.trimming_threshold = self.config.get("trimming_threshold", 1000)
10+
def __init__(self, config: Dict[str, Any] = None):
11+
self.config = config or {}
12+
# Default budget if not specified
13+
self.token_budget = self.config.get("token_budget", 8000)
14+
self.model_name = self.config.get("model_name", "gpt-4")
1715

18-
def compress(self, snapshot: ProjectSnapshot) -> ProjectSnapshot:
19-
"""
20-
Compresses the snapshot by applying various techniques.
16+
try:
17+
self.encoding = tiktoken.encoding_for_model(self.model_name)
18+
except KeyError:
19+
self.encoding = tiktoken.get_encoding("cl100k_base")
20+
21+
def compress_project(self, snapshot: ProjectSnapshot, project_root: str) -> ProjectSnapshot:
2122
"""
22-
compressed_snapshot = snapshot.model_copy(deep=True)
23+
Compresses the snapshot by assigning compression levels to files based on risk and budget.
2324
24-
if self.exclude_patterns:
25-
compressed_snapshot.files = self._exclude_files(
26-
compressed_snapshot.files, self.exclude_patterns
27-
)
25+
Args:
26+
snapshot: The project snapshot.
27+
project_root: The root directory of the project (to read file contents).
2828
29-
compressed_snapshot.files = self._deduplicate_ast_nodes(compressed_snapshot.files)
30-
compressed_snapshot.files = self._trim_large_asts(
31-
compressed_snapshot.files, self.trimming_threshold
29+
Returns:
30+
The modified project snapshot with updated compression_level fields.
31+
"""
32+
# 1. Sort files by Risk Score (Desc)
33+
# Assuming risk.risk_score exists. If not, default to 0.
34+
sorted_files = sorted(
35+
snapshot.files,
36+
key=lambda f: f.risk.risk_score if f.risk else 0.0,
37+
reverse=True
3238
)
3339

34-
return compressed_snapshot
35-
36-
def _exclude_files(
37-
self, files: List[FileSnapshot], patterns: List[str]
38-
) -> List[FileSnapshot]:
39-
"""Filters out files that match the exclude patterns."""
40-
return [
41-
file
42-
for file in files
43-
if not any(fnmatch.fnmatch(file.path, pattern) for pattern in patterns)
44-
]
45-
46-
def _deduplicate_ast_nodes(
47-
self, files: List[FileSnapshot]
48-
) -> List[FileSnapshot]:
49-
"""
50-
Deduplicates AST nodes by replacing identical subtrees with a reference.
51-
This is a simplified implementation. A real one would need a more robust
52-
hashing and reference mechanism.
53-
"""
54-
node_cache = {}
55-
for file in files:
56-
if file.ast_summary: # Assuming ast_summary holds the AST
57-
self._traverse_and_deduplicate(file.ast_summary, node_cache)
58-
return files
59-
60-
def _traverse_and_deduplicate(self, node: ASTNode, cache: Dict[str, ASTNode]):
61-
"""Recursively traverses the AST and deduplicates nodes."""
62-
if not isinstance(node, ASTNode):
63-
return
64-
65-
node_hash = self._hash_node(node)
66-
if node_hash in cache:
67-
# Replace with a reference to the cached node
68-
# This is a conceptual implementation. In practice, you might
69-
# store the canonical node in a separate structure and use IDs.
70-
node = cache[node_hash]
71-
return
72-
73-
cache[node_hash] = node
74-
for i, child in enumerate(node.children):
75-
node.children[i] = self._traverse_and_deduplicate(child, cache)
76-
77-
def _hash_node(self, node: ASTNode) -> str:
78-
"""Creates a stable hash for an AST node."""
79-
# A simple hash based on type and value. A real implementation
80-
# should be more robust, considering children as well.
81-
hasher = hashlib.md5()
82-
hasher.update(node.node_type.encode())
83-
if node.value:
84-
hasher.update(str(node.value).encode())
85-
return hasher.hexdigest()
86-
87-
def _trim_large_asts(
88-
self, files: List[FileSnapshot], threshold: int
89-
) -> List[FileSnapshot]:
90-
"""Trims the AST of very large files to save space."""
91-
for file in files:
92-
if file.lines > threshold and file.ast_summary:
93-
self._traverse_and_trim(file.ast_summary)
94-
return files
95-
96-
def _traverse_and_trim(self, node: ASTNode):
40+
# 2. Initial pass: Estimate costs for different levels
41+
file_costs = {} # {file_path: {level: token_count}}
42+
43+
# We need to read files.
44+
for file in sorted_files:
45+
file_path = os.path.join(project_root, file.path)
46+
try:
47+
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
48+
content = f.read()
49+
except Exception:
50+
content = "" # Should we handle missing files?
51+
52+
# Calculate costs for all strategies
53+
costs = {}
54+
for level in ["full", "skeleton", "signature"]:
55+
strategy = CompressionStrategyFactory.get_strategy(level)
56+
compressed_content = strategy.compress(content, file.path, file.language)
57+
costs[level] = len(self.encoding.encode(compressed_content))
58+
59+
file_costs[file.path] = costs
60+
61+
# 3. Budget allocation loop
62+
# Start with minimal cost (all signature)
63+
current_total_tokens = sum(file_costs[f.path]["signature"] for f in sorted_files)
64+
65+
# Assign initial level
66+
for file in snapshot.files:
67+
file.compression_level = "signature"
68+
69+
# If we have budget left, upgrade files based on risk
70+
# We iterate sorted_files (highest risk first)
71+
72+
# Upgrades: signature -> skeleton -> full
73+
74+
# Pass 1: Upgrade to Skeleton
75+
for file in sorted_files:
76+
costs = file_costs[file.path]
77+
cost_increase = costs["skeleton"] - costs["signature"]
78+
79+
if current_total_tokens + cost_increase <= self.token_budget:
80+
file.compression_level = "skeleton"
81+
current_total_tokens += cost_increase
82+
else:
83+
# If we can't upgrade this file, maybe we can upgrade smaller files?
84+
# Greedy approach says: prioritize high risk.
85+
# If high risk file is huge, it might consume all budget.
86+
# Standard Knapsack problem.
87+
# For now, simple greedy: iterate by risk. If fits, upgrade.
88+
pass
89+
90+
# Pass 2: Upgrade to Full
91+
for file in sorted_files:
92+
if file.compression_level == "skeleton":
93+
costs = file_costs[file.path]
94+
cost_increase = costs["full"] - costs["skeleton"]
95+
96+
if current_total_tokens + cost_increase <= self.token_budget:
97+
file.compression_level = "full"
98+
current_total_tokens += cost_increase
99+
100+
return snapshot
101+
102+
def select_strategy(self, file_risk: float, is_focal_file: bool) -> str:
97103
"""
98-
Recursively traverses the AST and removes non-essential nodes,
99-
like the bodies of functions.
104+
Determines the ideal strategy based on risk, ignoring budget.
105+
Used as a heuristic or upper bound.
100106
"""
101-
if not isinstance(node, ASTNode):
102-
return
103-
104-
# For function nodes, keep the signature but remove the body
105-
if node.node_type == "function":
106-
node.children = [] # A simple way to trim the function body
107-
return
108-
109-
for child in node.children:
110-
self._traverse_and_trim(child)
111-
112-
113-
def calculate_compression_ratio(
114-
self, original: ProjectSnapshot, compressed: ProjectSnapshot
115-
) -> float:
116-
"""Calculates the compression ratio."""
117-
original_size = len(json.dumps(original.model_dump(mode='json')))
118-
compressed_size = len(json.dumps(compressed.model_dump(mode='json')))
119-
if original_size == 0:
120-
return 0.0
121-
return (original_size - compressed_size) / original_size
107+
if is_focal_file or file_risk >= 0.7: # High risk
108+
return "full"
109+
elif file_risk >= 0.3: # Medium risk
110+
return "skeleton"
111+
else:
112+
return "signature"

codesage/snapshot/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class FileSnapshot(BaseModel):
101101
symbols: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of symbols defined in the file.")
102102
risk: Optional[FileRisk] = Field(None, description="Risk assessment for the file.")
103103
issues: List[Issue] = Field(default_factory=list, description="A list of issues identified in the file.")
104+
compression_level: Literal["full", "skeleton", "signature"] = Field("full", description="The compression level applied to the file.")
104105

105106
# Old fields for compatibility
106107
hash: Optional[str] = Field(None, description="The SHA256 hash of the file content.")

0 commit comments

Comments
 (0)