-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlm.py
More file actions
319 lines (266 loc) · 12.9 KB
/
lm.py
File metadata and controls
319 lines (266 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import hashlib
import logging
import warnings
from typing import Any
import litellm
import numpy as np
from litellm import batch_completion, completion_cost
from litellm.types.utils import ChatCompletionTokenLogprob, Choices, ModelResponse
from litellm.utils import token_counter
from openai._exceptions import OpenAIError
from tokenizers import Tokenizer
from tqdm import tqdm
import lotus
from lotus.cache import CacheFactory
from lotus.types import (
LMOutput,
LMStats,
LogprobsForCascade,
LogprobsForFilterCascade,
LotusUsageLimitException,
UsageLimit,
)
logging.getLogger("LiteLLM").setLevel(logging.CRITICAL)
logging.getLogger("httpx").setLevel(logging.CRITICAL)
class LM:
def __init__(
self,
model: str = "gpt-4o-mini",
temperature: float = 0.0,
max_ctx_len: int = 128000,
max_tokens: int = 512,
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
cache=None,
physical_usage_limit: UsageLimit = UsageLimit(),
virtual_usage_limit: UsageLimit = UsageLimit(),
**kwargs: dict[str, Any],
):
"""Language Model class for interacting with various LLM providers.
Args:
model (str): Name of the model to use. Defaults to "gpt-4o-mini".
temperature (float): Sampling temperature. Defaults to 0.0.
max_ctx_len (int): Maximum context length in tokens. Defaults to 128000.
max_tokens (int): Maximum number of tokens to generate. Defaults to 512.
max_batch_size (int): Maximum batch size for concurrent requests. Defaults to 64.
tokenizer (Tokenizer | None): Custom tokenizer instance. Defaults to None.
cache: Cache instance to use. Defaults to None.
usage_limit (UsageLimit): Usage limits for the model. Defaults to UsageLimit().
**kwargs: Additional keyword arguments passed to the underlying LLM API.
"""
self.model = model
self.max_ctx_len = max_ctx_len
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.tokenizer = tokenizer
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.stats: LMStats = LMStats()
self.physical_usage_limit = physical_usage_limit
self.virtual_usage_limit = virtual_usage_limit
self.cache = cache or CacheFactory.create_default_cache()
def __call__(
self,
messages: list[list[dict[str, str]]],
show_progress_bar: bool = True,
progress_bar_desc: str = "Processing uncached messages",
**kwargs: dict[str, Any],
) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}
# Set top_logprobs if logprobs requested
if all_kwargs.get("logprobs", False):
all_kwargs.setdefault("top_logprobs", 10)
if lotus.settings.enable_cache:
# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]
uncached_data = (
[(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None]
if lotus.settings.enable_cache
else [(msg, "no-cache") for msg in messages]
)
self.stats.cache_hits += len(messages) - len(uncached_data)
# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(
uncached_data, all_kwargs, show_progress_bar, progress_bar_desc
)
# Add new responses to cache and update stats
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._update_stats(resp, is_cached=False)
if lotus.settings.enable_cache:
self._cache_response(resp, hash)
# Update virtual stats for cached responses
if lotus.settings.enable_cache:
for resp in cached_responses:
if resp is not None:
self._update_stats(resp, is_cached=True)
# Merge all responses in original order and extract outputs
all_responses = (
self._merge_responses(cached_responses, uncached_responses)
if lotus.settings.enable_cache
else uncached_responses
)
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)
return LMOutput(outputs=outputs, logprobs=logprobs)
def _process_uncached_messages(self, uncached_data, all_kwargs, show_progress_bar, progress_bar_desc):
"""Processes uncached messages in batches and returns responses."""
total_calls = len(uncached_data)
pbar = tqdm(
total=total_calls,
desc=progress_bar_desc,
disable=not show_progress_bar,
bar_format="{l_bar}{bar} {n}/{total} LM calls [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
batch = [msg for msg, _ in uncached_data]
uncached_responses = batch_completion(
self.model, batch, drop_params=True, max_workers=self.max_batch_size, **all_kwargs
)
pbar.update(total_calls)
pbar.close()
return uncached_responses
def _cache_response(self, response, hash):
"""Caches a response and updates stats if successful."""
if isinstance(response, OpenAIError):
raise response
self.cache.insert(hash, response)
def _hash_messages(self, messages: list[dict[str, str]], kwargs: dict[str, Any]) -> str:
"""Hash messages and kwargs to create a unique key for the cache"""
to_hash = str(self.model) + str(messages) + str(kwargs)
return hashlib.sha256(to_hash.encode()).hexdigest()
def _merge_responses(
self, cached_responses: list[ModelResponse | None], uncached_responses: list[ModelResponse]
) -> list[ModelResponse]:
"""Merge cached and uncached responses, maintaining order"""
uncached_iter = iter(uncached_responses)
return [resp if resp is not None else next(uncached_iter) for resp in cached_responses]
def _check_usage_limit(self, usage: LMStats.TotalUsage, limit: UsageLimit, usage_type: str):
"""Helper to check if usage exceeds limits"""
if (
usage.prompt_tokens > limit.prompt_tokens_limit
or usage.completion_tokens > limit.completion_tokens_limit
or usage.total_tokens > limit.total_tokens_limit
or usage.total_cost > limit.total_cost_limit
):
raise LotusUsageLimitException(f"Usage limit exceeded. Current {usage_type} usage: {usage}, Limit: {limit}")
def _update_usage_stats(self, usage: LMStats.TotalUsage, response: ModelResponse, cost: float | None):
"""Helper to update usage statistics"""
if hasattr(response, "usage"):
usage.prompt_tokens += response.usage.prompt_tokens
usage.completion_tokens += response.usage.completion_tokens
usage.total_tokens += response.usage.total_tokens
if cost is not None:
usage.total_cost += cost
def _update_stats(self, response: ModelResponse, is_cached: bool = False):
if not hasattr(response, "usage"):
return
# Calculate cost once
try:
cost = completion_cost(completion_response=response)
except litellm.exceptions.NotFoundError as e:
# Sometimes the model's pricing information is not available
lotus.logger.debug(f"Error updating completion cost: {e}")
cost = None
except Exception as e:
# Handle any other unexpected errors when calculating cost
lotus.logger.debug(f"Unexpected error calculating completion cost: {e}")
warnings.warn(
"Error calculating completion cost - cost metrics will be inaccurate. Enable debug logging for details."
)
cost = None
# Always update virtual usage
self._update_usage_stats(self.stats.virtual_usage, response, cost)
self._check_usage_limit(self.stats.virtual_usage, self.virtual_usage_limit, "virtual")
# Only update physical usage for non-cached responses
if not is_cached:
self._update_usage_stats(self.stats.physical_usage, response, cost)
self._check_usage_limit(self.stats.physical_usage, self.physical_usage_limit, "physical")
def _get_top_choice(self, response: ModelResponse) -> str:
choice = response.choices[0]
assert isinstance(choice, Choices)
if choice.message.content is None:
raise ValueError(f"No content in response: {response}")
return choice.message.content
def _get_top_choice_logprobs(self, response: ModelResponse) -> list[ChatCompletionTokenLogprob]:
choice = response.choices[0]
assert isinstance(choice, Choices)
logprobs = choice.logprobs["content"]
return [ChatCompletionTokenLogprob(**logprob) for logprob in logprobs]
def format_logprobs_for_cascade(self, logprobs: list[list[ChatCompletionTokenLogprob]]) -> LogprobsForCascade:
all_tokens = []
all_confidences = []
for resp_logprobs in logprobs:
tokens = [logprob.token for logprob in resp_logprobs]
confidences = [np.exp(logprob.logprob) for logprob in resp_logprobs]
all_tokens.append(tokens)
all_confidences.append(confidences)
return LogprobsForCascade(tokens=all_tokens, confidences=all_confidences)
def format_logprobs_for_filter_cascade(
self, logprobs: list[list[ChatCompletionTokenLogprob]]
) -> LogprobsForFilterCascade:
# Get base cascade format first
base_cascade = self.format_logprobs_for_cascade(logprobs)
all_true_probs = []
def get_normalized_true_prob(token_probs: dict[str, float]) -> float | None:
if "True" in token_probs and "False" in token_probs:
true_prob = token_probs["True"]
false_prob = token_probs["False"]
return true_prob / (true_prob + false_prob)
return None
# Get true probabilities for filter cascade
for resp_idx, response_logprobs in enumerate(logprobs):
true_prob = None
for logprob in response_logprobs:
token_probs = {top.token: np.exp(top.logprob) for top in logprob.top_logprobs}
true_prob = get_normalized_true_prob(token_probs)
if true_prob is not None:
break
# Default to 1 if "True" in tokens, 0 if not
if true_prob is None:
true_prob = 1 if "True" in base_cascade.tokens[resp_idx] else 0
all_true_probs.append(true_prob)
return LogprobsForFilterCascade(
tokens=base_cascade.tokens, confidences=base_cascade.confidences, true_probs=all_true_probs
)
def count_tokens(self, messages: list[dict[str, str]] | str) -> int:
"""Count tokens in messages using either custom tokenizer or model's default tokenizer"""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
custom_tokenizer: dict[str, Any] | None = None
if self.tokenizer:
custom_tokenizer = dict(type="huggingface_tokenizer", tokenizer=self.tokenizer)
return token_counter(
custom_tokenizer=custom_tokenizer,
model=self.model,
messages=messages,
)
def print_total_usage(self):
print("\n=== Usage Statistics ===")
print("Virtual = Total usage if no caching was used")
print("Physical = Actual usage with caching applied\n")
print(f"Virtual Cost: ${self.stats.virtual_usage.total_cost:,.6f}")
print(f"Physical Cost: ${self.stats.physical_usage.total_cost:,.6f}")
print(f"Virtual Tokens: {self.stats.virtual_usage.total_tokens:,}")
print(f"Physical Tokens: {self.stats.physical_usage.total_tokens:,}")
print(f"Cache Hits: {self.stats.cache_hits:,}\n")
def reset_stats(self):
self.stats = LMStats()
def reset_cache(self, max_size: int | None = None):
self.cache.reset(max_size)
def get_model_name(self) -> str:
raw_model = self.model
if not raw_model:
return ""
# If a slash is present, assume the model name is after the last slash.
if "/" in raw_model:
candidate = raw_model.split("/")[-1]
else:
candidate = raw_model
# If a colon is present, assume the model version is appended and remove it.
if ":" in candidate:
candidate = candidate.split(":")[0]
return candidate.lower()
def is_deepseek(self) -> bool:
model_name = self.get_model_name()
return model_name.startswith("deepseek-r1")