Skip to content

Latest commit

 

History

History
558 lines (433 loc) · 15.8 KB

File metadata and controls

558 lines (433 loc) · 15.8 KB

Context Management - Memory, State, and Token Budgets

The Problem

LLMs are stateless. Each request is independent:

  • No memory of previous messages
  • Token limits (4K, 32K, 128K depending on model)
  • Costs scale with context size

Your agent needs:

  • Conversation history (remember what was said)
  • Tool results (remember what was done)
  • Token management (stay within limits)
  • Optional: Long-term memory (across sessions)

Message Structure

Basic Message Format

@dataclass
class Message:
    role: str       # "system" | "user" | "assistant" | "tool"
    content: str    # Message text
    name: str = None           # Optional: tool name for tool role
    tool_calls: list = None    # Optional: for assistant role
    tool_call_id: str = None   # Optional: for tool role

Example Conversation

messages = [
    # System prompt (optional, sets behavior)
    {
        "role": "system",
        "content": "You are a helpful coding assistant."
    },
    
    # User request
    {
        "role": "user",
        "content": "Read the file test.py"
    },
    
    # Assistant decides to call a tool
    {
        "role": "assistant",
        "content": "",
        "tool_calls": [
            {
                "id": "call_1",
                "function": {
                    "name": "read_file",
                    "arguments": '{"path": "test.py"}'
                }
            }
        ]
    },
    
    # Tool result
    {
        "role": "tool",
        "content": "print('Hello world')",
        "tool_call_id": "call_1"
    },
    
    # Assistant final response
    {
        "role": "assistant",
        "content": "The file contains a simple print statement."
    }
]

Context Manager Implementation

# src/context.py
from dataclasses import dataclass, field
from typing import List, Dict, Any
import json

