|
1 | | -from typing import override |
| 1 | +import functools |
| 2 | +import inspect |
| 3 | +import json |
| 4 | +import os |
| 5 | +from collections.abc import Callable, Coroutine |
| 6 | +from typing import Any, override |
| 7 | +from unittest.mock import patch |
| 8 | +from urllib import parse |
| 9 | + |
| 10 | +import vcr |
| 11 | +from vcr.config import RecordMode |
| 12 | +from vcr.request import Request |
| 13 | + |
2 | 14 | from splunklib.ai.model import PredefinedModel |
3 | 15 | from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model |
4 | 16 | from tests.testlib import SDKTestCase |
5 | 17 |
|
| 18 | +REDACTED_APP_KEY = "[[[--APPKEY-REDACTED-]]]" |
| 19 | + |
6 | 20 |
|
7 | 21 | class AITestCase(SDKTestCase): |
8 | 22 | _model: PredefinedModel | None = None |
@@ -42,3 +56,147 @@ async def model(self) -> PredefinedModel: |
42 | 56 | model = await create_model(self.test_llm_settings) |
43 | 57 | self._model = model |
44 | 58 | return model |
| 59 | + |
| 60 | + |
| 61 | +def ai_snapshot_test() -> Callable[ |
| 62 | + [Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]] |
| 63 | +]: |
| 64 | + def decorator( |
| 65 | + fn: Callable[..., Coroutine[Any, Any, None]], |
| 66 | + ) -> Callable[..., Coroutine[Any, Any, None]]: |
| 67 | + source_file = inspect.getfile(fn) |
| 68 | + test_dir = os.path.dirname(source_file) |
| 69 | + test_file = os.path.splitext(os.path.basename(source_file))[0] |
| 70 | + |
| 71 | + snapshot_dir = os.path.join(test_dir, "snapshots", test_file) |
| 72 | + snapshot_filename = f"{fn.__qualname__}.json" |
| 73 | + |
| 74 | + @functools.wraps(fn) |
| 75 | + async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None: |
| 76 | + settings = self.test_llm_settings |
| 77 | + assert settings.internal_ai is not None |
| 78 | + |
| 79 | + internal_ai_hostname = parse.urlparse( |
| 80 | + settings.internal_ai.base_url |
| 81 | + ).hostname |
| 82 | + assert internal_ai_hostname is not None |
| 83 | + |
| 84 | + class _JSONFriendlySerializer: |
| 85 | + def deserialize(self, serialized: str) -> Any: |
| 86 | + assert settings.internal_ai is not None |
| 87 | + serialized = serialized.replace( |
| 88 | + REDACTED_APP_KEY, settings.internal_ai.app_key |
| 89 | + ) |
| 90 | + |
| 91 | + data = json.loads(serialized) |
| 92 | + for interaction in data.get("interactions", []): |
| 93 | + interaction["request"]["uri"] = interaction["request"][ |
| 94 | + "uri" |
| 95 | + ].replace("internal-ai-host", internal_ai_hostname, 1) |
| 96 | + |
| 97 | + interaction["request"]["body"] = json.dumps( |
| 98 | + interaction["request"]["body"] |
| 99 | + ) |
| 100 | + body = interaction["response"]["body"] |
| 101 | + interaction["response"]["body"] = {} |
| 102 | + interaction["response"]["body"]["string"] = json.dumps(body) |
| 103 | + |
| 104 | + return data |
| 105 | + |
| 106 | + def serialize(self, dict: Any) -> str: |
| 107 | + for interaction in dict.get("interactions", []): |
| 108 | + interaction["request"]["uri"] = interaction["request"][ |
| 109 | + "uri" |
| 110 | + ].replace(internal_ai_hostname, "internal-ai-host", 1) |
| 111 | + |
| 112 | + body = interaction["request"]["body"] |
| 113 | + interaction["request"]["body"] = json.loads(body) |
| 114 | + |
| 115 | + resp_body = interaction["response"]["body"]["string"] |
| 116 | + interaction["response"]["body"] = json.loads(resp_body) |
| 117 | + |
| 118 | + out = json.dumps(dict, indent=4) + "\n" |
| 119 | + assert settings.internal_ai is not None |
| 120 | + out = out.replace(settings.internal_ai.app_key, REDACTED_APP_KEY) |
| 121 | + |
| 122 | + # Assert that nothing is leaking into the public snapshots. |
| 123 | + assert internal_ai_hostname not in out.lower() |
| 124 | + assert settings.internal_ai.app_key.lower() not in out.lower() |
| 125 | + assert settings.internal_ai.base_url.lower() not in out.lower() |
| 126 | + assert settings.internal_ai.token_url.lower() not in out.lower() |
| 127 | + assert settings.internal_ai.client_id.lower() not in out.lower() |
| 128 | + assert settings.internal_ai.client_secret.lower() not in out.lower() |
| 129 | + |
| 130 | + return out |
| 131 | + |
| 132 | + def _before_record_request(request: Request) -> Request | None: |
| 133 | + url = parse.urlparse(request.uri) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] |
| 134 | + if url.hostname == internal_ai_hostname: |
| 135 | + request.headers = {} |
| 136 | + return request |
| 137 | + return None |
| 138 | + |
| 139 | + def _before_record_response(response: Any) -> Any: |
| 140 | + response["headers"] = {} |
| 141 | + return response |
| 142 | + |
| 143 | + def _json_body_matcher(r1: Any, r2: Any) -> None: |
| 144 | + b1 = json.loads(r1.body) |
| 145 | + b2 = json.loads(r2.body) |
| 146 | + if b1 != b2: |
| 147 | + raise AssertionError(f"Body mismatch:\n{b1}\n!=\n{b2}") |
| 148 | + |
| 149 | + my_vcr = vcr.VCR( |
| 150 | + cassette_library_dir=snapshot_dir, |
| 151 | + serializer="json-friendly", |
| 152 | + record_mode=RecordMode.ONCE, |
| 153 | + match_on=[ |
| 154 | + "method", |
| 155 | + "scheme", |
| 156 | + "host", |
| 157 | + "port", |
| 158 | + "path", |
| 159 | + "query", |
| 160 | + "jsonbody", |
| 161 | + ], |
| 162 | + before_record_request=_before_record_request, |
| 163 | + before_record_response=_before_record_response, |
| 164 | + record_on_exception=False, |
| 165 | + drop_unused_requests=True, |
| 166 | + ) |
| 167 | + my_vcr.register_serializer("json-friendly", _JSONFriendlySerializer()) |
| 168 | + my_vcr.register_matcher("jsonbody", _json_body_matcher) |
| 169 | + |
| 170 | + with my_vcr.use_cassette(snapshot_filename): # pyright: ignore[reportGeneralTypeIssues] |
| 171 | + await fn(self, *args, **kwargs) |
| 172 | + |
| 173 | + return wrapper |
| 174 | + |
| 175 | + return decorator |
| 176 | + |
| 177 | + |
| 178 | +def deterministic_thread_ids() -> Callable[ |
| 179 | + [Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]] |
| 180 | +]: |
| 181 | + def decorator( |
| 182 | + fn: Callable[..., Coroutine[Any, Any, None]], |
| 183 | + ) -> Callable[..., Coroutine[Any, Any, None]]: |
| 184 | + @functools.wraps(fn) |
| 185 | + async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None: |
| 186 | + counter = 0 |
| 187 | + |
| 188 | + def _deterministic_uuid() -> str: |
| 189 | + nonlocal counter |
| 190 | + result = f"00000000-0000-0000-0000-{counter:012d}" |
| 191 | + counter += 1 |
| 192 | + return result |
| 193 | + |
| 194 | + with patch( |
| 195 | + "splunklib.ai.engines.langchain._thread_id_new_uuid", |
| 196 | + side_effect=_deterministic_uuid, |
| 197 | + ): |
| 198 | + await fn(self, *args, **kwargs) |
| 199 | + |
| 200 | + return wrapper |
| 201 | + |
| 202 | + return decorator |
0 commit comments