|
16 | 16 | # Consolidation operation |
17 | 17 | ConsolidateContext, |
18 | 18 | ConsolidateResult, |
| 19 | + CreateBankContext, |
19 | 20 | Extension, |
20 | 21 | HttpExtension, |
21 | 22 | OperationValidationError, |
@@ -202,6 +203,30 @@ async def on_consolidate_complete(self, result: ConsolidateResult) -> None: |
202 | 203 | self.post_consolidate_calls.append(result) |
203 | 204 |
|
204 | 205 |
|
| 206 | +class CreateBankRejectingValidator(OperationValidatorExtension): |
| 207 | + """Validator that records create-bank checks and can reject them.""" |
| 208 | + |
| 209 | + def __init__(self, *, reject: bool = True): |
| 210 | + super().__init__({}) |
| 211 | + self.reject = reject |
| 212 | + self.create_bank_calls: list[CreateBankContext] = [] |
| 213 | + |
| 214 | + async def validate_retain(self, ctx: RetainContext) -> ValidationResult: |
| 215 | + return ValidationResult.accept() |
| 216 | + |
| 217 | + async def validate_recall(self, ctx: RecallContext) -> ValidationResult: |
| 218 | + return ValidationResult.accept() |
| 219 | + |
| 220 | + async def validate_reflect(self, ctx: ReflectContext) -> ValidationResult: |
| 221 | + return ValidationResult.accept() |
| 222 | + |
| 223 | + async def validate_create_bank(self, ctx: CreateBankContext) -> ValidationResult: |
| 224 | + self.create_bank_calls.append(ctx) |
| 225 | + if self.reject: |
| 226 | + return ValidationResult.reject("bank creation not allowed", status_code=402) |
| 227 | + return ValidationResult.accept() |
| 228 | + |
| 229 | + |
205 | 230 | class TestMemoryEngineValidation: |
206 | 231 | """Tests for validation integration with MemoryEngine. |
207 | 232 |
|
@@ -297,6 +322,109 @@ async def test_reflect_validation(self, memory_with_validator): |
297 | 322 |
|
298 | 323 | assert "limit exceeded" in str(exc_info.value).lower() |
299 | 324 |
|
| 325 | + @pytest.mark.asyncio |
| 326 | + async def test_retain_validates_create_bank_for_missing_bank(self, memory): |
| 327 | + """Retain may lazily create a bank, so missing-bank retain validates create_bank.""" |
| 328 | + bank_id = "test-retain-create-bank-validation" |
| 329 | + ctx = RequestContext() |
| 330 | + validator = CreateBankRejectingValidator() |
| 331 | + memory._operation_validator = validator |
| 332 | + |
| 333 | + with pytest.raises(OperationValidationError) as exc_info: |
| 334 | + await memory.retain_batch_async( |
| 335 | + bank_id=bank_id, |
| 336 | + contents=[{"content": "Should not create a bank"}], |
| 337 | + request_context=ctx, |
| 338 | + ) |
| 339 | + |
| 340 | + assert "bank creation not allowed" in str(exc_info.value) |
| 341 | + assert len(validator.create_bank_calls) == 1 |
| 342 | + assert validator.create_bank_calls[0].bank_id == bank_id |
| 343 | + assert await memory.get_bank_profile(bank_id, request_context=ctx, create_if_missing=False) is None |
| 344 | + |
| 345 | + @pytest.mark.asyncio |
| 346 | + async def test_async_retain_validates_create_bank_for_missing_bank(self, memory): |
| 347 | + """Async retain validates create_bank before creating its parent operation.""" |
| 348 | + bank_id = "test-async-retain-create-bank-validation" |
| 349 | + ctx = RequestContext() |
| 350 | + validator = CreateBankRejectingValidator() |
| 351 | + memory._operation_validator = validator |
| 352 | + |
| 353 | + with pytest.raises(OperationValidationError) as exc_info: |
| 354 | + await memory.submit_async_retain( |
| 355 | + bank_id=bank_id, |
| 356 | + contents=[{"content": "Should not create a bank"}], |
| 357 | + request_context=ctx, |
| 358 | + ) |
| 359 | + |
| 360 | + assert "bank creation not allowed" in str(exc_info.value) |
| 361 | + assert len(validator.create_bank_calls) == 1 |
| 362 | + assert validator.create_bank_calls[0].bank_id == bank_id |
| 363 | + assert await memory.get_bank_profile(bank_id, request_context=ctx, create_if_missing=False) is None |
| 364 | + |
| 365 | + @pytest.mark.asyncio |
| 366 | + async def test_create_bank_validation_skips_existing_bank(self, memory): |
| 367 | + """Existing banks do not require create_bank validation.""" |
| 368 | + bank_id = "test-create-bank-existing" |
| 369 | + ctx = RequestContext() |
| 370 | + await memory.get_bank_profile(bank_id, request_context=ctx) |
| 371 | + |
| 372 | + validator = CreateBankRejectingValidator() |
| 373 | + memory._operation_validator = validator |
| 374 | + |
| 375 | + created = await memory._ensure_bank_exists(bank_id, ctx) |
| 376 | + |
| 377 | + assert created is False |
| 378 | + assert validator.create_bank_calls == [] |
| 379 | + |
| 380 | + @pytest.mark.asyncio |
| 381 | + async def test_get_bank_profile_validates_create_bank_for_missing_bank(self, memory): |
| 382 | + """The default auto-create profile path validates create_bank.""" |
| 383 | + bank_id = "test-profile-create-bank-validation" |
| 384 | + ctx = RequestContext() |
| 385 | + validator = CreateBankRejectingValidator() |
| 386 | + memory._operation_validator = validator |
| 387 | + |
| 388 | + with pytest.raises(OperationValidationError) as exc_info: |
| 389 | + await memory.get_bank_profile(bank_id, request_context=ctx) |
| 390 | + |
| 391 | + assert "bank creation not allowed" in str(exc_info.value) |
| 392 | + assert len(validator.create_bank_calls) == 1 |
| 393 | + assert validator.create_bank_calls[0].bank_id == bank_id |
| 394 | + assert await memory.get_bank_profile(bank_id, request_context=ctx, create_if_missing=False) is None |
| 395 | + |
| 396 | + @pytest.mark.asyncio |
| 397 | + async def test_http_create_or_update_validates_create_bank(self, api_client, memory): |
| 398 | + bank_id = "test-http-put-create-bank-validation" |
| 399 | + validator = CreateBankRejectingValidator() |
| 400 | + memory._operation_validator = validator |
| 401 | + |
| 402 | + resp = await api_client.put( |
| 403 | + f"/v1/default/banks/{bank_id}", |
| 404 | + json={"name": "Blocked Bank"}, |
| 405 | + ) |
| 406 | + |
| 407 | + assert resp.status_code == 402 |
| 408 | + assert resp.json()["detail"] == "bank creation not allowed" |
| 409 | + assert len(validator.create_bank_calls) == 1 |
| 410 | + assert validator.create_bank_calls[0].bank_id == bank_id |
| 411 | + |
| 412 | + @pytest.mark.asyncio |
| 413 | + async def test_http_template_import_validates_create_bank(self, api_client, memory): |
| 414 | + bank_id = "test-http-import-create-bank-validation" |
| 415 | + validator = CreateBankRejectingValidator() |
| 416 | + memory._operation_validator = validator |
| 417 | + |
| 418 | + resp = await api_client.post( |
| 419 | + f"/v1/default/banks/{bank_id}/import", |
| 420 | + json={"version": "1"}, |
| 421 | + ) |
| 422 | + |
| 423 | + assert resp.status_code == 402 |
| 424 | + assert resp.json()["detail"] == "bank creation not allowed" |
| 425 | + assert len(validator.create_bank_calls) == 1 |
| 426 | + assert validator.create_bank_calls[0].bank_id == bank_id |
| 427 | + |
300 | 428 |
|
301 | 429 | @pytest.fixture |
302 | 430 | def memory_with_validator(memory): |
|
0 commit comments