Skip to content

Commit 566a255

Browse files
committed
feat: Implement token budgeting and context optimization to maximize ROI from token budget using tiktoken.
1 parent f48d4ae commit 566a255

6 files changed

Lines changed: 424 additions & 58 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies = [
1212
"tree-sitter==0.21.3",
1313
"tree-sitter-languages>=1.10.0",
1414
"GitPython>=3.1.0",
15+
"tiktoken>=0.7.0",
1516
]
1617

1718
[project.scripts]

src/knowcode/cli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,12 @@ def query(query_type: str, target: str, store: str, as_json: bool) -> None:
188188
help="Path to knowledge store file or directory",
189189
)
190190
@click.option(
191-
"--max-chars", "-m",
191+
"--max-tokens", "-m",
192192
type=int,
193-
default=8000,
194-
help="Maximum characters in context (default: 8000)",
193+
default=2000,
194+
help="Maximum tokens in context (default: 2000)",
195195
)
196-
def context(target: str, store: str, max_chars: int) -> None:
196+
def context(target: str, store: str, max_tokens: int) -> None:
197197
"""Generate context bundle for an entity.
198198
199199
TARGET: Entity ID or search pattern
@@ -204,7 +204,7 @@ def context(target: str, store: str, max_chars: int) -> None:
204204
click.echo("Error: Knowledge store not found. Run 'knowcode analyze' first.", err=True)
205205
sys.exit(1)
206206

207-
synthesizer = ContextSynthesizer(knowledge, max_chars=max_chars)
207+
synthesizer = ContextSynthesizer(knowledge, max_tokens=max_tokens)
208208

209209
# Try exact match first
210210
entity = knowledge.get_entity(target)
@@ -222,7 +222,7 @@ def context(target: str, store: str, max_chars: int) -> None:
222222
bundle = synthesizer.synthesize(entity.id)
223223
if bundle:
224224
click.echo(bundle.context_text)
225-
click.echo(f"\n--- {bundle.total_chars} chars, {len(bundle.included_entities)} entities ---", err=True)
225+
click.echo(f"\n--- {bundle.total_chars} chars, {bundle.total_tokens} tokens, {len(bundle.included_entities)} entities ---", err=True)
226226
if bundle.truncated:
227227
click.echo("(truncated)", err=True)
228228

src/knowcode/context_synthesizer.py

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from knowcode.knowledge_store import KnowledgeStore
77
from knowcode.models import Entity, EntityKind
8+
from knowcode.token_counter import TokenCounter
89

910

1011
@dataclass
@@ -15,27 +16,31 @@ class ContextBundle:
1516
context_text: str
1617
included_entities: list[str]
1718
total_chars: int
19+
total_tokens: int
1820
truncated: bool
1921

2022

2123
class ContextSynthesizer:
2224
"""Synthesizes context bundles for entities."""
2325

24-
DEFAULT_MAX_CHARS = 8000 # Rough proxy for ~2K tokens
26+
DEFAULT_MAX_TOKENS = 2000
2527

2628
def __init__(
2729
self,
2830
store: KnowledgeStore,
29-
max_chars: int = DEFAULT_MAX_CHARS,
31+
max_tokens: int = DEFAULT_MAX_TOKENS,
32+
model: str = "gpt-4",
3033
) -> None:
3134
"""Initialize context synthesizer.
3235
3336
Args:
3437
store: Knowledge store to query.
35-
max_chars: Maximum characters in context bundle.
38+
max_tokens: Maximum tokens in context bundle.
39+
model: Model name for token counting.
3640
"""
3741
self.store = store
38-
self.max_chars = max_chars
42+
self.max_tokens = max_tokens
43+
self.tokenizer = TokenCounter(model)
3944

4045
def synthesize(self, entity_id: str) -> Optional[ContextBundle]:
4146
"""Synthesize context bundle for an entity.
@@ -52,80 +57,114 @@ def synthesize(self, entity_id: str) -> Optional[ContextBundle]:
5257

