|
| 1 | +from datafast.llms import MistralProvider |
| 2 | +from dotenv import load_dotenv |
| 3 | +import pytest |
| 4 | +from tests.test_schemas import ( |
| 5 | + SimpleResponse, |
| 6 | + LandmarkInfo, |
| 7 | + PersonaContent, |
| 8 | + QASet, |
| 9 | + MCQSet, |
| 10 | +) |
| 11 | + |
| 12 | +load_dotenv() |
| 13 | + |
| 14 | + |
| 15 | +@pytest.mark.integration |
| 16 | +class TestMistralProvider: |
| 17 | + """Test suite for Mistral provider with various input types and configurations.""" |
| 18 | + |
| 19 | + def test_basic_text_response(self): |
| 20 | + """Test the Mistral provider with a basic text response.""" |
| 21 | + provider = MistralProvider() |
| 22 | + response = provider.generate(prompt="What is the capital of France? Answer in one word.") |
| 23 | + assert "Paris" in response |
| 24 | + |
| 25 | + def test_structured_output(self): |
| 26 | + """Test the Mistral provider with structured output.""" |
| 27 | + provider = MistralProvider() |
| 28 | + prompt = """What is the capital of France? |
| 29 | + Provide a short answer and a brief explanation of why Paris is the capital.""" |
| 30 | + |
| 31 | + response = provider.generate( |
| 32 | + prompt=prompt, |
| 33 | + response_format=SimpleResponse, |
| 34 | + ) |
| 35 | + |
| 36 | + assert isinstance(response, SimpleResponse) |
| 37 | + assert "Paris" in response.answer |
| 38 | + assert len(response.reasoning) > 10 |
| 39 | + |
| 40 | + def test_with_messages(self): |
| 41 | + """Test Mistral provider with messages input instead of prompt.""" |
| 42 | + provider = MistralProvider() |
| 43 | + messages = [ |
| 44 | + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, |
| 45 | + {"role": "user", "content": "What is the capital of France? Answer in one word."}, |
| 46 | + ] |
| 47 | + |
| 48 | + response = provider.generate(messages=messages) |
| 49 | + assert "Paris" in response |
| 50 | + |
| 51 | + def test_messages_with_structured_output(self): |
| 52 | + """Test Mistral provider with messages input and structured output.""" |
| 53 | + provider = MistralProvider() |
| 54 | + messages = [ |
| 55 | + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, |
| 56 | + {"role": "user", "content": """What is the capital of France? |
| 57 | + Provide a short answer and a brief explanation of why Paris is the capital."""}, |
| 58 | + ] |
| 59 | + |
| 60 | + response = provider.generate( |
| 61 | + messages=messages, |
| 62 | + response_format=SimpleResponse, |
| 63 | + ) |
| 64 | + |
| 65 | + assert isinstance(response, SimpleResponse) |
| 66 | + assert "Paris" in response.answer |
| 67 | + assert len(response.reasoning) > 10 |
| 68 | + |
| 69 | + def test_with_all_parameters(self): |
| 70 | + """Test Mistral provider with all optional parameters specified.""" |
| 71 | + provider = MistralProvider( |
| 72 | + model_id="mistral-small-latest", |
| 73 | + temperature=0.3, |
| 74 | + max_completion_tokens=300, |
| 75 | + top_p=0.85, |
| 76 | + ) |
| 77 | + |
| 78 | + response = provider.generate(prompt="What is the capital of France? Answer in one word.") |
| 79 | + assert "Paris" in response |
| 80 | + |
| 81 | + def test_structured_landmark_info(self): |
| 82 | + """Test Mistral provider with a structured landmark info response.""" |
| 83 | + provider = MistralProvider(temperature=0.6, max_completion_tokens=2000) |
| 84 | + |
| 85 | + prompt = """ |
| 86 | + Extract structured landmark details about the Great Wall of China from the passage below. |
| 87 | +
|
| 88 | + Passage: |
| 89 | + "The Great Wall of China stands across northern China, originally begun in 220 BCE to guard imperial borders. |
| 90 | + Spanning roughly 13,171 miles, it threads over mountains and deserts, symbolising centuries of engineering prowess and cultural unity. |
| 91 | + Construction and major reinforcement during the Ming dynasty in the 14th century gave the wall its iconic form, using stone and brick to fortify older earthen ramparts. |
| 92 | + Key attributes include: overall length of about 13,171 miles (importance 0.9), primary materials of stone and brick with tamped earth cores (importance 0.7), and critical Ming dynasty stewardship that restored and expanded the fortifications (importance 0.8). |
| 93 | + Today's visitors typically rate the experience around 4.6 out of 5, citing sweeping views and the wall's historical resonance." |
| 94 | + """ |
| 95 | + |
| 96 | + response = provider.generate(prompt=prompt, response_format=LandmarkInfo) |
| 97 | + |
| 98 | + assert isinstance(response, LandmarkInfo) |
| 99 | + assert "Great Wall" in response.name |
| 100 | + assert "China" in response.location |
| 101 | + assert len(response.description) > 20 |
| 102 | + assert response.year_built is not None |
| 103 | + assert len(response.attributes) >= 3 |
| 104 | + |
| 105 | + for attr in response.attributes: |
| 106 | + assert 0 <= attr.importance <= 1 |
| 107 | + assert len(attr.name) > 0 |
| 108 | + assert len(attr.value) > 0 |
| 109 | + |
| 110 | + assert 0 <= response.visitor_rating <= 5 |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.integration |
| 114 | +class TestMistralMedium: |
| 115 | + """Test suite for mistral-medium-latest model.""" |
| 116 | + |
| 117 | + def test_persona_content_generation(self): |
| 118 | + """Test generating tweets and bio for a persona using Mistral Medium.""" |
| 119 | + provider = MistralProvider( |
| 120 | + model_id="mistral-medium-latest", |
| 121 | + temperature=0.7, |
| 122 | + max_completion_tokens=1000, |
| 123 | + ) |
| 124 | + |
| 125 | + prompt = """ |
| 126 | + Generate social media content for the following persona: |
| 127 | +
|
| 128 | + Persona: A passionate environmental scientist who loves hiking and photography, |
| 129 | + advocates for climate action, and enjoys sharing nature facts with humor. |
| 130 | +
|
| 131 | + Create exactly 5 tweets and 1 bio for this persona. |
| 132 | + """ |
| 133 | + |
| 134 | + response = provider.generate(prompt=prompt, response_format=PersonaContent) |
| 135 | + |
| 136 | + assert isinstance(response, PersonaContent) |
| 137 | + assert len(response.tweets) == 5 |
| 138 | + assert all(len(tweet) > 0 for tweet in response.tweets) |
| 139 | + assert len(response.bio) > 20 |
| 140 | + |
| 141 | + def test_qa_generation(self): |
| 142 | + """Test generating Q&A pairs on machine learning using Mistral Medium.""" |
| 143 | + provider = MistralProvider( |
| 144 | + model_id="mistral-medium-latest", |
| 145 | + temperature=0.5, |
| 146 | + max_completion_tokens=1500, |
| 147 | + ) |
| 148 | + |
| 149 | + prompt = """ |
| 150 | + Generate exactly 5 questions and their correct answers about machine learning topics. |
| 151 | +
|
| 152 | + Topics to cover: supervised learning, neural networks, overfitting, gradient descent, and cross-validation. |
| 153 | +
|
| 154 | + Each question should be clear and the answer should be concise but complete. |
| 155 | + """ |
| 156 | + |
| 157 | + response = provider.generate(prompt=prompt, response_format=QASet) |
| 158 | + |
| 159 | + assert isinstance(response, QASet) |
| 160 | + assert len(response.questions) == 5 |
| 161 | + for qa in response.questions: |
| 162 | + assert len(qa.question) > 10 |
| 163 | + assert len(qa.answer) > 10 |
| 164 | + |
| 165 | + def test_mcq_generation(self): |
| 166 | + """Test generating multiple choice questions using Mistral Medium.""" |
| 167 | + provider = MistralProvider( |
| 168 | + model_id="mistral-medium-latest", |
| 169 | + temperature=0.5, |
| 170 | + max_completion_tokens=1500, |
| 171 | + ) |
| 172 | + |
| 173 | + prompt = """ |
| 174 | + Generate exactly 3 multiple choice questions about machine learning. |
| 175 | +
|
| 176 | + For each question, provide: |
| 177 | + - The question itself |
| 178 | + - One correct answer |
| 179 | + - Three plausible but incorrect answers |
| 180 | +
|
| 181 | + Topics: neural networks, decision trees, and ensemble methods. |
| 182 | + """ |
| 183 | + |
| 184 | + response = provider.generate(prompt=prompt, response_format=MCQSet) |
| 185 | + |
| 186 | + assert isinstance(response, MCQSet) |
| 187 | + assert len(response.questions) == 3 |
| 188 | + for mcq in response.questions: |
| 189 | + assert len(mcq.question) > 10 |
| 190 | + assert len(mcq.correct_answer) > 0 |
| 191 | + assert len(mcq.incorrect_answers) == 3 |
| 192 | + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) |
0 commit comments