@dataclass
class Message:
    role: str
    content: str
    name: str = None
    tool_calls: List[Dict] = None
    tool_call_id: str = None
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to Ollama format"""
        d = {"role": self.role, "content": self.content}
        if self.name:
            d["name"] = self.name
        if self.tool_calls:
            d["tool_calls"] = self.tool_calls
        if self.tool_call_id:
            d["tool_call_id"] = self.tool_call_id
        return d
    
    def estimate_tokens(self) -> int:
        """Rough token estimate (~4 chars per token)"""
        total = len(self.content)
        if self.tool_calls:
            total += len(json.dumps(self.tool_calls))
        return total // 4


class ContextManager:
    """Manage conversation history and token budget"""
    
    def __init__(self, max_tokens: int = 4096, system_prompt: str = None):
        self.max_tokens = max_tokens
        self.messages: List[Message] = []
        
        # Add system prompt if provided
        if system_prompt:
            self.add_message(Message(role="system", content=system_prompt))
    
    def add_message(self, message: Message):
        """Add a message to context"""
        self.messages.append(message)
        self._enforce_budget()
    
    def add_user_message(self, content: str):
        """Convenience method for user messages"""
        self.add_message(Message(role="user", content=content))
    
    def add_assistant_message(self, content: str, tool_calls: List[Dict] = None):
        """Convenience method for assistant messages"""
        self.add_message(Message(
            role="assistant",
            content=content,
            tool_calls=tool_calls
        ))
    
    def add_tool_result(self, tool_call_id: str, content: str):
        """Convenience method for tool results"""
        self.add_message(Message(
            role="tool",
            content=content,
            tool_call_id=tool_call_id
        ))
    
    def get_messages(self) -> List[Dict[str, Any]]:
        """Get messages in Ollama format"""
        return [msg.to_dict() for msg in self.messages]
    
    def estimate_tokens(self) -> int:
        """Estimate total tokens in context"""
        return sum(msg.estimate_tokens() for msg in self.messages)
    
    def _enforce_budget(self):
        """Ensure we stay within token budget"""
        while self.estimate_tokens() > self.max_tokens * 0.8:
            # Remove oldest messages (but keep system prompt)
            if len(self.messages) > 1 and self.messages[0].role == "system":
                # Keep system, remove second-oldest
                self.messages.pop(1)
            elif len(self.messages) > 0:
                self.messages.pop(0)
            else:
                break  # Nothing left to remove
    
    def clear(self):
        """Clear all messages except system prompt"""
        if self.messages and self.messages[0].role == "system":
            self.messages = [self.messages[0]]
        else:
            self.messages = []
    
    def save(self, path: str):
        """Save conversation to file"""
        with open(path, "w") as f:
            json.dump([msg.to_dict() for msg in self.messages], f, indent=2)
    
    @classmethod
    def load(cls, path: str) -> "ContextManager":
        """Load conversation from file"""
        with open(path) as f:
            data = json.load(f)
        
        ctx = cls()
        for msg_dict in data:
            msg = Message(
                role=msg_dict["role"],
                content=msg_dict["content"],
                name=msg_dict.get("name"),
                tool_calls=msg_dict.get("tool_calls"),
                tool_call_id=msg_dict.get("tool_call_id")
            )
            ctx.messages.append(msg)
        
        return ctx

Token Budget Strategies

1. Sliding Window (Simplest)

Keep last N messages, drop oldest.

class SlidingWindowContext(ContextManager):
    def __init__(self, max_messages: int = 20):
        super().__init__()
        self.max_messages = max_messages
    
    def _enforce_budget(self):
        # Keep system prompt + last N messages
        if len(self.messages) > self.max_messages + 1:
            system = self.messages[0] if self.messages[0].role == "system" else None
            recent = self.messages[-(self.max_messages):]
            self.messages = ([system] if system else []) + recent

Pros: Simple, predictable
Cons: Loses old context, may forget important info


2. Summarization

Compress old messages into a summary.

class SummarizingContext(ContextManager):
    def __init__(self, llm, max_tokens: int = 4096):
        super().__init__(max_tokens=max_tokens)
        self.llm = llm
    
    def _enforce_budget(self):
        if self.estimate_tokens() > self.max_tokens * 0.8:
            # Summarize oldest half of conversation
            split_point = len(self.messages) // 2
            old_messages = self.messages[:split_point]
            recent_messages = self.messages[split_point:]
            
            # Ask LLM to summarize
            summary = self._summarize(old_messages)
            
            # Replace old messages with summary
            self.messages = [
                Message(role="system", content=f"Previous conversation summary:\n{summary}")
            ] + recent_messages
    
    def _summarize(self, messages: List[Message]) -> str:
        """Use LLM to summarize messages"""
        prompt = "Summarize the following conversation in 2-3 sentences:\n\n"
        for msg in messages:
            prompt += f"{msg.role}: {msg.content}\n"
        
        response = self.llm.generate([{"role": "user", "content": prompt}])
        return response["message"]["content"]

Pros: Preserves important info
Cons: Costs extra LLM calls, may lose details


3. Semantic Memory (Advanced)

Store embeddings of past messages, retrieve relevant ones.

from sentence_transformers import SentenceTransformer
import numpy as np

class SemanticMemoryContext(ContextManager):
    def __init__(self, max_tokens: int = 4096):
        super().__init__(max_tokens=max_tokens)
        self.encoder = SentenceTransformer('all-MiniLM-L6-v2')  # Lightweight model
        self.memory: List[tuple[Message, np.ndarray]] = []  # (message, embedding)
    
    def add_message(self, message: Message):
        # Add to current context
        super().add_message(message)
        
        # Store in semantic memory
        if message.content:
            embedding = self.encoder.encode(message.content)
            self.memory.append((message, embedding))
    
    def retrieve_relevant(self, query: str, top_k: int = 3) -> List[Message]:
        """Retrieve most relevant past messages"""
        if not self.memory:
            return []
        
        query_embedding = self.encoder.encode(query)
        
        # Calculate cosine similarity
        similarities = []
        for msg, emb in self.memory:
            sim = np.dot(query_embedding, emb) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(emb)
            )
            similarities.append((sim, msg))
        
        # Return top K most similar
        similarities.sort(reverse=True, key=lambda x: x[0])
        return [msg for _, msg in similarities[:top_k]]
    
    def _enforce_budget(self):
        if self.estimate_tokens() > self.max_tokens * 0.8:
            # Move old messages to long-term memory, keep recent
            split = len(self.messages) // 2
            old = self.messages[:split]
            
            # Already in self.memory, just remove from context
            self.messages = self.messages[split:]

Pros: Can recall distant relevant info
Cons: Complex, requires embedding model


Conversation Persistence

Save/Load Sessions

# src/agent.py
class Agent:
    def __init__(self, session_dir: str = "~/.agent/sessions"):
        self.session_dir = Path(session_dir).expanduser()
        self.session_dir.mkdir(parents=True, exist_ok=True)
        self.context = ContextManager()
    
    def save_session(self, session_id: str):
        """Save current conversation"""
        path = self.session_dir / f"{session_id}.json"
        self.context.save(str(path))
        print(f"Session saved: {session_id}")
    
    def load_session(self, session_id: str):
        """Load previous conversation"""
        path = self.session_dir / f"{session_id}.json"
        if path.exists():
            self.context = ContextManager.load(str(path))
            print(f"Session loaded: {session_id}")
        else:
            print(f"Session not found: {session_id}")
    
    def list_sessions(self):
        """List available sessions"""
        sessions = [f.stem for f in self.session_dir.glob("*.json")]
        return sessions

# Usage
agent = Agent()
agent.run("Create a file test.py")
agent.save_session("coding-session-1")

# Later...
agent.load_session("coding-session-1")
agent.run("What did we work on?")

Token Estimation

Accurate Token Counting

# Use tiktoken for accurate counting (OpenAI's tokenizer)
import tiktoken

def count_tokens(text: str, model: str = "gpt-4") -> int:
    """Accurate token count"""
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))

# For Ollama models, approximate:
def estimate_tokens(text: str) -> int:
    """Rough estimate: ~4 chars per token"""
    return len(text) // 4

Context Window Sizes

Model Context Window Notes
llama3.3 128K tokens ~512KB text
qwen2.5-coder 32K tokens ~128KB text
codellama 16K tokens ~64KB text

Safe usage: Stay under 80% of max to leave room for response


Optimization Tips

1. Compress System Prompts

Verbose:

system_prompt = """
You are a helpful coding assistant. You should help users with programming tasks.
When asked to write code, make sure to include comments and explanations.
Always test your code before presenting it.
"""

Concise:

system_prompt = "Helpful coding assistant. Include comments, test code."

2. Truncate Tool Results

def add_tool_result(self, tool_call_id: str, content: str, max_len: int = 1000):
    """Add tool result, truncated if too long"""
    if len(content) > max_len:
        content = content[:max_len] + f"\n... (truncated, {len(content)} total chars)"
    
    self.add_message(Message(
        role="tool",
        content=content,
        tool_call_id=tool_call_id
    ))

3. Remove Redundant Messages

def _deduplicate(self):
    """Remove repeated messages"""
    seen = set()
    unique = []
    
    for msg in self.messages:
        key = (msg.role, msg.content)
        if key not in seen:
            seen.add(key)
            unique.append(msg)
    
    self.messages = unique

Multi-Turn Loop with Context

# src/agent.py
class Agent:
    def run_conversation(self):
        """Interactive multi-turn conversation"""
        print("Agent ready. Type 'exit' to quit, '/clear' to reset context.")
        
        while True:
            user_input = input("\nYou: ").strip()
            
            if user_input.lower() == "exit":
                break
            
            if user_input.lower() == "/clear":
                self.context.clear()
                print("Context cleared.")
                continue
            
            if user_input.lower() == "/tokens":
                tokens = self.context.estimate_tokens()
                print(f"Context size: ~{tokens} tokens")
                continue
            
            # Add user message to context
            self.context.add_user_message(user_input)
            
            # Get response (may involve tool calls)
            response = self._process()
            
            print(f"\nAgent: {response}")
    
    def _process(self) -> str:
        """Process current context, handle tools, return final response"""
        max_turns = 10  # Prevent infinite loops
        
        for turn in range(max_turns):
            # Send context to LLM
            response = self.llm.generate(
                messages=self.context.get_messages(),
                tools=self.tools
            )
            
            # Check for tool calls
            if response.get("tool_calls"):
                # Add assistant's tool call to context
                self.context.add_assistant_message("", tool_calls=response["tool_calls"])
                
                # Execute tools
                for tool_call in response["tool_calls"]:
                    result = self._execute_tool(tool_call)
                    self.context.add_tool_result(tool_call["id"], result)
                
                # Loop back to get next LLM response
                continue
            
            # No tool calls, we have final response
            content = response["message"]["content"]
            self.context.add_assistant_message(content)
            return content
        
        return "Error: Max turns exceeded"

Testing Context Management

# tests/test_context.py
def test_token_budget_enforcement():
    ctx = ContextManager(max_tokens=100)  # Very small budget
    
    # Add many messages
    for i in range(20):
        ctx.add_user_message(f"Message {i}" * 10)  # Long messages
    
    # Should have pruned old messages
    assert ctx.estimate_tokens() <= 100

def test_session_persistence():
    ctx = ContextManager()
    ctx.add_user_message("Hello")
    ctx.add_assistant_message("Hi there")
    
    # Save and load
    ctx.save("test_session.json")
    loaded = ContextManager.load("test_session.json")
    
    assert len(loaded.messages) == 2
    assert loaded.messages[0].content == "Hello"

Next Steps

  1. Implement basic ContextManager with sliding window
  2. Test with multi-turn conversations
  3. Add session persistence
  4. Experiment with summarization (Phase 4+)
  5. Consider semantic memory for long-term projects

Read testing-guide.md for testing strategies.