-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathtest_gating.py
More file actions
288 lines (231 loc) · 10.4 KB
/
test_gating.py
File metadata and controls
288 lines (231 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""
PraisonAI Test Gating Plugin
This module provides automatic marker assignment and skip/gating enforcement
for the PraisonAI test suite. It ensures tests are properly classified and
gated based on provider requirements and network access.
Environment Variables:
- PRAISONAI_TEST_TIER: smoke|main|extended|nightly (default: main)
- PRAISONAI_ALLOW_NETWORK: 0|1 (default: 0)
- PRAISONAI_LIVE_TESTS: 0|1 (default: 0)
- PRAISONAI_TEST_PROVIDERS: comma-separated list or 'all' (default: openai)
- PRAISONAI_LOCAL_SERVICES: 0|1 (default: 0)
"""
import os
import re
import socket
from pathlib import Path
from typing import Set, Dict, Optional
import pytest
# Provider detection patterns (case-insensitive)
PROVIDER_PATTERNS: Dict[str, re.Pattern] = {
'provider_openai': re.compile(r'\b(openai|gpt-[34]|gpt4|chatgpt)\b', re.IGNORECASE),
'provider_anthropic': re.compile(r'\b(anthropic|claude)\b', re.IGNORECASE),
'provider_google': re.compile(r'\b(google|gemini|palm|vertex)\b', re.IGNORECASE),
'provider_ollama': re.compile(r'\b(ollama)\b', re.IGNORECASE),
'provider_grok_xai': re.compile(r'\b(grok|xai|x\.ai)\b', re.IGNORECASE),
'provider_groq': re.compile(r'\b(groq)\b', re.IGNORECASE),
'provider_cohere': re.compile(r'\b(cohere)\b', re.IGNORECASE),
}
# Provider to environment variable mapping
PROVIDER_ENV_KEYS: Dict[str, str] = {
'provider_openai': 'OPENAI_API_KEY',
'provider_anthropic': 'ANTHROPIC_API_KEY',
'provider_google': 'GOOGLE_API_KEY',
'provider_ollama': None, # Requires service check
'provider_grok_xai': 'XAI_API_KEY',
'provider_groq': 'GROQ_API_KEY',
'provider_cohere': 'COHERE_API_KEY',
}
# Cache for file content scans (avoid re-reading files)
_file_content_cache: Dict[str, str] = {}
def _get_test_tier() -> str:
"""Get the current test tier from environment."""
return os.environ.get('PRAISONAI_TEST_TIER', 'main').lower()
def _is_network_allowed() -> bool:
"""Check if network access is allowed."""
return (
os.environ.get('PRAISONAI_ALLOW_NETWORK', '0') == '1' or
os.environ.get('PRAISONAI_LIVE_TESTS', '0') == '1'
)
def _get_allowed_providers() -> Set[str]:
"""Get the set of allowed providers."""
providers_str = os.environ.get('PRAISONAI_TEST_PROVIDERS', 'openai')
if providers_str.lower() == 'all':
return set(PROVIDER_ENV_KEYS.keys())
return {f'provider_{p.strip().lower()}' for p in providers_str.split(',')}
def _is_local_services_allowed() -> bool:
"""Check if local services (Docker, etc.) are allowed."""
return os.environ.get('PRAISONAI_LOCAL_SERVICES', '0') == '1'
def _check_ollama_available() -> bool:
"""Check if Ollama is running locally."""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(0.5)
result = sock.connect_ex(('127.0.0.1', 11434))
sock.close()
return result == 0
except Exception:
return False
def _check_provider_available(provider_marker: str) -> tuple[bool, str]:
"""
Check if a provider is available.
Returns (is_available, reason_if_not).
"""
env_key = PROVIDER_ENV_KEYS.get(provider_marker)
if provider_marker == 'provider_ollama':
if _check_ollama_available():
return True, ""
return False, "Ollama not running on localhost:11434"
if env_key:
if os.environ.get(env_key):
return True, ""
return False, f"{env_key} not set"
return True, "" # Unknown provider, allow by default
def _get_file_content(filepath: Path) -> str:
"""Get file content with caching."""
filepath_str = str(filepath)
if filepath_str not in _file_content_cache:
try:
_file_content_cache[filepath_str] = filepath.read_text(errors='ignore')
except Exception:
_file_content_cache[filepath_str] = ""
return _file_content_cache[filepath_str]
# Paths that should NEVER have provider markers auto-assigned
# These are test infrastructure files that may contain provider keywords
# but are not actual provider tests
EXCLUDED_PATHS = (
'_pytest_plugins',
'_meta',
'test_test_command',
'test_gating',
'test_network_guard',
'conftest',
'fixtures/',
)
def _is_excluded_path(filepath_str: str, nodeid: str = '') -> bool:
"""
Check if a path should be excluded from provider auto-detection.
This prevents test infrastructure files from being incorrectly
classified as provider tests just because they contain provider
keywords in their validation/testing logic.
"""
check_str = filepath_str.lower() + nodeid.lower()
for excluded in EXCLUDED_PATHS:
if excluded.lower() in check_str:
return True
return False
def _detect_providers_in_file(filepath: Path) -> Set[str]:
"""Detect which providers are referenced in a test file."""
filepath_str = str(filepath)
# Skip detection for excluded paths (plugin tests, meta, fixtures)
if _is_excluded_path(filepath_str):
return set()
content = _get_file_content(filepath)
detected = set()
for marker, pattern in PROVIDER_PATTERNS.items():
if pattern.search(content):
detected.add(marker)
return detected
def _get_test_type_from_path(nodeid: str) -> Optional[str]:
"""Determine test type based on path conventions."""
nodeid_lower = nodeid.lower()
if '/unit/' in nodeid_lower or '\\unit\\' in nodeid_lower:
return 'unit'
if '/integration/' in nodeid_lower or '\\integration\\' in nodeid_lower:
return 'integration'
if '/e2e/' in nodeid_lower or '\\e2e\\' in nodeid_lower:
return 'e2e'
if '/live/' in nodeid_lower or '\\live\\' in nodeid_lower:
return 'e2e'
return None
def pytest_configure(config):
"""Register custom markers and initialize plugin state."""
# Clear file content cache at start of session
_file_content_cache.clear()
def pytest_collection_modifyitems(config, items):
"""
Auto-assign markers and apply skip logic based on gating rules.
This hook runs after test collection and:
1. Adds test type markers (unit/integration/e2e) based on path
2. Adds provider markers based on file content analysis
3. Adds network marker if any provider marker is present
4. Applies skip logic based on environment configuration
"""
tier = _get_test_tier()
network_allowed = _is_network_allowed()
allowed_providers = _get_allowed_providers()
local_services_allowed = _is_local_services_allowed()
for item in items:
# Get existing markers
existing_markers = {m.name for m in item.iter_markers()}
# 1. Auto-assign test type marker based on path
test_type = _get_test_type_from_path(item.nodeid)
if test_type and test_type not in existing_markers:
item.add_marker(getattr(pytest.mark, test_type))
# 2. Auto-detect and assign provider markers from file content
# Skip auto-detection entirely for excluded paths (plugin tests, etc.)
if item.fspath and test_type != 'unit':
filepath = Path(item.fspath)
filepath_str = str(filepath)
# Check if this path should be excluded from provider detection
if not _is_excluded_path(filepath_str, item.nodeid):
detected_providers = _detect_providers_in_file(filepath)
# Also check nodeid for provider keywords (but not for excluded paths)
for marker, pattern in PROVIDER_PATTERNS.items():
if pattern.search(item.nodeid):
detected_providers.add(marker)
for provider in detected_providers:
if provider not in existing_markers:
item.add_marker(getattr(pytest.mark, provider))
# Refresh existing markers after additions
existing_markers = {m.name for m in item.iter_markers()}
# 3. Add network marker if any provider marker is present
provider_markers = {m for m in existing_markers if m.startswith('provider_')}
if provider_markers and 'network' not in existing_markers:
item.add_marker(pytest.mark.network)
# Handle 'real' marker as alias for network
if 'real' in existing_markers and 'network' not in existing_markers:
item.add_marker(pytest.mark.network)
# Refresh markers again
existing_markers = {m.name for m in item.iter_markers()}
# 4. Apply skip logic based on tier and gating rules
# Smoke tier: only unit tests, no network, no slow
if tier == 'smoke':
if 'integration' in existing_markers or 'e2e' in existing_markers:
item.add_marker(pytest.mark.skip(
reason="Smoke tier: skipping non-unit tests"
))
continue
if 'slow' in existing_markers:
item.add_marker(pytest.mark.skip(
reason="Smoke tier: skipping slow tests"
))
continue
# Skip network tests if network not allowed
if 'network' in existing_markers and not network_allowed:
item.add_marker(pytest.mark.skip(
reason="Network tests disabled. Set PRAISONAI_ALLOW_NETWORK=1 or PRAISONAI_LIVE_TESTS=1"
))
continue
# Skip provider tests if provider not in allowed list or key missing
for provider in provider_markers:
if provider not in allowed_providers:
item.add_marker(pytest.mark.skip(
reason=f"Provider {provider} not in PRAISONAI_TEST_PROVIDERS"
))
break
# Check if provider is actually available
if network_allowed:
available, reason = _check_provider_available(provider)
if not available:
item.add_marker(pytest.mark.skip(reason=reason))
break
# Skip local_service tests if not allowed
if 'local_service' in existing_markers and not local_services_allowed:
item.add_marker(pytest.mark.skip(
reason="Local service tests disabled. Set PRAISONAI_LOCAL_SERVICES=1"
))
continue
def pytest_sessionfinish(session, exitstatus):
"""Clean up at end of session."""
_file_content_cache.clear()