Skip to content

Commit 8cbe943

Browse files
authored
fix(middleware): handle MODIFIED status in GuardrailsMiddleware instead of silently dropping it (#1714)
1 parent 2646995 commit 8cbe943

3 files changed

Lines changed: 249 additions & 2 deletions

File tree

docs/integration/langchain/agent-middleware.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ guardrails = GuardrailsMiddleware(
326326

327327
Rails evaluate the `content` field of messages, not the `tool_calls` arguments. Content-based rails do not inspect PII or harmful content passed through tool call arguments (e.g., `send_email(body="SSN: 123-45-6789")`).
328328

329-
### MODIFIED Status Is Ignored
329+
### MODIFIED Status Replaces Message Content
330330

331-
When a rail modifies content (returns `RailStatus.MODIFIED`), the middleware treats it as a pass-through and the agent uses the original, unmodified content. This is by design — applying modifications to the agent's internal state could cause inconsistencies.
331+
When a rail modifies content (returns `RailStatus.MODIFIED`), the middleware replaces the relevant message with the modified content. For input rails, the last user message is replaced. For output rails, the last AI message is replaced. This enables use cases like PII redaction and content sanitization.
332332

333333
---
334334

nemoguardrails/integrations/langchain/middleware.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ async def abefore_model(self, state: AgentState, runtime: LangGraphRuntime) -> O
111111
if not messages:
112112
return None
113113

114+
last_user_message = self._get_last_user_message(messages)
115+
if not last_user_message:
116+
return None
117+
114118
rails_messages = self._convert_to_rails_messages(messages)
115119

116120
try:
@@ -125,6 +129,11 @@ async def abefore_model(self, state: AgentState, runtime: LangGraphRuntime) -> O
125129
blocked_msg = create_ai_message(self.blocked_input_message)
126130
return {"messages": messages + [blocked_msg], "jump_to": "end"}
127131

132+
if result.status == RailStatus.MODIFIED:
133+
log.info("Input modified by rail '%s': content replaced", result.rail or "unknown rail")
134+
modified_msg = last_user_message.model_copy(update={"content": result.content})
135+
return {"messages": self._replace_last_human_message(messages, modified_msg)}
136+
128137
return None
129138

130139
except GuardrailViolation:
@@ -141,6 +150,12 @@ async def abefore_model(self, state: AgentState, runtime: LangGraphRuntime) -> O
141150
blocked_msg = create_ai_message(self.blocked_input_message)
142151
return {"messages": messages + [blocked_msg], "jump_to": "end"}
143152

153+
def _replace_last_human_message(self, messages: list, replacement: HumanMessage) -> list:
154+
for i in range(len(messages) - 1, -1, -1):
155+
if is_human_message(messages[i]):
156+
return messages[:i] + [replacement] + messages[i + 1 :]
157+
return messages + [replacement]
158+
144159
def _replace_last_ai_message(self, messages: list, replacement: AIMessage) -> list:
145160
for i in range(len(messages) - 1, -1, -1):
146161
if is_ai_message(messages[i]):
@@ -173,6 +188,11 @@ async def aafter_model(self, state: AgentState, runtime: LangGraphRuntime) -> Op
173188
blocked_msg = create_ai_message(self.blocked_output_message)
174189
return {"messages": self._replace_last_ai_message(messages, blocked_msg)}
175190

191+
if result.status == RailStatus.MODIFIED:
192+
log.info("Output modified by rail '%s': content replaced", result.rail or "unknown rail")
193+
modified_msg = last_ai_message.model_copy(update={"content": result.content})
194+
return {"messages": self._replace_last_ai_message(messages, modified_msg)}
195+
176196
return None
177197

178198
except GuardrailViolation:

tests/integrations/langchain/test_middleware.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,3 +1342,230 @@ async def test_error_handler_also_replaces_correctly(self, mock_rails_factory):
13421342
assert result["messages"][1].content == middleware.blocked_output_message
13431343
assert isinstance(result["messages"][1], AIMessage)
13441344
assert result["messages"][2] is trailing_msg
1345+
1346+
1347+
class TestModifiedStatus:
1348+
@pytest.mark.asyncio
1349+
async def test_input_modified_replaces_last_human_message(self, mock_rails_factory):
1350+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="sanitized input")
1351+
middleware = create_middleware_with_rails(mock_rails)
1352+
1353+
state = {"messages": [HumanMessage(content="original input with PII")]}
1354+
result = await middleware.abefore_model(state, None)
1355+
1356+
assert result is not None
1357+
assert "jump_to" not in result
1358+
assert len(result["messages"]) == 1
1359+
assert isinstance(result["messages"][0], HumanMessage)
1360+
assert result["messages"][0].content == "sanitized input"
1361+
1362+
@pytest.mark.asyncio
1363+
async def test_input_modified_preserves_surrounding_messages(self, mock_rails_factory):
1364+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="redacted")
1365+
middleware = create_middleware_with_rails(mock_rails)
1366+
1367+
system_msg = SystemMessage(content="Be helpful")
1368+
state = {
1369+
"messages": [
1370+
system_msg,
1371+
HumanMessage(content="my SSN is 123-45-6789"),
1372+
]
1373+
}
1374+
result = await middleware.abefore_model(state, None)
1375+
1376+
assert len(result["messages"]) == 2
1377+
assert result["messages"][0] is system_msg
1378+
assert isinstance(result["messages"][1], HumanMessage)
1379+
assert result["messages"][1].content == "redacted"
1380+
1381+
@pytest.mark.asyncio
1382+
async def test_input_modified_with_multi_turn_history(self, mock_rails_factory):
1383+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="cleaned message")
1384+
middleware = create_middleware_with_rails(mock_rails)
1385+
1386+
first_human = HumanMessage(content="Hello")
1387+
first_ai = AIMessage(content="Hi there")
1388+
state = {
1389+
"messages": [
1390+
first_human,
1391+
first_ai,
1392+
HumanMessage(content="my email is foo@bar.com"),
1393+
]
1394+
}
1395+
result = await middleware.abefore_model(state, None)
1396+
1397+
assert len(result["messages"]) == 3
1398+
assert result["messages"][0] is first_human
1399+
assert result["messages"][1] is first_ai
1400+
assert result["messages"][2].content == "cleaned message"
1401+
1402+
@pytest.mark.asyncio
1403+
async def test_output_modified_replaces_last_ai_message(self, mock_rails_factory):
1404+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="sanitized output")
1405+
middleware = create_middleware_with_rails(mock_rails)
1406+
1407+
state = {
1408+
"messages": [
1409+
HumanMessage(content="Hello"),
1410+
AIMessage(content="original response with PII"),
1411+
]
1412+
}
1413+
result = await middleware.aafter_model(state, None)
1414+
1415+
assert result is not None
1416+
assert len(result["messages"]) == 2
1417+
assert isinstance(result["messages"][0], HumanMessage)
1418+
assert result["messages"][0].content == "Hello"
1419+
assert isinstance(result["messages"][1], AIMessage)
1420+
assert result["messages"][1].content == "sanitized output"
1421+
1422+
@pytest.mark.asyncio
1423+
async def test_output_modified_preserves_trailing_messages(self, mock_rails_factory):
1424+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="redacted output")
1425+
middleware = create_middleware_with_rails(mock_rails)
1426+
1427+
trailing = SystemMessage(content="trailing")
1428+
state = {
1429+
"messages": [
1430+
HumanMessage(content="Hello"),
1431+
AIMessage(content="bad output"),
1432+
trailing,
1433+
]
1434+
}
1435+
result = await middleware.aafter_model(state, None)
1436+
1437+
assert len(result["messages"]) == 3
1438+
assert isinstance(result["messages"][1], AIMessage)
1439+
assert result["messages"][1].content == "redacted output"
1440+
assert result["messages"][2] is trailing
1441+
1442+
@pytest.mark.asyncio
1443+
async def test_output_modified_replaces_only_last_ai_message(self, mock_rails_factory):
1444+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="fixed")
1445+
middleware = create_middleware_with_rails(mock_rails)
1446+
1447+
state = {
1448+
"messages": [
1449+
HumanMessage(content="Hello"),
1450+
AIMessage(content="First response"),
1451+
HumanMessage(content="Follow up"),
1452+
AIMessage(content="Second response with PII"),
1453+
]
1454+
}
1455+
result = await middleware.aafter_model(state, None)
1456+
1457+
assert len(result["messages"]) == 4
1458+
assert result["messages"][1].content == "First response"
1459+
assert result["messages"][3].content == "fixed"
1460+
1461+
def test_sync_before_model_handles_modified(self, mock_rails_factory):
1462+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="sanitized")
1463+
middleware = create_middleware_with_rails(mock_rails)
1464+
1465+
state = {"messages": [HumanMessage(content="PII content")]}
1466+
result = middleware.before_model(state, None)
1467+
1468+
assert result is not None
1469+
assert len(result["messages"]) == 1
1470+
assert isinstance(result["messages"][0], HumanMessage)
1471+
assert result["messages"][0].content == "sanitized"
1472+
1473+
def test_sync_after_model_handles_modified(self, mock_rails_factory):
1474+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="sanitized output")
1475+
middleware = create_middleware_with_rails(mock_rails)
1476+
1477+
state = {
1478+
"messages": [
1479+
HumanMessage(content="Hello"),
1480+
AIMessage(content="PII output"),
1481+
]
1482+
}
1483+
result = middleware.after_model(state, None)
1484+
1485+
assert result is not None
1486+
assert len(result["messages"]) == 2
1487+
assert result["messages"][1].content == "sanitized output"
1488+
1489+
@pytest.mark.asyncio
1490+
async def test_input_modified_with_empty_content(self, mock_rails_factory):
1491+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="")
1492+
middleware = create_middleware_with_rails(mock_rails)
1493+
1494+
state = {"messages": [HumanMessage(content="sensitive data")]}
1495+
result = await middleware.abefore_model(state, None)
1496+
1497+
assert result is not None
1498+
assert len(result["messages"]) == 1
1499+
assert isinstance(result["messages"][0], HumanMessage)
1500+
assert result["messages"][0].content == ""
1501+
1502+
@pytest.mark.asyncio
1503+
async def test_input_modified_preserves_message_metadata(self, mock_rails_factory):
1504+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="redacted")
1505+
middleware = create_middleware_with_rails(mock_rails)
1506+
1507+
original = HumanMessage(
1508+
content="my SSN is 123-45-6789",
1509+
id="msg-123",
1510+
name="user1",
1511+
additional_kwargs={"source": "web"},
1512+
)
1513+
state = {"messages": [original]}
1514+
result = await middleware.abefore_model(state, None)
1515+
1516+
modified = result["messages"][0]
1517+
assert modified.content == "redacted"
1518+
assert modified.id == "msg-123"
1519+
assert modified.name == "user1"
1520+
assert modified.additional_kwargs == {"source": "web"}
1521+
1522+
@pytest.mark.asyncio
1523+
async def test_output_modified_preserves_message_metadata(self, mock_rails_factory):
1524+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="safe output")
1525+
middleware = create_middleware_with_rails(mock_rails)
1526+
1527+
original_ai = AIMessage(
1528+
content="PII in response",
1529+
id="ai-456",
1530+
name="assistant",
1531+
additional_kwargs={"model": "gpt-4"},
1532+
)
1533+
state = {
1534+
"messages": [
1535+
HumanMessage(content="Hello"),
1536+
original_ai,
1537+
]
1538+
}
1539+
result = await middleware.aafter_model(state, None)
1540+
1541+
modified = result["messages"][1]
1542+
assert modified.content == "safe output"
1543+
assert modified.id == "ai-456"
1544+
assert modified.name == "assistant"
1545+
assert modified.additional_kwargs == {"model": "gpt-4"}
1546+
1547+
@pytest.mark.asyncio
1548+
async def test_output_modified_preserves_tool_calls(self, mock_rails_factory):
1549+
mock_rails = mock_rails_factory(status=RailStatus.MODIFIED, content="sanitized")
1550+
middleware = create_middleware_with_rails(mock_rails)
1551+
1552+
tool_call = ToolCall(name="search", args={"q": "test"}, id="tc-1")
1553+
original_ai = AIMessage(
1554+
content="PII response",
1555+
id="ai-789",
1556+
tool_calls=[tool_call],
1557+
)
1558+
state = {
1559+
"messages": [
1560+
HumanMessage(content="Hello"),
1561+
original_ai,
1562+
]
1563+
}
1564+
result = await middleware.aafter_model(state, None)
1565+
1566+
modified = result["messages"][1]
1567+
assert modified.content == "sanitized"
1568+
assert modified.id == "ai-789"
1569+
assert len(modified.tool_calls) == 1
1570+
assert modified.tool_calls[0]["name"] == "search"
1571+
assert modified.tool_calls[0]["id"] == "tc-1"

0 commit comments

Comments
 (0)