Skip to content

Commit 1984965

Browse files
authored
Snapshot record then reply testing (#740)
This change implements snapshot, record then reply testing for all integration tests. Re-enables recently skipped (flaky) tests: - `test_agent_understands_other_agents` (snapshot was edited manually) - `test_supervisor_resumes_subagent_thread_across_invocations` - `test_supervisor_resumes_subagent_thread_across_invocations_structured` Introduces a deterministic thread_id mock generator, such that snapshots are deterministic, for: - `test_supervisor_resumes_subagent_thread_across_invocations` - `test_supervisor_resumes_subagent_thread_across_invocations_structured` Modified `test_tool_execution_service_access` using tool middleware, to make the test deterministic. E2E tests still call the LLMs directly.
1 parent e81d2ed commit 1984965

68 files changed

Lines changed: 16986 additions & 23 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ test = [
4545
"pytest-cov>=7.1.0",
4646
"pytest-asyncio>=1.3.0",
4747
"python-dotenv>=1.2.2",
48+
"vcrpy>=8.1.1",
4849
]
4950
release = ["build>=1.4.3", "jinja2>=3.1.6", "sphinx>=9.1.0", "twine>=6.2.0"]
5051
lint = ["basedpyright>=1.39.0", "ruff>=0.15.10"]

splunklib/ai/engines/langchain.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@
174174
_testing_force_tool_strategy = False
175175

176176

177+
def _thread_id_new_uuid() -> str:
178+
return str(uuid.uuid4())
179+
180+
177181
def _supports_provider_strategy(model: BaseChatModel) -> bool:
178182
return (
179183
model.profile is not None
@@ -365,16 +369,16 @@ async def awrap_model_call(
365369
# LLM halucinated a thread_id, start a new conversation instead.
366370
# This should not happen, since we provide an enum above, but just
367371
# in case.
368-
args.thread_id = str(uuid.uuid4())
372+
args.thread_id = _thread_id_new_uuid()
369373

370374
if args.thread_id and args.thread_id in called_thread_ids:
371375
# LLM did not listen not to issue multiple calls to the
372376
# same thread_id, start a new conversation instead.
373-
args.thread_id = str(uuid.uuid4())
377+
args.thread_id = _thread_id_new_uuid()
374378

375379
if not args.thread_id:
376380
# Generate thread_id for a new conversation.
377-
args.thread_id = str(uuid.uuid4())
381+
args.thread_id = _thread_id_new_uuid()
378382

379383
called_thread_ids.add(args.thread_id)
380384
call["args"] = asdict(args)

tests/ai_test_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ async def _buildInternalAIModel(
7575
auth=(client_id, client_secret),
7676
)
7777

78+
response.raise_for_status()
79+
7880
token = _TokenResponse.model_validate_json(response.text).access_token
7981

8082
auth_handler = _InternalAIAuth(token)

tests/ai_testlib.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
1-
from typing import override
1+
import functools
2+
import inspect
3+
import json
4+
import os
5+
from collections.abc import Callable, Coroutine
6+
from typing import Any, override
7+
from unittest.mock import patch
8+
from urllib import parse
9+
10+
import vcr
11+
from vcr.config import RecordMode
12+
from vcr.request import Request
13+
214
from splunklib.ai.model import PredefinedModel
315
from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model
416
from tests.testlib import SDKTestCase
517

18+
REDACTED_APP_KEY = "[[[--APPKEY-REDACTED-]]]"
19+
620

721
class AITestCase(SDKTestCase):
822
_model: PredefinedModel | None = None
@@ -42,3 +56,147 @@ async def model(self) -> PredefinedModel:
4256
model = await create_model(self.test_llm_settings)
4357
self._model = model
4458
return model
59+
60+
61+
def ai_snapshot_test() -> Callable[
62+
[Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]]
63+
]:
64+
def decorator(
65+
fn: Callable[..., Coroutine[Any, Any, None]],
66+
) -> Callable[..., Coroutine[Any, Any, None]]:
67+
source_file = inspect.getfile(fn)
68+
test_dir = os.path.dirname(source_file)
69+
test_file = os.path.splitext(os.path.basename(source_file))[0]
70+
71+
snapshot_dir = os.path.join(test_dir, "snapshots", test_file)
72+
snapshot_filename = f"{fn.__qualname__}.json"
73+
74+
@functools.wraps(fn)
75+
async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None:
76+
settings = self.test_llm_settings
77+
assert settings.internal_ai is not None
78+
79+
internal_ai_hostname = parse.urlparse(
80+
settings.internal_ai.base_url
81+
).hostname
82+
assert internal_ai_hostname is not None
83+
84+
class _JSONFriendlySerializer:
85+
def deserialize(self, serialized: str) -> Any:
86+
assert settings.internal_ai is not None
87+
serialized = serialized.replace(
88+
REDACTED_APP_KEY, settings.internal_ai.app_key
89+
)
90+
91+
data = json.loads(serialized)
92+
for interaction in data.get("interactions", []):
93+
interaction["request"]["uri"] = interaction["request"][
94+
"uri"
95+
].replace("internal-ai-host", internal_ai_hostname, 1)
96+
97+
interaction["request"]["body"] = json.dumps(
98+
interaction["request"]["body"]
99+
)
100+
body = interaction["response"]["body"]
101+
interaction["response"]["body"] = {}
102+
interaction["response"]["body"]["string"] = json.dumps(body)
103+
104+
return data
105+
106+
def serialize(self, dict: Any) -> str:
107+
for interaction in dict.get("interactions", []):
108+
interaction["request"]["uri"] = interaction["request"][
109+
"uri"
110+
].replace(internal_ai_hostname, "internal-ai-host", 1)
111+
112+
body = interaction["request"]["body"]
113+
interaction["request"]["body"] = json.loads(body)
114+
115+
resp_body = interaction["response"]["body"]["string"]
116+
interaction["response"]["body"] = json.loads(resp_body)
117+
118+
out = json.dumps(dict, indent=4) + "\n"
119+
assert settings.internal_ai is not None
120+
out = out.replace(settings.internal_ai.app_key, REDACTED_APP_KEY)
121+
122+
# Assert that nothing is leaking into the public snapshots.
123+
assert internal_ai_hostname not in out.lower()
124+
assert settings.internal_ai.app_key.lower() not in out.lower()
125+
assert settings.internal_ai.base_url.lower() not in out.lower()
126+
assert settings.internal_ai.token_url.lower() not in out.lower()
127+
assert settings.internal_ai.client_id.lower() not in out.lower()
128+
assert settings.internal_ai.client_secret.lower() not in out.lower()
129+
130+
return out
131+
132+
def _before_record_request(request: Request) -> Request | None:
133+
url = parse.urlparse(request.uri) # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType]
134+
if url.hostname == internal_ai_hostname:
135+
request.headers = {}
136+
return request
137+
return None
138+
139+
def _before_record_response(response: Any) -> Any:
140+
response["headers"] = {}
141+
return response
142+
143+
def _json_body_matcher(r1: Any, r2: Any) -> None:
144+
b1 = json.loads(r1.body)
145+
b2 = json.loads(r2.body)
146+
if b1 != b2:
147+
raise AssertionError(f"Body mismatch:\n{b1}\n!=\n{b2}")
148+
149+
my_vcr = vcr.VCR(
150+
cassette_library_dir=snapshot_dir,
151+
serializer="json-friendly",
152+
record_mode=RecordMode.ONCE,
153+
match_on=[
154+
"method",
155+
"scheme",
156+
"host",
157+
"port",
158+
"path",
159+
"query",
160+
"jsonbody",
161+
],
162+
before_record_request=_before_record_request,
163+
before_record_response=_before_record_response,
164+
record_on_exception=False,
165+
drop_unused_requests=True,
166+
)
167+
my_vcr.register_serializer("json-friendly", _JSONFriendlySerializer())
168+
my_vcr.register_matcher("jsonbody", _json_body_matcher)
169+
170+
with my_vcr.use_cassette(snapshot_filename): # pyright: ignore[reportGeneralTypeIssues]
171+
await fn(self, *args, **kwargs)
172+
173+
return wrapper
174+
175+
return decorator
176+
177+
178+
def deterministic_thread_ids() -> Callable[
179+
[Callable[..., Coroutine[Any, Any, None]]], Callable[..., Coroutine[Any, Any, None]]
180+
]:
181+
def decorator(
182+
fn: Callable[..., Coroutine[Any, Any, None]],
183+
) -> Callable[..., Coroutine[Any, Any, None]]:
184+
@functools.wraps(fn)
185+
async def wrapper(self: AITestCase, *args: Any, **kwargs: Any) -> None:
186+
counter = 0
187+
188+
def _deterministic_uuid() -> str:
189+
nonlocal counter
190+
result = f"00000000-0000-0000-0000-{counter:012d}"
191+
counter += 1
192+
return result
193+
194+
with patch(
195+
"splunklib.ai.engines.langchain._thread_id_new_uuid",
196+
side_effect=_deterministic_uuid,
197+
):
198+
await fn(self, *args, **kwargs)
199+
200+
return wrapper
201+
202+
return decorator

0 commit comments

Comments
 (0)