Skip to content

Commit 488ca06

Browse files
Merge pull request #141 from aidenerdogan/feature/adding-mistral-provider4litellm
feat: add MistralProvider via LiteLLM with integration tests and READ…
2 parents 0b265b0 + 1262dee commit 488ca06

File tree

3 files changed

+242
-1
lines changed

3 files changed

+242
-1
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Currently we support the following LLM providers:
2929
- ✔︎ Anthropic
3030
- ✔︎ Google Gemini
3131
- ✔︎ Ollama (local LLM server)
32+
- ✔︎ Mistral AI
3233
- ⏳ more to come...
3334

3435

@@ -53,6 +54,7 @@ Other keys depends on which LLM providers you use.
5354
GEMINI_API_KEY=XXXX
5455
OPENAI_API_KEY=sk-XXXX
5556
ANTHROPIC_API_KEY=sk-ant-XXXXX
57+
MISTRAL_API_KEY=XXXX
5658
HF_TOKEN=hf_XXXXX
5759
```
5860

@@ -161,6 +163,7 @@ How to proceed?
161163
- openai API key
162164
- anthropic API key
163165
- gemini API key
166+
- mistral API key
164167
- an ollama server running (use `ollama serve` from command line)
165168
9. Commit your change, push to your fork and create a pull request from your fork branch to datafast main branch.
166169
10. Explain your pull request in a clear and concise way, I'll review it as soon as possible.

datafast/llms.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""LLM providers for datafast using LiteLLM.
22
3-
This module provides classes for different LLM providers (OpenAI, Anthropic, Gemini)
3+
This module provides classes for different LLM providers (OpenAI, Anthropic, Gemini, Mistral)
44
with a unified interface using LiteLLM under the hood.
55
"""
66

@@ -726,4 +726,50 @@ def __init__(
726726
top_p = top_p,
727727
frequency_penalty = frequency_penalty,
728728
timeout = timeout,
729+
)
730+
731+
732+
class MistralProvider(LLMProvider):
733+
"""Mistral AI provider using litellm."""
734+
735+
@property
736+
def provider_name(self) -> str:
737+
return "mistral"
738+
739+
@property
740+
def env_key_name(self) -> str:
741+
return "MISTRAL_API_KEY"
742+
743+
def __init__(
744+
self,
745+
model_id: str = "mistral-small-latest",
746+
api_key: str | None = None,
747+
temperature: float | None = None,
748+
max_completion_tokens: int | None = None,
749+
top_p: float | None = None,
750+
frequency_penalty: float | None = None,
751+
rpm_limit: int | None = None,
752+
timeout: int | None = None,
753+
):
754+
"""Initialize the Mistral provider.
755+
756+
Args:
757+
model_id: The model ID (defaults to mistral-small-latest)
758+
api_key: API key (if None, will get from MISTRAL_API_KEY env var)
759+
temperature: Temperature for generation (0.0 to 1.0)
760+
max_completion_tokens: Maximum tokens to generate
761+
top_p: Nucleus sampling parameter (0.0 to 1.0)
762+
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
763+
rpm_limit: Requests per minute limit for rate limiting
764+
timeout: Request timeout in seconds
765+
"""
766+
super().__init__(
767+
model_id=model_id,
768+
api_key=api_key,
769+
temperature=temperature,
770+
max_completion_tokens=max_completion_tokens,
771+
top_p=top_p,
772+
frequency_penalty=frequency_penalty,
773+
rpm_limit=rpm_limit,
774+
timeout=timeout,
729775
)

tests/test_mistral.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from datafast.llms import MistralProvider
2+
from dotenv import load_dotenv
3+
import pytest
4+
from tests.test_schemas import (
5+
SimpleResponse,
6+
LandmarkInfo,
7+
PersonaContent,
8+
QASet,
9+
MCQSet,
10+
)
11+
12+
load_dotenv()
13+
14+
15+
@pytest.mark.integration
16+
class TestMistralProvider:
17+
"""Test suite for Mistral provider with various input types and configurations."""
18+
19+
def test_basic_text_response(self):
20+
"""Test the Mistral provider with a basic text response."""
21+
provider = MistralProvider()
22+
response = provider.generate(prompt="What is the capital of France? Answer in one word.")
23+
assert "Paris" in response
24+
25+
def test_structured_output(self):
26+
"""Test the Mistral provider with structured output."""
27+
provider = MistralProvider()
28+
prompt = """What is the capital of France?
29+
Provide a short answer and a brief explanation of why Paris is the capital."""
30+
31+
response = provider.generate(
32+
prompt=prompt,
33+
response_format=SimpleResponse,
34+
)
35+
36+
assert isinstance(response, SimpleResponse)
37+
assert "Paris" in response.answer
38+
assert len(response.reasoning) > 10
39+
40+
def test_with_messages(self):
41+
"""Test Mistral provider with messages input instead of prompt."""
42+
provider = MistralProvider()
43+
messages = [
44+
{"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."},
45+
{"role": "user", "content": "What is the capital of France? Answer in one word."},
46+
]
47+
48+
response = provider.generate(messages=messages)
49+
assert "Paris" in response
50+
51+
def test_messages_with_structured_output(self):
52+
"""Test Mistral provider with messages input and structured output."""
53+
provider = MistralProvider()
54+
messages = [
55+
{"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."},
56+
{"role": "user", "content": """What is the capital of France?
57+
Provide a short answer and a brief explanation of why Paris is the capital."""},
58+
]
59+
60+
response = provider.generate(
61+
messages=messages,
62+
response_format=SimpleResponse,
63+
)
64+
65+
assert isinstance(response, SimpleResponse)
66+
assert "Paris" in response.answer
67+
assert len(response.reasoning) > 10
68+
69+
def test_with_all_parameters(self):
70+
"""Test Mistral provider with all optional parameters specified."""
71+
provider = MistralProvider(
72+
model_id="mistral-small-latest",
73+
temperature=0.3,
74+
max_completion_tokens=300,
75+
top_p=0.85,
76+
)
77+
78+
response = provider.generate(prompt="What is the capital of France? Answer in one word.")
79+
assert "Paris" in response
80+
81+
def test_structured_landmark_info(self):
82+
"""Test Mistral provider with a structured landmark info response."""
83+
provider = MistralProvider(temperature=0.6, max_completion_tokens=2000)
84+
85+
prompt = """
86+
Extract structured landmark details about the Great Wall of China from the passage below.
87+
88+
Passage:
89+
"The Great Wall of China stands across northern China, originally begun in 220 BCE to guard imperial borders.
90+
Spanning roughly 13,171 miles, it threads over mountains and deserts, symbolising centuries of engineering prowess and cultural unity.
91+
Construction and major reinforcement during the Ming dynasty in the 14th century gave the wall its iconic form, using stone and brick to fortify older earthen ramparts.
92+
Key attributes include: overall length of about 13,171 miles (importance 0.9), primary materials of stone and brick with tamped earth cores (importance 0.7), and critical Ming dynasty stewardship that restored and expanded the fortifications (importance 0.8).
93+
Today's visitors typically rate the experience around 4.6 out of 5, citing sweeping views and the wall's historical resonance."
94+
"""
95+
96+
response = provider.generate(prompt=prompt, response_format=LandmarkInfo)
97+
98+
assert isinstance(response, LandmarkInfo)
99+
assert "Great Wall" in response.name
100+
assert "China" in response.location
101+
assert len(response.description) > 20
102+
assert response.year_built is not None
103+
assert len(response.attributes) >= 3
104+
105+
for attr in response.attributes:
106+
assert 0 <= attr.importance <= 1
107+
assert len(attr.name) > 0
108+
assert len(attr.value) > 0
109+
110+
assert 0 <= response.visitor_rating <= 5
111+
112+
113+
@pytest.mark.integration
114+
class TestMistralMedium:
115+
"""Test suite for mistral-medium-latest model."""
116+
117+
def test_persona_content_generation(self):
118+
"""Test generating tweets and bio for a persona using Mistral Medium."""
119+
provider = MistralProvider(
120+
model_id="mistral-medium-latest",
121+
temperature=0.7,
122+
max_completion_tokens=1000,
123+
)
124+
125+
prompt = """
126+
Generate social media content for the following persona:
127+
128+
Persona: A passionate environmental scientist who loves hiking and photography,
129+
advocates for climate action, and enjoys sharing nature facts with humor.
130+
131+
Create exactly 5 tweets and 1 bio for this persona.
132+
"""
133+
134+
response = provider.generate(prompt=prompt, response_format=PersonaContent)
135+
136+
assert isinstance(response, PersonaContent)
137+
assert len(response.tweets) == 5
138+
assert all(len(tweet) > 0 for tweet in response.tweets)
139+
assert len(response.bio) > 20
140+
141+
def test_qa_generation(self):
142+
"""Test generating Q&A pairs on machine learning using Mistral Medium."""
143+
provider = MistralProvider(
144+
model_id="mistral-medium-latest",
145+
temperature=0.5,
146+
max_completion_tokens=1500,
147+
)
148+
149+
prompt = """
150+
Generate exactly 5 questions and their correct answers about machine learning topics.
151+
152+
Topics to cover: supervised learning, neural networks, overfitting, gradient descent, and cross-validation.
153+
154+
Each question should be clear and the answer should be concise but complete.
155+
"""
156+
157+
response = provider.generate(prompt=prompt, response_format=QASet)
158+
159+
assert isinstance(response, QASet)
160+
assert len(response.questions) == 5
161+
for qa in response.questions:
162+
assert len(qa.question) > 10
163+
assert len(qa.answer) > 10
164+
165+
def test_mcq_generation(self):
166+
"""Test generating multiple choice questions using Mistral Medium."""
167+
provider = MistralProvider(
168+
model_id="mistral-medium-latest",
169+
temperature=0.5,
170+
max_completion_tokens=1500,
171+
)
172+
173+
prompt = """
174+
Generate exactly 3 multiple choice questions about machine learning.
175+
176+
For each question, provide:
177+
- The question itself
178+
- One correct answer
179+
- Three plausible but incorrect answers
180+
181+
Topics: neural networks, decision trees, and ensemble methods.
182+
"""
183+
184+
response = provider.generate(prompt=prompt, response_format=MCQSet)
185+
186+
assert isinstance(response, MCQSet)
187+
assert len(response.questions) == 3
188+
for mcq in response.questions:
189+
assert len(mcq.question) > 10
190+
assert len(mcq.correct_answer) > 0
191+
assert len(mcq.incorrect_answers) == 3
192+
assert all(len(ans) > 0 for ans in mcq.incorrect_answers)

0 commit comments

Comments
 (0)