|
1 | | -import fnmatch |
2 | | -import json |
3 | | -import hashlib |
4 | | -from typing import Any, Dict, List |
5 | | - |
| 1 | +from typing import Any, Dict, List, Optional |
| 2 | +import os |
| 3 | +import tiktoken |
6 | 4 | from codesage.snapshot.models import ProjectSnapshot, FileSnapshot |
7 | | -from codesage.analyzers.ast_models import ASTNode |
8 | | - |
| 5 | +from codesage.snapshot.strategies import CompressionStrategyFactory, FullStrategy |
9 | 6 |
|
10 | 7 | class SnapshotCompressor: |
11 | | - """Compresses a ProjectSnapshot to reduce its size.""" |
| 8 | + """Compresses a ProjectSnapshot to reduce its token usage for LLM context.""" |
12 | 9 |
|
13 | | - def __init__(self, config: Dict[str, Any]): |
14 | | - self.config = config.get("compression", {}) |
15 | | - self.exclude_patterns = self.config.get("exclude_patterns", []) |
16 | | - self.trimming_threshold = self.config.get("trimming_threshold", 1000) |
| 10 | + def __init__(self, config: Dict[str, Any] = None): |
| 11 | + self.config = config or {} |
| 12 | + # Default budget if not specified |
| 13 | + self.token_budget = self.config.get("token_budget", 8000) |
| 14 | + self.model_name = self.config.get("model_name", "gpt-4") |
17 | 15 |
|
18 | | - def compress(self, snapshot: ProjectSnapshot) -> ProjectSnapshot: |
19 | | - """ |
20 | | - Compresses the snapshot by applying various techniques. |
| 16 | + try: |
| 17 | + self.encoding = tiktoken.encoding_for_model(self.model_name) |
| 18 | + except KeyError: |
| 19 | + self.encoding = tiktoken.get_encoding("cl100k_base") |
| 20 | + |
| 21 | + def compress_project(self, snapshot: ProjectSnapshot, project_root: str) -> ProjectSnapshot: |
21 | 22 | """ |
22 | | - compressed_snapshot = snapshot.model_copy(deep=True) |
| 23 | + Compresses the snapshot by assigning compression levels to files based on risk and budget. |
23 | 24 |
|
24 | | - if self.exclude_patterns: |
25 | | - compressed_snapshot.files = self._exclude_files( |
26 | | - compressed_snapshot.files, self.exclude_patterns |
27 | | - ) |
| 25 | + Args: |
| 26 | + snapshot: The project snapshot. |
| 27 | + project_root: The root directory of the project (to read file contents). |
28 | 28 |
|
29 | | - compressed_snapshot.files = self._deduplicate_ast_nodes(compressed_snapshot.files) |
30 | | - compressed_snapshot.files = self._trim_large_asts( |
31 | | - compressed_snapshot.files, self.trimming_threshold |
| 29 | + Returns: |
| 30 | + The modified project snapshot with updated compression_level fields. |
| 31 | + """ |
| 32 | + # 1. Sort files by Risk Score (Desc) |
| 33 | + # Assuming risk.risk_score exists. If not, default to 0. |
| 34 | + sorted_files = sorted( |
| 35 | + snapshot.files, |
| 36 | + key=lambda f: f.risk.risk_score if f.risk else 0.0, |
| 37 | + reverse=True |
32 | 38 | ) |
33 | 39 |
|
34 | | - return compressed_snapshot |
35 | | - |
36 | | - def _exclude_files( |
37 | | - self, files: List[FileSnapshot], patterns: List[str] |
38 | | - ) -> List[FileSnapshot]: |
39 | | - """Filters out files that match the exclude patterns.""" |
40 | | - return [ |
41 | | - file |
42 | | - for file in files |
43 | | - if not any(fnmatch.fnmatch(file.path, pattern) for pattern in patterns) |
44 | | - ] |
45 | | - |
46 | | - def _deduplicate_ast_nodes( |
47 | | - self, files: List[FileSnapshot] |
48 | | - ) -> List[FileSnapshot]: |
49 | | - """ |
50 | | - Deduplicates AST nodes by replacing identical subtrees with a reference. |
51 | | - This is a simplified implementation. A real one would need a more robust |
52 | | - hashing and reference mechanism. |
53 | | - """ |
54 | | - node_cache = {} |
55 | | - for file in files: |
56 | | - if file.ast_summary: # Assuming ast_summary holds the AST |
57 | | - self._traverse_and_deduplicate(file.ast_summary, node_cache) |
58 | | - return files |
59 | | - |
60 | | - def _traverse_and_deduplicate(self, node: ASTNode, cache: Dict[str, ASTNode]): |
61 | | - """Recursively traverses the AST and deduplicates nodes.""" |
62 | | - if not isinstance(node, ASTNode): |
63 | | - return |
64 | | - |
65 | | - node_hash = self._hash_node(node) |
66 | | - if node_hash in cache: |
67 | | - # Replace with a reference to the cached node |
68 | | - # This is a conceptual implementation. In practice, you might |
69 | | - # store the canonical node in a separate structure and use IDs. |
70 | | - node = cache[node_hash] |
71 | | - return |
72 | | - |
73 | | - cache[node_hash] = node |
74 | | - for i, child in enumerate(node.children): |
75 | | - node.children[i] = self._traverse_and_deduplicate(child, cache) |
76 | | - |
77 | | - def _hash_node(self, node: ASTNode) -> str: |
78 | | - """Creates a stable hash for an AST node.""" |
79 | | - # A simple hash based on type and value. A real implementation |
80 | | - # should be more robust, considering children as well. |
81 | | - hasher = hashlib.md5() |
82 | | - hasher.update(node.node_type.encode()) |
83 | | - if node.value: |
84 | | - hasher.update(str(node.value).encode()) |
85 | | - return hasher.hexdigest() |
86 | | - |
87 | | - def _trim_large_asts( |
88 | | - self, files: List[FileSnapshot], threshold: int |
89 | | - ) -> List[FileSnapshot]: |
90 | | - """Trims the AST of very large files to save space.""" |
91 | | - for file in files: |
92 | | - if file.lines > threshold and file.ast_summary: |
93 | | - self._traverse_and_trim(file.ast_summary) |
94 | | - return files |
95 | | - |
96 | | - def _traverse_and_trim(self, node: ASTNode): |
| 40 | + # 2. Initial pass: Estimate costs for different levels |
| 41 | + file_costs = {} # {file_path: {level: token_count}} |
| 42 | + |
| 43 | + # We need to read files. |
| 44 | + for file in sorted_files: |
| 45 | + file_path = os.path.join(project_root, file.path) |
| 46 | + try: |
| 47 | + with open(file_path, "r", encoding="utf-8", errors="replace") as f: |
| 48 | + content = f.read() |
| 49 | + except Exception: |
| 50 | + content = "" # Should we handle missing files? |
| 51 | + |
| 52 | + # Calculate costs for all strategies |
| 53 | + costs = {} |
| 54 | + for level in ["full", "skeleton", "signature"]: |
| 55 | + strategy = CompressionStrategyFactory.get_strategy(level) |
| 56 | + compressed_content = strategy.compress(content, file.path, file.language) |
| 57 | + costs[level] = len(self.encoding.encode(compressed_content)) |
| 58 | + |
| 59 | + file_costs[file.path] = costs |
| 60 | + |
| 61 | + # 3. Budget allocation loop |
| 62 | + # Start with minimal cost (all signature) |
| 63 | + current_total_tokens = sum(file_costs[f.path]["signature"] for f in sorted_files) |
| 64 | + |
| 65 | + # Assign initial level |
| 66 | + for file in snapshot.files: |
| 67 | + file.compression_level = "signature" |
| 68 | + |
| 69 | + # If we have budget left, upgrade files based on risk |
| 70 | + # We iterate sorted_files (highest risk first) |
| 71 | + |
| 72 | + # Upgrades: signature -> skeleton -> full |
| 73 | + |
| 74 | + # Pass 1: Upgrade to Skeleton |
| 75 | + for file in sorted_files: |
| 76 | + costs = file_costs[file.path] |
| 77 | + cost_increase = costs["skeleton"] - costs["signature"] |
| 78 | + |
| 79 | + if current_total_tokens + cost_increase <= self.token_budget: |
| 80 | + file.compression_level = "skeleton" |
| 81 | + current_total_tokens += cost_increase |
| 82 | + else: |
| 83 | + # If we can't upgrade this file, maybe we can upgrade smaller files? |
| 84 | + # Greedy approach says: prioritize high risk. |
| 85 | + # If high risk file is huge, it might consume all budget. |
| 86 | + # Standard Knapsack problem. |
| 87 | + # For now, simple greedy: iterate by risk. If fits, upgrade. |
| 88 | + pass |
| 89 | + |
| 90 | + # Pass 2: Upgrade to Full |
| 91 | + for file in sorted_files: |
| 92 | + if file.compression_level == "skeleton": |
| 93 | + costs = file_costs[file.path] |
| 94 | + cost_increase = costs["full"] - costs["skeleton"] |
| 95 | + |
| 96 | + if current_total_tokens + cost_increase <= self.token_budget: |
| 97 | + file.compression_level = "full" |
| 98 | + current_total_tokens += cost_increase |
| 99 | + |
| 100 | + return snapshot |
| 101 | + |
| 102 | + def select_strategy(self, file_risk: float, is_focal_file: bool) -> str: |
97 | 103 | """ |
98 | | - Recursively traverses the AST and removes non-essential nodes, |
99 | | - like the bodies of functions. |
| 104 | + Determines the ideal strategy based on risk, ignoring budget. |
| 105 | + Used as a heuristic or upper bound. |
100 | 106 | """ |
101 | | - if not isinstance(node, ASTNode): |
102 | | - return |
103 | | - |
104 | | - # For function nodes, keep the signature but remove the body |
105 | | - if node.node_type == "function": |
106 | | - node.children = [] # A simple way to trim the function body |
107 | | - return |
108 | | - |
109 | | - for child in node.children: |
110 | | - self._traverse_and_trim(child) |
111 | | - |
112 | | - |
113 | | - def calculate_compression_ratio( |
114 | | - self, original: ProjectSnapshot, compressed: ProjectSnapshot |
115 | | - ) -> float: |
116 | | - """Calculates the compression ratio.""" |
117 | | - original_size = len(json.dumps(original.model_dump(mode='json'))) |
118 | | - compressed_size = len(json.dumps(compressed.model_dump(mode='json'))) |
119 | | - if original_size == 0: |
120 | | - return 0.0 |
121 | | - return (original_size - compressed_size) / original_size |
| 107 | + if is_focal_file or file_risk >= 0.7: # High risk |
| 108 | + return "full" |
| 109 | + elif file_risk >= 0.3: # Medium risk |
| 110 | + return "skeleton" |
| 111 | + else: |
| 112 | + return "signature" |
0 commit comments