5358
sections: list[str] = []
5459
included: list[str] = [entity_id]
55-
truncated = False
56-
57-
# Section 1: Entity header
58-
sections.append(self._format_entity_header(entity))
59-
60-
# Section 2: Docstring/description
60+
61+
# We build sections in priority order but display them in logical order usually.
62+
# However, for simplicity, we'll append and check budget.
63+
64+
# Priority 1: Entity Core (Header, Signature, Description)
65+
header = self._format_entity_header(entity)
66+
current_tokens = self.tokenizer.count_tokens(header)
67+
sections.append(header)
68+
69+
desc = ""
6170
if entity.docstring:
62-
sections.append(f"## Description\n\n{entity.docstring}")
63-
64-
# Section 3: Signature (for functions/methods)
71+
desc = f"## Description\n\n{entity.docstring}"
72+
73+
sig = ""
6574
if entity.signature:
66-
sections.append(f"## Signature\n\n```python\n{entity.signature}\n```")
67-
68-
# Section 4: Source code (if available and fits)
75+
sig = f"## Signature\n\n```python\n{entity.signature}\n```"
76+
77+
# Add high priority sections if they fit
78+
if desc:
79+
t = self.tokenizer.count_tokens(desc)
80+
if current_tokens + t < self.max_tokens:
81+
sections.append(desc)
82+
current_tokens += t
83+
84+
if sig:
85+
t = self.tokenizer.count_tokens(sig)
86+
if current_tokens + t < self.max_tokens:
87+
sections.append(sig)
88+
current_tokens += t
89+
90+
# Priority 2: Source Code (Huge consumer, often truncated)
6991
if entity.source_code:
70-
code_section = f"## Source Code\n\n```python\n{entity.source_code}\n```"
71-
if self._would_fit(sections, code_section):
72-
sections.append(code_section)
73-
74-
# Section 5: Parent context
92+
code_header = "## Source Code\n\n```python\n"
93+
code_footer = "\n```"
94+
overhead = self.tokenizer.count_tokens(code_header + code_footer)
95+
remaining = self.max_tokens - current_tokens - overhead
96+
97+
if remaining > 100: # Only add if we have decent space
98+
code_body = entity.source_code
99+
code_tokens = self.tokenizer.count_tokens(code_body)
100+
101+
if code_tokens > remaining:
102+
code_body = self.tokenizer.truncate(code_body, remaining) + "\n# ... (truncated)"
103+
# We technically truncated the content
104+
# But we will rely on full budget exhaustion check often
105+
106+
sections.append(f"{code_header}{code_body}{code_footer}")
107+
current_tokens += self.tokenizer.count_tokens(sections[-1])
108+
else:
109+
# Skipped source code due to budget
110+
# We consider this truncation/loss of info
111+
pass
112+
113+
# Priority 3: Parent Context
75114
parent = self.store.get_parent(entity_id)
76115
if parent:
77116
parent_section = self._format_parent_context(parent)
78-
if self._would_fit(sections, parent_section):
117+
t = self.tokenizer.count_tokens(parent_section)
118+
if current_tokens + t < self.max_tokens:
79119
sections.append(parent_section)
80120
included.append(parent.id)
81-
82-
# Section 6: Callers (who calls this?)
121+
current_tokens += t
122+
123+
# Priority 4: Relationships (Callers, Callees, Children)
124+
# We add them greedily until budget exhaust
125+
126+
# Unified list of potential sections
127+
rel_sections = []
128+
83129
callers = self.store.get_callers(entity_id)
84130
if callers:
85-
callers_section = self._format_callers(callers)
86-
if self._would_fit(sections, callers_section):
87-
sections.append(callers_section)
88-
included.extend(c.id for c in callers)
131+
rel_sections.append((self._format_callers(callers), [c.id for c in callers]))
89132

90-
# Section 7: Callees (what does this call?)
91133
callees = self.store.get_callees(entity_id)
92134
if callees:
93-
callees_section = self._format_callees(callees)
94-
if self._would_fit(sections, callees_section):
95-
sections.append(callees_section)
96-
included.extend(c.id for c in callees)
97-
98-
# Section 8: Children (for classes/modules)
135+
rel_sections.append((self._format_callees(callees), [c.id for c in callees]))
136+
99137
if entity.kind in {EntityKind.CLASS, EntityKind.MODULE, EntityKind.DOCUMENT}:
100138
children = self.store.get_children(entity_id)
101139
if children:
102-
children_section = self._format_children(children)
103-
if self._would_fit(sections, children_section):
104-
sections.append(children_section)
105-
included.extend(c.id for c in children)
140+
rel_sections.append((self._format_children(children), [c.id for c in children]))
141+
142+
is_truncated = False
143+
144+
for text, ids in rel_sections:
145+
t = self.tokenizer.count_tokens(text)
146+
if current_tokens + t < self.max_tokens:
147+
sections.append(text)
148+
included.extend(ids)
149+
current_tokens += t
150+
else:
151+
is_truncated = True
106152

107-
# Build final context
108153
context_text = "\n\n---\n\n".join(sections)
109-
110-
# Final truncation if still too long
111-
if len(context_text) > self.max_chars:
112-
context_text = context_text[: self.max_chars - 20] + "\n\n[TRUNCATED]"
113-
truncated = True
154+
155+
# Check if we skipped source code but had it
156+
if entity.source_code and "## Source Code" not in context_text:
157+
is_truncated = True
114158

115159
return ContextBundle(
116160
target_entity=entity,
117161
context_text=context_text,
118162
included_entities=included,
119163
total_chars=len(context_text),
120-
truncated=truncated,
164+
total_tokens=current_tokens,
165+
truncated=is_truncated or (current_tokens >= self.max_tokens),
121166
)
122167

