Skip to content

Commit 83e89ca

Browse files
Merge pull request #127 from patrickfleith/tests/refactoring-llms-tests
Tests/refactoring llms tests
2 parents 212abf9 + 725b7fe commit 83e89ca

File tree

6 files changed

+1438
-558
lines changed

6 files changed

+1438
-558
lines changed

datafast/llms.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# LiteLLM
1717
import litellm
1818
from litellm.utils import ModelResponse
19-
from litellm import batch_completion
2019

2120
# Internal imports
2221
from .llm_utils import get_messages
@@ -115,6 +114,37 @@ def _respect_rate_limit(self) -> None:
115114
print("Waiting for rate limit...")
116115
time.sleep(sleep_time)
117116

117+
@staticmethod
118+
def _strip_code_fences(content: str) -> str:
119+
"""Strip markdown code fences from content if present.
120+
121+
Args:
122+
content: The content string that may contain code fences
123+
124+
Returns:
125+
Content with code fences removed
126+
"""
127+
if not content:
128+
return content
129+
130+
content = content.strip()
131+
132+
# Check for code fences with optional language identifier
133+
if content.startswith('```'):
134+
# Find the end of the first line (language identifier)
135+
first_newline = content.find('\n')
136+
if first_newline != -1:
137+
content = content[first_newline + 1:]
138+
else:
139+
# No newline after opening fence, remove just the fence
140+
content = content[3:]
141+
142+
# Remove closing fence
143+
if content.endswith('```'):
144+
content = content[:-3]
145+
146+
return content.strip()
147+
118148
def generate(
119149
self,
120150
prompt: str | list[str] | None = None,
@@ -176,13 +206,29 @@ def generate(
176206
raise ValueError("messages cannot be empty")
177207

178208
try:
209+
# Append JSON formatting instructions if response_format is provided
210+
json_instructions = (
211+
"\nReturn only valid JSON. To do so, don't include ```json ``` markdown "
212+
"or code fences around the JSON. Use double quotes for all keys and values. "
213+
"Escape internal quotes and newlines (use \\n). Do not include trailing commas."
214+
)
215+
179216
# Convert batch prompts to messages if needed
180217
batch_to_send = []
181218
if batch_prompts is not None:
182219
for one_prompt in batch_prompts:
183-
batch_to_send.append(get_messages(one_prompt))
220+
# Append JSON instructions to prompt if response_format is provided
221+
modified_prompt = one_prompt + json_instructions if response_format is not None else one_prompt
222+
batch_to_send.append(get_messages(modified_prompt))
184223
else:
185224
batch_to_send = batch_messages
225+
# Append JSON instructions to the last user message if response_format is provided
226+
if response_format is not None:
227+
for message_list in batch_to_send:
228+
for msg in reversed(message_list):
229+
if msg.get("role") == "user":
230+
msg["content"] += json_instructions
231+
break
186232

187233
# Enforce rate limit per batch
188234
self._respect_rate_limit()
@@ -211,11 +257,15 @@ def generate(
211257
results = []
212258
for one_response in response:
213259
content = one_response.choices[0].message.content
260+
214261
if response_format is not None:
262+
# Strip code fences before validation
263+
content = self._strip_code_fences(content)
215264
results.append(
216265
response_format.model_validate_json(content))
217266
else:
218-
results.append(content)
267+
# Strip leading/trailing whitespace for text responses
268+
results.append(content.strip() if content else content)
219269

220270
# Return single result for backward compatibility
221271
if single_input and len(results) == 1:
@@ -286,8 +336,8 @@ def __init__(
286336
api_key: str | None = None,
287337
temperature: float | None = None,
288338
max_completion_tokens: int | None = None,
289-
top_p: float | None = None,
290-
# frequency_penalty: float | None = None, # Not supported by anthropic
339+
# top_p: float | None = None, # Not properly supported by anthropic models 4.5
340+
# frequency_penalty: float | None = None, # Not supported by anthropic models 4.5
291341
):
292342
"""Initialize the Anthropic provider.
293343
@@ -303,7 +353,6 @@ def __init__(
303353
api_key=api_key,
304354
temperature=temperature,
305355
max_completion_tokens=max_completion_tokens,
306-
top_p=top_p,
307356
)
308357

309358

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[pytest]
22
markers =
33
integration: marks tests that require API connectivity (deselect with '-m "not integration"')
4+
slow: marks tests that are slow to run
45

56
# Other pytest configurations
67
testpaths = tests
@@ -16,6 +17,8 @@ log_cli_level = INFO
1617
filterwarnings =
1718
# Ignore Pydantic deprecation warnings
1819
ignore::DeprecationWarning:pydantic.*:
20+
# Ignore Pydantic serializer warnings during tests
21+
ignore::UserWarning:pydantic.main
1922
# Ignore LiteLLM deprecation warnings
2023
ignore::DeprecationWarning:litellm.*:
2124
# Ignore HTTPX deprecation warnings

0 commit comments

Comments
 (0)