Skip to content

Commit cf3f219

Browse files
authored
fix(iorails): Fix failing tests due to ModelManager refactor (#1791)
1 parent af242d8 commit cf3f219

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

tests/guardrails/test_iorails.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -345,29 +345,29 @@ def iorails_input_only(self):
345345
@pytest.mark.asyncio
346346
async def test_generate_async_calls_start(self, iorails):
347347
"""generate_async() calls start() automatically before running the pipeline."""
348-
iorails.model_manager.start = AsyncMock()
348+
iorails.engine_registry.start = AsyncMock()
349349
iorails.rails_manager.is_input_safe = AsyncMock(return_value=RailResult(is_safe=True))
350-
iorails.model_manager.generate_async = AsyncMock(return_value="ok")
350+
iorails.engine_registry.model_call = AsyncMock(return_value="ok")
351351
iorails.rails_manager.is_output_safe = AsyncMock(return_value=RailResult(is_safe=True))
352352

353353
assert not iorails._running
354354
await iorails.generate_async([{"role": "user", "content": "hi"}])
355355

356-
iorails.model_manager.start.assert_called_once()
356+
iorails.engine_registry.start.assert_called_once()
357357
assert iorails._running
358358

359359
@pytest.mark.asyncio
360360
async def test_generate_async_start_is_idempotent(self, iorails):
361361
"""Two generate_async() calls only trigger start() once."""
362-
iorails.model_manager.start = AsyncMock()
362+
iorails.engine_registry.start = AsyncMock()
363363
iorails.rails_manager.is_input_safe = AsyncMock(return_value=RailResult(is_safe=True))
364-
iorails.model_manager.generate_async = AsyncMock(return_value="ok")
364+
iorails.engine_registry.model_call = AsyncMock(return_value="ok")
365365
iorails.rails_manager.is_output_safe = AsyncMock(return_value=RailResult(is_safe=True))
366366

367367
await iorails.generate_async([{"role": "user", "content": "hi"}])
368368
await iorails.generate_async([{"role": "user", "content": "hi"}])
369369

370-
iorails.model_manager.start.assert_called_once()
370+
iorails.engine_registry.start.assert_called_once()
371371

372372
@pytest.mark.asyncio
373373
async def test_stream_async_calls_start(self, iorails_input_only):
@@ -376,16 +376,16 @@ async def test_stream_async_calls_start(self, iorails_input_only):
376376
async def mock_stream(model_type, messages, **kwargs):
377377
yield "hello"
378378

379-
iorails_input_only.model_manager.start = AsyncMock()
379+
iorails_input_only.engine_registry.start = AsyncMock()
380380
iorails_input_only.rails_manager.is_input_safe = AsyncMock(return_value=RailResult(is_safe=True))
381-
iorails_input_only.model_manager.stream_async = mock_stream
381+
iorails_input_only.engine_registry.stream_model_call = mock_stream
382382

383383
assert not iorails_input_only._running
384384
chunks = [
385385
chunk async for chunk in iorails_input_only.stream_async(messages=[{"role": "user", "content": "hi"}])
386386
]
387387

388-
iorails_input_only.model_manager.start.assert_called_once()
388+
iorails_input_only.engine_registry.start.assert_called_once()
389389
assert iorails_input_only._running
390390
assert chunks == ["hello"]
391391

@@ -396,19 +396,19 @@ async def test_stream_async_start_is_idempotent(self, iorails_input_only):
396396
async def mock_stream(model_type, messages, **kwargs):
397397
yield "hi"
398398

399-
iorails_input_only.model_manager.start = AsyncMock()
399+
iorails_input_only.engine_registry.start = AsyncMock()
400400
iorails_input_only.rails_manager.is_input_safe = AsyncMock(return_value=RailResult(is_safe=True))
401-
iorails_input_only.model_manager.stream_async = mock_stream
401+
iorails_input_only.engine_registry.stream_model_call = mock_stream
402402

403403
_ = [chunk async for chunk in iorails_input_only.stream_async(messages=[{"role": "user", "content": "hi"}])]
404404
_ = [chunk async for chunk in iorails_input_only.stream_async(messages=[{"role": "user", "content": "hi"}])]
405405

406-
iorails_input_only.model_manager.start.assert_called_once()
406+
iorails_input_only.engine_registry.start.assert_called_once()
407407

408408
@pytest.mark.asyncio
409409
async def test_stream_async_propagates_start_failure(self, iorails_input_only):
410410
"""start() failure inside stream_async propagates to the caller."""
411-
iorails_input_only.model_manager.start = AsyncMock(side_effect=RuntimeError("engine unavailable"))
411+
iorails_input_only.engine_registry.start = AsyncMock(side_effect=RuntimeError("engine unavailable"))
412412

413413
with pytest.raises(RuntimeError, match="engine unavailable"):
414414
_ = [chunk async for chunk in iorails_input_only.stream_async(messages=[{"role": "user", "content": "hi"}])]

0 commit comments

Comments
 (0)