123-
def _would_fit(self, current_sections: list[str], new_section: str) -> bool:
124-
"""Check if adding a section would stay within budget."""
125-
current_len = sum(len(s) for s in current_sections)
126-
new_len = current_len + len(new_section) + 10 # +10 for separators
127-
return new_len < self.max_chars
128-
129168
def _format_entity_header(self, entity: Entity) -> str:
130169
"""Format entity header."""
131170
lines = [

src/knowcode/token_counter.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Token counting utility using tiktoken."""
2+
3+
from functools import lru_cache
4+
from typing import Optional
5+
6+
import tiktoken
7+
8+
9+
class TokenCounter:
10+
"""Token counter utility."""
11+
12+
DEFAULT_MODEL = "gpt-4"
13+
14+
def __init__(self, model: str = DEFAULT_MODEL) -> None:
15+
"""Initialize token counter.
16+
17+
Args:
18+
model: Model name to use for encoding.
19+
"""
20+
self.model = model
21+
try:
22+
self.encoding = tiktoken.encoding_for_model(model)
23+
except KeyError:
24+
# Fallback to cl100k_base (used by gpt-4, gpt-3.5-turbo)
25+
self.encoding = tiktoken.get_encoding("cl100k_base")
26+
27+
def count_tokens(self, text: str) -> int:
28+
"""Count tokens in text.
29+
30+
Args:
31+
text: Text to count tokens for.
32+
33+
Returns:
34+
Number of tokens.
35+
"""
36+
if not text:
37+
return 0
38+
return len(self.encoding.encode(text))
39+
40+
def truncate(self, text: str, max_tokens: int) -> str:
41+
"""Truncate text to max_tokens.
42+
43+
Args:
44+
text: Text to truncate.
45+
max_tokens: Maximum tokens allowed.
46+
47+
Returns:
48+
Truncated text.
49+
"""
50+
if not text:
51+
return ""
52+
53+
tokens = self.encoding.encode(text)
54+
if len(tokens) <= max_tokens:
55+
return text
56+
57+
truncated_tokens = tokens[:max_tokens]
58+
return self.encoding.decode(truncated_tokens)

tests/test_token_optimization.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Tests for Token Counter and Context Synthesizer."""
2+
3+
import pytest
4+
from unittest.mock import MagicMock
5+
from knowcode.token_counter import TokenCounter
6+
from knowcode.context_synthesizer import ContextSynthesizer, ContextBundle
7+
from knowcode.models import Entity, EntityKind, Location
8+
9+
def test_token_counter():
10+
"""Test functionality of TokenCounter."""
11+
counter = TokenCounter()
12+
13+
text = "Hello world"
14+
tokens = counter.count_tokens(text)
15+
assert tokens > 0
16+
17+
truncated = counter.truncate(text, max_tokens=1)
18+
assert counter.count_tokens(truncated) == 1
19+
assert truncated != text
20+
21+
def test_context_synthesizer_budget():
22+
"""Test standard budgeting logic."""
23+
store = MagicMock()
24+
25+
# Create a mock entity with huge source code
26+
large_code = "print('hello')\n" * 1000
27+
entity = Entity(
28+
id="test::Foo",
29+
kind=EntityKind.CLASS,
30+
name="Foo",
31+
qualified_name="Foo",
32+
location=Location("test.py", 1, 1000),
33+
source_code=large_code
34+
)
35+
store.get_entity.return_value = entity
36+
store.get_parent.return_value = None
37+
store.get_callers.return_value = []
38+
store.get_callees.return_value = []
39+
store.get_children.return_value = []
40+
41+
# Low budget
42+
synthesizer = ContextSynthesizer(store, max_tokens=50)
43+
bundle = synthesizer.synthesize("test::Foo")
44+
45+
assert bundle is not None
46+
assert bundle.total_tokens <= 50
47+
assert bundle.truncated is True
48+
# The text itself might not say 'truncated' if we omitted the whole section
49+
# assert "truncated" in bundle.context_text
50+
51+
def test_context_synthesizer_priority():
52+
"""Test that header is preserved even if code is truncated."""
53+
store = MagicMock()
54+
55+
entity = Entity(
56+
id="test::Bar",
57+
kind=EntityKind.FUNCTION,
58+
name="bar",
59+
qualified_name="bar",
60+
location=Location("test.py", 1, 10),
61+
source_code="def bar():\n pass # very long code...",
62+
docstring="Checks that header is kept."
63+
)
64+
store.get_entity.return_value = entity
65+
store.get_parent.return_value = None
66+
store.get_callers.return_value = []
67+
store.get_callees.return_value = []
68+
store.get_children.return_value = []
69+
70+
synthesizer = ContextSynthesizer(store, max_tokens=100)
71+
bundle = synthesizer.synthesize("test::Bar")
72+
73+
assert bundle is not None
74+
# Ensure header info is present
75+
assert "# Function: `bar`" in bundle.context_text
76+
assert "**File**: `test.py`" in bundle.context_text

0 commit comments

Comments
 (0)