|
1 | 1 | import json |
| 2 | +import logging |
2 | 3 |
|
3 | 4 | from langchain.agents import AgentState |
4 | 5 | from langchain.agents.middleware.types import before_model |
|
7 | 8 |
|
8 | 9 | from .config import Config |
9 | 10 |
|
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
10 | 13 | INDIVIDUAL_MIN_LENGTH = 100 |
| 14 | +# Approximate characters per token across providers |
| 15 | +CHARS_PER_TOKEN = 4 |
11 | 16 |
|
12 | 17 |
|
13 | 18 | def collect_long_strings(obj): |
@@ -88,3 +93,53 @@ def truncate_tool_messages(state: AgentState, runtime: Runtime) -> AgentState: |
88 | 93 | else: |
89 | 94 | modified_messages.append(msg) |
90 | 95 | return {"messages": modified_messages} |
| 96 | + |
| 97 | + |
| 98 | +def _estimate_tokens(text): |
| 99 | + """Estimate token count using character-based approximation.""" |
| 100 | + return len(text) // CHARS_PER_TOKEN |
| 101 | + |
| 102 | + |
| 103 | +def _message_content(msg): |
| 104 | + """Extract text content from a message dict or object.""" |
| 105 | + if isinstance(msg, dict): |
| 106 | + return msg.get("content", "") |
| 107 | + return getattr(msg, "content", "") |
| 108 | + |
| 109 | + |
| 110 | +def trim_messages_to_token_limit(messages): |
| 111 | + """ |
| 112 | + Trim conversation history from the oldest messages to fit within the token |
| 113 | + budget derived from MAX_CONTENT_LENGTH. |
| 114 | + The most recent message (the new user turn) is always kept. |
| 115 | + """ |
| 116 | + max_tokens = Config.MAX_CONTENT_LENGTH // CHARS_PER_TOKEN |
| 117 | + |
| 118 | + if not messages: |
| 119 | + return messages |
| 120 | + |
| 121 | + # Estimate per-message tokens |
| 122 | + token_counts = [_estimate_tokens(_message_content(m)) for m in messages] |
| 123 | + total_tokens = sum(token_counts) |
| 124 | + |
| 125 | + if total_tokens <= max_tokens: |
| 126 | + return messages |
| 127 | + |
| 128 | + # Always keep the last message; trim from the front |
| 129 | + trimmed = list(messages) |
| 130 | + trimmed_tokens = list(token_counts) |
| 131 | + |
| 132 | + while len(trimmed) > 1 and sum(trimmed_tokens) > max_tokens: |
| 133 | + trimmed.pop(0) |
| 134 | + trimmed_tokens.pop(0) |
| 135 | + |
| 136 | + logger.info( |
| 137 | + "Trimmed conversation history from %d to %d messages " |
| 138 | + "(estimated tokens: %d -> %d, limit: %d)", |
| 139 | + len(messages), |
| 140 | + len(trimmed), |
| 141 | + total_tokens, |
| 142 | + sum(trimmed_tokens), |
| 143 | + max_tokens, |
| 144 | + ) |
| 145 | + return trimmed |
0 commit comments