@@ -1342,3 +1342,230 @@ async def test_error_handler_also_replaces_correctly(self, mock_rails_factory):
13421342 assert result ["messages" ][1 ].content == middleware .blocked_output_message
13431343 assert isinstance (result ["messages" ][1 ], AIMessage )
13441344 assert result ["messages" ][2 ] is trailing_msg
1345+
1346+
1347+ class TestModifiedStatus :
1348+ @pytest .mark .asyncio
1349+ async def test_input_modified_replaces_last_human_message (self , mock_rails_factory ):
1350+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "sanitized input" )
1351+ middleware = create_middleware_with_rails (mock_rails )
1352+
1353+ state = {"messages" : [HumanMessage (content = "original input with PII" )]}
1354+ result = await middleware .abefore_model (state , None )
1355+
1356+ assert result is not None
1357+ assert "jump_to" not in result
1358+ assert len (result ["messages" ]) == 1
1359+ assert isinstance (result ["messages" ][0 ], HumanMessage )
1360+ assert result ["messages" ][0 ].content == "sanitized input"
1361+
1362+ @pytest .mark .asyncio
1363+ async def test_input_modified_preserves_surrounding_messages (self , mock_rails_factory ):
1364+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "redacted" )
1365+ middleware = create_middleware_with_rails (mock_rails )
1366+
1367+ system_msg = SystemMessage (content = "Be helpful" )
1368+ state = {
1369+ "messages" : [
1370+ system_msg ,
1371+ HumanMessage (content = "my SSN is 123-45-6789" ),
1372+ ]
1373+ }
1374+ result = await middleware .abefore_model (state , None )
1375+
1376+ assert len (result ["messages" ]) == 2
1377+ assert result ["messages" ][0 ] is system_msg
1378+ assert isinstance (result ["messages" ][1 ], HumanMessage )
1379+ assert result ["messages" ][1 ].content == "redacted"
1380+
1381+ @pytest .mark .asyncio
1382+ async def test_input_modified_with_multi_turn_history (self , mock_rails_factory ):
1383+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "cleaned message" )
1384+ middleware = create_middleware_with_rails (mock_rails )
1385+
1386+ first_human = HumanMessage (content = "Hello" )
1387+ first_ai = AIMessage (content = "Hi there" )
1388+ state = {
1389+ "messages" : [
1390+ first_human ,
1391+ first_ai ,
1392+ HumanMessage (content = "my email is foo@bar.com" ),
1393+ ]
1394+ }
1395+ result = await middleware .abefore_model (state , None )
1396+
1397+ assert len (result ["messages" ]) == 3
1398+ assert result ["messages" ][0 ] is first_human
1399+ assert result ["messages" ][1 ] is first_ai
1400+ assert result ["messages" ][2 ].content == "cleaned message"
1401+
1402+ @pytest .mark .asyncio
1403+ async def test_output_modified_replaces_last_ai_message (self , mock_rails_factory ):
1404+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "sanitized output" )
1405+ middleware = create_middleware_with_rails (mock_rails )
1406+
1407+ state = {
1408+ "messages" : [
1409+ HumanMessage (content = "Hello" ),
1410+ AIMessage (content = "original response with PII" ),
1411+ ]
1412+ }
1413+ result = await middleware .aafter_model (state , None )
1414+
1415+ assert result is not None
1416+ assert len (result ["messages" ]) == 2
1417+ assert isinstance (result ["messages" ][0 ], HumanMessage )
1418+ assert result ["messages" ][0 ].content == "Hello"
1419+ assert isinstance (result ["messages" ][1 ], AIMessage )
1420+ assert result ["messages" ][1 ].content == "sanitized output"
1421+
1422+ @pytest .mark .asyncio
1423+ async def test_output_modified_preserves_trailing_messages (self , mock_rails_factory ):
1424+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "redacted output" )
1425+ middleware = create_middleware_with_rails (mock_rails )
1426+
1427+ trailing = SystemMessage (content = "trailing" )
1428+ state = {
1429+ "messages" : [
1430+ HumanMessage (content = "Hello" ),
1431+ AIMessage (content = "bad output" ),
1432+ trailing ,
1433+ ]
1434+ }
1435+ result = await middleware .aafter_model (state , None )
1436+
1437+ assert len (result ["messages" ]) == 3
1438+ assert isinstance (result ["messages" ][1 ], AIMessage )
1439+ assert result ["messages" ][1 ].content == "redacted output"
1440+ assert result ["messages" ][2 ] is trailing
1441+
1442+ @pytest .mark .asyncio
1443+ async def test_output_modified_replaces_only_last_ai_message (self , mock_rails_factory ):
1444+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "fixed" )
1445+ middleware = create_middleware_with_rails (mock_rails )
1446+
1447+ state = {
1448+ "messages" : [
1449+ HumanMessage (content = "Hello" ),
1450+ AIMessage (content = "First response" ),
1451+ HumanMessage (content = "Follow up" ),
1452+ AIMessage (content = "Second response with PII" ),
1453+ ]
1454+ }
1455+ result = await middleware .aafter_model (state , None )
1456+
1457+ assert len (result ["messages" ]) == 4
1458+ assert result ["messages" ][1 ].content == "First response"
1459+ assert result ["messages" ][3 ].content == "fixed"
1460+
1461+ def test_sync_before_model_handles_modified (self , mock_rails_factory ):
1462+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "sanitized" )
1463+ middleware = create_middleware_with_rails (mock_rails )
1464+
1465+ state = {"messages" : [HumanMessage (content = "PII content" )]}
1466+ result = middleware .before_model (state , None )
1467+
1468+ assert result is not None
1469+ assert len (result ["messages" ]) == 1
1470+ assert isinstance (result ["messages" ][0 ], HumanMessage )
1471+ assert result ["messages" ][0 ].content == "sanitized"
1472+
1473+ def test_sync_after_model_handles_modified (self , mock_rails_factory ):
1474+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "sanitized output" )
1475+ middleware = create_middleware_with_rails (mock_rails )
1476+
1477+ state = {
1478+ "messages" : [
1479+ HumanMessage (content = "Hello" ),
1480+ AIMessage (content = "PII output" ),
1481+ ]
1482+ }
1483+ result = middleware .after_model (state , None )
1484+
1485+ assert result is not None
1486+ assert len (result ["messages" ]) == 2
1487+ assert result ["messages" ][1 ].content == "sanitized output"
1488+
1489+ @pytest .mark .asyncio
1490+ async def test_input_modified_with_empty_content (self , mock_rails_factory ):
1491+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "" )
1492+ middleware = create_middleware_with_rails (mock_rails )
1493+
1494+ state = {"messages" : [HumanMessage (content = "sensitive data" )]}
1495+ result = await middleware .abefore_model (state , None )
1496+
1497+ assert result is not None
1498+ assert len (result ["messages" ]) == 1
1499+ assert isinstance (result ["messages" ][0 ], HumanMessage )
1500+ assert result ["messages" ][0 ].content == ""
1501+
1502+ @pytest .mark .asyncio
1503+ async def test_input_modified_preserves_message_metadata (self , mock_rails_factory ):
1504+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "redacted" )
1505+ middleware = create_middleware_with_rails (mock_rails )
1506+
1507+ original = HumanMessage (
1508+ content = "my SSN is 123-45-6789" ,
1509+ id = "msg-123" ,
1510+ name = "user1" ,
1511+ additional_kwargs = {"source" : "web" },
1512+ )
1513+ state = {"messages" : [original ]}
1514+ result = await middleware .abefore_model (state , None )
1515+
1516+ modified = result ["messages" ][0 ]
1517+ assert modified .content == "redacted"
1518+ assert modified .id == "msg-123"
1519+ assert modified .name == "user1"
1520+ assert modified .additional_kwargs == {"source" : "web" }
1521+
1522+ @pytest .mark .asyncio
1523+ async def test_output_modified_preserves_message_metadata (self , mock_rails_factory ):
1524+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "safe output" )
1525+ middleware = create_middleware_with_rails (mock_rails )
1526+
1527+ original_ai = AIMessage (
1528+ content = "PII in response" ,
1529+ id = "ai-456" ,
1530+ name = "assistant" ,
1531+ additional_kwargs = {"model" : "gpt-4" },
1532+ )
1533+ state = {
1534+ "messages" : [
1535+ HumanMessage (content = "Hello" ),
1536+ original_ai ,
1537+ ]
1538+ }
1539+ result = await middleware .aafter_model (state , None )
1540+
1541+ modified = result ["messages" ][1 ]
1542+ assert modified .content == "safe output"
1543+ assert modified .id == "ai-456"
1544+ assert modified .name == "assistant"
1545+ assert modified .additional_kwargs == {"model" : "gpt-4" }
1546+
1547+ @pytest .mark .asyncio
1548+ async def test_output_modified_preserves_tool_calls (self , mock_rails_factory ):
1549+ mock_rails = mock_rails_factory (status = RailStatus .MODIFIED , content = "sanitized" )
1550+ middleware = create_middleware_with_rails (mock_rails )
1551+
1552+ tool_call = ToolCall (name = "search" , args = {"q" : "test" }, id = "tc-1" )
1553+ original_ai = AIMessage (
1554+ content = "PII response" ,
1555+ id = "ai-789" ,
1556+ tool_calls = [tool_call ],
1557+ )
1558+ state = {
1559+ "messages" : [
1560+ HumanMessage (content = "Hello" ),
1561+ original_ai ,
1562+ ]
1563+ }
1564+ result = await middleware .aafter_model (state , None )
1565+
1566+ modified = result ["messages" ][1 ]
1567+ assert modified .content == "sanitized"
1568+ assert modified .id == "ai-789"
1569+ assert len (modified .tool_calls ) == 1
1570+ assert modified .tool_calls [0 ]["name" ] == "search"
1571+ assert modified .tool_calls [0 ]["id" ] == "tc-1"
0 commit comments