-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_pipelines.py
More file actions
398 lines (317 loc) · 14 KB
/
test_pipelines.py
File metadata and controls
398 lines (317 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import pytest
import pytest_asyncio
import sys
import os
# Add src directory to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
import dspy
from src.dspy_setup import setup_dspy_basic, setup_dspy
from src.basic_examples import BasicPipeline
from src.pydantic_integration import AnalysisModule, QueryInput, ValidatedRAGPipeline
from src.advanced_patterns import ResilientQAPipeline, CachedRAGPipeline, AsyncRAGPipeline
from dspy.evaluate import Evaluate
# Handle Logfire import for test logging
try:
from src.logfire_setup import get_logfire_manager, logfire_span
logfire_manager = get_logfire_manager()
LOGFIRE_AVAILABLE = True
except ImportError:
logfire_manager = None
LOGFIRE_AVAILABLE = False
def logfire_span(name, **kwargs):
def decorator(func):
return func
return decorator
class TestBasicPipelines:
"""Test basic DSPy pipeline functionality"""
def setup_method(self):
"""Setup test environment"""
if LOGFIRE_AVAILABLE:
logfire_manager.log_event("Setting up test environment", "info", test_class="TestBasicPipelines")
try:
self.lm = setup_dspy_basic()
self.pipeline = BasicPipeline()
# Load test cases from sample data
from src.util import load_sample_data, get_test_cases
try:
sample_data = load_sample_data()
self.test_cases = get_test_cases(sample_data)
except Exception:
# Fallback to hardcoded test cases
self.test_cases = [
{
"question": "What is machine learning?",
"expected_topics": ["algorithms", "data", "learning", "artificial", "intelligence"]
},
{
"question": "How does photosynthesis work?",
"expected_topics": ["plants", "sunlight", "energy", "carbon", "oxygen"]
},
{
"question": "What is Python programming?",
"expected_topics": ["programming", "language", "code", "software", "development"]
}
]
except Exception as e:
pytest.skip(f"Cannot setup DSPy (likely missing API key): {e}")
def test_pipeline_initialization(self):
"""Test that pipeline initializes correctly"""
assert self.pipeline is not None
assert hasattr(self.pipeline, 'forward')
assert hasattr(self.pipeline, 'qa')
assert hasattr(self.pipeline, 'rag')
@logfire_span("test_answer_generation", component="tests")
def test_answer_generation(self):
"""Test basic answer generation"""
question = "What is artificial intelligence?"
if LOGFIRE_AVAILABLE:
logfire_manager.log_event("Testing answer generation", "info", question=question)
result = self.pipeline(question)
assert hasattr(result, 'answer')
assert isinstance(result.answer, str)
assert len(result.answer.strip()) > 0
assert len(result.answer) > 10 # Meaningful answer length
if LOGFIRE_AVAILABLE:
logfire_manager.log_event("Answer generation test passed", "info", answer_length=len(result.answer))
def test_answer_quality(self):
"""Test answer quality and relevance"""
for case in self.test_cases[:2]: # Test first 2 to avoid API limits
result = self.pipeline(case["question"])
# Check answer length
assert len(result.answer) > 20, f"Answer too short for: {case['question']}"
# Check for topic relevance (at least one expected topic should appear)
answer_lower = result.answer.lower()
topic_found = any(topic in answer_lower for topic in case["expected_topics"])
assert topic_found, f"No relevant topics found in answer for: {case['question']}"
def test_context_handling(self):
"""Test pipeline handling with custom context"""
from src.util import load_sample_data
try:
sample_data = load_sample_data()
custom_context = sample_data.get('contexts', {}).get('dspy_context',
"DSPy is a framework for programming language models. It provides automatic optimization of prompts.")
except Exception:
custom_context = "DSPy is a framework for programming language models. It provides automatic optimization of prompts. DSPy uses signatures to define task interfaces."
question = "What is DSPy?"
result = self.pipeline(question, context=custom_context)
assert "DSPy" in result.answer or "dspy" in result.answer.lower()
assert hasattr(result, 'answer')
assert len(result.answer) > 0
class TestPydanticIntegration:
"""Test Pydantic integration with DSPy"""
def setup_method(self):
"""Setup test environment"""
try:
setup_dspy()
self.analyzer = AnalysisModule()
self.rag_pipeline = ValidatedRAGPipeline()
except Exception as e:
pytest.skip(f"Cannot setup DSPy (likely missing API key): {e}")
def test_analysis_result_validation(self):
"""Test that analysis results are properly validated"""
sample_text = "This is a positive message about technology and innovation."
result = self.analyzer(sample_text)
# Check Pydantic model fields
assert hasattr(result, 'sentiment')
assert hasattr(result, 'confidence')
assert hasattr(result, 'key_themes')
assert hasattr(result, 'summary')
assert hasattr(result, 'word_count')
# Check field types and constraints
assert result.sentiment in ['positive', 'negative', 'neutral']
assert 0.0 <= result.confidence <= 1.0
assert isinstance(result.key_themes, list)
assert 1 <= len(result.key_themes) <= 5
assert len(result.summary) >= 10
assert result.word_count > 0
def test_query_input_validation(self):
"""Test query input validation"""
# Valid query
valid_query = QueryInput(
question="What is machine learning?",
max_results=3
)
assert valid_query.question == "What is machine learning?"
assert valid_query.max_results == 3
# Invalid query - too short
with pytest.raises(Exception): # Pydantic ValidationError
QueryInput(question="Hi")
# Invalid query - too many results
with pytest.raises(Exception): # Pydantic ValidationError
QueryInput(question="Valid question?", max_results=25)
def test_validated_rag_pipeline(self):
"""Test RAG pipeline with validated inputs"""
query = QueryInput(
question="What is artificial intelligence?",
include_reasoning=True
)
result = self.rag_pipeline(query)
assert 'question' in result
assert 'answer' in result
assert 'context' in result
assert 'reasoning' in result # Should be included
assert result['question'] == query.question
class TestAdvancedPatterns:
"""Test advanced DSPy patterns"""
def setup_method(self):
"""Setup test environment"""
try:
setup_dspy()
self.resilient_pipeline = ResilientQAPipeline(max_retries=2)
self.cached_pipeline = CachedRAGPipeline(cache_size=10)
except Exception as e:
pytest.skip(f"Cannot setup DSPy (likely missing API key): {e}")
def test_resilient_pipeline_success(self):
"""Test resilient pipeline success case"""
context = "Python is a programming language created by Guido van Rossum."
question = "Who created Python?"
result = self.resilient_pipeline(context=context, question=question)
assert hasattr(result, 'answer')
assert hasattr(result, 'confidence')
assert hasattr(result, 'method_used')
assert hasattr(result, 'attempts')
assert result.confidence > 0
def test_resilient_pipeline_metrics(self):
"""Test resilient pipeline metrics collection"""
context = "Test context for metrics."
questions = ["Question 1?", "Question 2?"]
for question in questions:
self.resilient_pipeline(context=context, question=question)
metrics = self.resilient_pipeline.get_metrics()
assert 'total_calls' in metrics
assert 'successful_calls' in metrics
assert 'success_rate' in metrics
assert metrics['total_calls'] >= len(questions)
def test_cached_pipeline_caching(self):
"""Test caching functionality"""
question = "What is caching?"
# First call - should be cache miss
result1 = self.cached_pipeline(question=question)
stats_after_first = self.cached_pipeline.get_cache_stats()
# Second call - should be cache hit
result2 = self.cached_pipeline(question=question)
stats_after_second = self.cached_pipeline.get_cache_stats()
# Verify caching worked
assert stats_after_second['cache_hits'] > stats_after_first['cache_hits']
assert hasattr(result2, 'from_cache')
assert result2.from_cache == True
def test_cache_size_management(self):
"""Test cache size management"""
small_cache = CachedRAGPipeline(cache_size=2)
# Fill cache beyond capacity
questions = ["Q1?", "Q2?", "Q3?"] # 3 questions, cache size 2
for question in questions:
small_cache(question=question)
stats = small_cache.get_cache_stats()
assert int(stats['cache_size']) <= 2 # Should not exceed max size
@pytest.mark.asyncio
class TestAsyncPipeline:
"""Test asynchronous pipeline functionality"""
def setup_method(self):
"""Setup async test environment"""
try:
setup_dspy()
self.async_pipeline = AsyncRAGPipeline(max_workers=2)
except Exception as e:
pytest.skip(f"Cannot setup DSPy (likely missing API key): {e}")
async def test_async_batch_processing(self):
"""Test asynchronous batch processing"""
questions = [
"What is AI?",
"What is ML?",
"What is DL?"
]
results = await self.async_pipeline.forward_async(questions)
assert len(results) == len(questions)
for i, result in enumerate(results):
assert 'question' in result
assert 'success' in result
assert result['question'] == questions[i]
if result['success']:
assert 'answer' in result
assert result['answer'] is not None
else:
assert 'error' in result
async def test_async_error_handling(self):
"""Test async pipeline error handling"""
# Mix of valid and potentially problematic questions
questions = [
"What is artificial intelligence?",
"", # Empty question might cause issues
"What is machine learning?"
]
results = await self.async_pipeline.forward_async(questions)
assert len(results) == len(questions)
# Check that pipeline handles errors gracefully
for result in results:
assert isinstance(result, dict)
assert 'success' in result
class TestSystemEvaluation:
"""System-level evaluation tests"""
def setup_method(self):
"""Setup evaluation environment"""
try:
setup_dspy()
self.pipeline = BasicPipeline()
except Exception as e:
pytest.skip(f"Cannot setup DSPy (likely missing API key): {e}")
def custom_accuracy_metric(self, example, pred, trace=None):
"""Custom accuracy metric for evaluation"""
if not hasattr(pred, 'answer') or not pred.answer:
return 0.0
# Simple keyword-based accuracy
pred_lower = pred.answer.lower()
# Check for meaningful content
if len(pred_lower.strip()) < 10:
return 0.0
# Check for failure indicators
failure_terms = ['i don\'t know', 'cannot answer', 'not sure']
for term in failure_terms:
if term in pred_lower:
return 0.0
return 1.0 # Consider it correct if it's a substantial, confident answer
def test_systematic_evaluation(self):
"""Test systematic evaluation using DSPy's framework"""
# Create simple test set
test_examples = [
{"question": "What is programming?", "context": "Programming is writing code."},
{"question": "What is Python?", "context": "Python is a programming language."}
]
# Convert to DSPy examples
dspy_examples = []
for ex in test_examples:
dspy_ex = dspy.Example(
question=ex["question"],
context=ex["context"]
).with_inputs('question')
dspy_examples.append(dspy_ex)
# Run evaluation
try:
evaluator = Evaluate(
devset=dspy_examples,
metric=self.custom_accuracy_metric,
num_threads=1 # Use single thread to avoid API rate limits
)
score = evaluator(self.pipeline)
# Score should be between 0 and 1
assert 0.0 <= score <= 1.0
print(f"Evaluation score: {score}")
except Exception as e:
# Evaluation might fail due to API limits or other issues
print(f"Evaluation failed: {e}")
pytest.skip("Evaluation failed - possibly due to API limits")
# Test runner function
def run_tests():
"""Run all tests with proper error handling"""
test_args = [
__file__,
"-v", # Verbose output
"-s", # Don't capture print statements
"--tb=short" # Short traceback format
]
try:
pytest.main(test_args)
except Exception as e:
print(f"Test execution failed: {e}")
if __name__ == "__main__":
run_tests()