Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nemoguardrails/llm/output_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def nemoguard_parse_prompt_safety(response: str) -> Sequence[Union[bool, str]]:
parsed_json_result = json.loads(response)
assert "User Safety" in parsed_json_result
result = parsed_json_result["User Safety"].lower()
if "Safety Categories" in parsed_json_result:
safety_categories = [cat.strip() for cat in parsed_json_result["Safety Categories"].split(",")]
if "Violated Categories" in parsed_json_result:
safety_categories = [cat.strip() for cat in parsed_json_result["Violated Categories"].split(",")]
else:
safety_categories = []
except Exception:
Expand Down Expand Up @@ -199,8 +199,8 @@ def nemoguard_parse_response_safety(response: str) -> Sequence[Union[bool, str]]
parsed_json_result = json.loads(response)
assert "Response Safety" in parsed_json_result
result = parsed_json_result["Response Safety"].lower()
if "Safety Categories" in parsed_json_result:
safety_categories = [cat.strip() for cat in parsed_json_result["Safety Categories"].split(",")]
if "Violated Categories" in parsed_json_result:
safety_categories = [cat.strip() for cat in parsed_json_result["Violated Categories"].split(",")]
else:
safety_categories = []
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions tests/guardrails/test_content_safety_iorails_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@
UNSAFE_JSON = json.dumps(
{
"User Safety": "unsafe",
"Safety Categories": "S1: Violence, S3: Criminal Planning/Confessions",
"Violated Categories": "S1: Violence, S3: Criminal Planning/Confessions",
}
)
SAFE_OUTPUT_JSON = json.dumps({"User Safety": "safe", "Response Safety": "safe"})
UNSAFE_OUTPUT_JSON = json.dumps(
{
"User Safety": "safe",
"Response Safety": "unsafe",
"Safety Categories": "S17: Malware",
"Violated Categories": "S17: Malware",
}
)

Expand Down
4 changes: 2 additions & 2 deletions tests/guardrails/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
<END CONVERSATION>

Provide your safety assessment for the above conversation in the below output JSON format:
{"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}
{"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Violated Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}

Do not include anything other than the output JSON in your response.
Output JSON:
Expand Down Expand Up @@ -97,7 +97,7 @@
<END CONVERSATION>

Provide your safety assessment for the above conversation in the below output JSON format:
{"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}
{"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Violated Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.}

Do not include anything other than the output JSON in your response.
Output JSON:
Expand Down
2 changes: 1 addition & 1 deletion tests/guardrails/test_iorails_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ async def capture_then_pass(messages):

SAFE_INPUT_JSON = json.dumps({"User Safety": "safe"})
SAFE_OUTPUT_JSON = json.dumps({"User Safety": "safe", "Response Safety": "safe"})
UNSAFE_INPUT_JSON = json.dumps({"User Safety": "unsafe", "Safety Categories": "S1: Violence"})
UNSAFE_INPUT_JSON = json.dumps({"User Safety": "unsafe", "Violated Categories": "S1: Violence"})


def _stub_deep_pipeline(iorails, main_llm_response="Hello", input_safe=True):
Expand Down
4 changes: 2 additions & 2 deletions tests/guardrails/test_rails_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
)

SAFE_INPUT_JSON = json.dumps({"User Safety": "safe"})
UNSAFE_INPUT_JSON = json.dumps({"User Safety": "unsafe", "Safety Categories": "S1: Violence"})
UNSAFE_INPUT_JSON = json.dumps({"User Safety": "unsafe", "Violated Categories": "S1: Violence"})
SAFE_OUTPUT_JSON = json.dumps({"User Safety": "safe", "Response Safety": "safe"})
UNSAFE_OUTPUT_JSON = json.dumps(
{
"User Safety": "safe",
"Response Safety": "unsafe",
"Safety Categories": "S17: Malware",
"Violated Categories": "S17: Malware",
}
)
MESSAGES = [{"role": "user", "content": "hello"}]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_content_safety_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def test_content_safety_input_with_nemoguard_parser_unsafe_with_categories
self,
):
"""Test input action with real nemoguard_parse_prompt_safety parser - unsafe with categories."""
json_response = '{"User Safety": "unsafe", "Safety Categories": "S1, S8, S10"}'
json_response = '{"User Safety": "unsafe", "Violated Categories": "S1, S8, S10"}'
parsed_result = nemoguard_parse_prompt_safety(json_response)
llms, mock_task_manager = _create_mock_setup([json_response], parsed_result)
context = _create_input_context("Potentially harmful content")
Expand All @@ -181,7 +181,7 @@ async def test_content_safety_input_with_nemoguard_parser_unsafe_with_categories
[
('{"Response Safety": "safe"}', True, []),
(
'{"Response Safety": "unsafe", "Safety Categories": "Violence, Hate Speech"}',
'{"Response Safety": "unsafe", "Violated Categories": "Violence, Hate Speech"}',
False,
["Violence", "Hate Speech"],
),
Expand Down Expand Up @@ -251,13 +251,13 @@ def test_iterable_unpacking_with_is_content_safe_outputs(self, response, expecte
[
('{"User Safety": "safe"}', True, []),
(
'{"User Safety": "unsafe", "Safety Categories": "S1, S8"}',
'{"User Safety": "unsafe", "Violated Categories": "S1, S8"}',
False,
["S1", "S8"],
),
('{"Response Safety": "safe"}', True, []),
(
'{"Response Safety": "unsafe", "Safety Categories": "Violence, Hate"}',
'{"Response Safety": "unsafe", "Violated Categories": "Violence, Hate"}',
False,
["Violence", "Hate"],
),
Expand Down
50 changes: 32 additions & 18 deletions tests/test_content_safety_output_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,22 @@ def test_unsafe_user_content_json(self):
assert is_safe is False
assert violated_policies == []

def test_unsafe_with_safety_categories(self):
"""Test parsing unsafe content with safety categories."""
response = '{"User Safety": "unsafe", "Safety Categories": "S1, S8, S10"}'
def test_unsafe_with_violated_categories(self):
"""Test parsing unsafe content with violation categories (correct NemoGuard key)."""
response = '{"User Safety": "unsafe", "Violated Categories": "S1, S8, S10"}'
is_safe, *violated_policies = nemoguard_parse_prompt_safety(response)
assert is_safe is False
assert "S1" in violated_policies
assert "S8" in violated_policies
assert "S10" in violated_policies

def test_wrong_key_safety_categories_yields_no_categories(self):
"""Regression: old wrong key 'Safety Categories' should not extract categories."""
response = '{"User Safety": "unsafe", "Safety Categories": "S1, S8"}'
is_safe, *violated_policies = nemoguard_parse_prompt_safety(response)
assert is_safe is False
assert violated_policies == []

def test_case_insensitive_safety_status(self):
"""Test parsing is case insensitive for safety status."""
response = '{"User Safety": "SAFE"}'
Expand All @@ -151,16 +158,16 @@ def test_case_insensitive_safety_status(self):

def test_categories_with_whitespace_trimming(self):
"""Test parsing categories with extra whitespace gets trimmed."""
response = '{"User Safety": "unsafe", "Safety Categories": " S1 , S8 , S10 "}'
response = '{"User Safety": "unsafe", "Violated Categories": " S1 , S8 , S10 "}'
is_safe, *violated_policies = nemoguard_parse_prompt_safety(response)
assert is_safe is False
assert "S1" in violated_policies
assert "S8" in violated_policies
assert "S10" in violated_policies

def test_empty_safety_categories(self):
"""Test parsing with empty safety categories string."""
response = '{"User Safety": "unsafe", "Safety Categories": ""}'
def test_empty_violated_categories(self):
"""Test parsing with empty violated categories string."""
response = '{"User Safety": "unsafe", "Violated Categories": ""}'
is_safe, *violated_policies = nemoguard_parse_prompt_safety(response)
assert is_safe is False
assert violated_policies == [""]
Expand All @@ -187,15 +194,15 @@ def test_missing_user_safety_field(self):
assert violated_policies == ["JSON parsing failed"]

def test_single_category(self):
"""Test parsing with single safety category."""
response = '{"User Safety": "unsafe", "Safety Categories": "Violence"}'
"""Test parsing with single violated category."""
response = '{"User Safety": "unsafe", "Violated Categories": "Violence"}'
is_safe, *violated_policies = nemoguard_parse_prompt_safety(response)
assert is_safe is False
assert violated_policies == ["Violence"]

def test_complex_category_names(self):
"""Test parsing with descriptive category names."""
response = '{"User Safety": "unsafe", "Safety Categories": "Violence, Hate Speech, Sexual Content"}'
response = '{"User Safety": "unsafe", "Violated Categories": "Violence, Hate Speech, Sexual Content"}'
is_safe, *violated_policies = nemoguard_parse_prompt_safety(response)
assert is_safe is False
assert "Violence" in violated_policies
Expand All @@ -220,15 +227,22 @@ def test_unsafe_response_content_json(self):
assert is_safe is False
assert violated_policies == []

def test_unsafe_with_safety_categories(self):
"""Test parsing unsafe response with safety categories."""
response = '{"Response Safety": "unsafe", "Safety Categories": "S1, S8, S10"}'
def test_unsafe_with_violated_categories(self):
"""Test parsing unsafe response with violated categories (correct NemoGuard key)."""
response = '{"Response Safety": "unsafe", "Violated Categories": "S1, S8, S10"}'
is_safe, *violated_policies = nemoguard_parse_response_safety(response)
assert is_safe is False
assert "S1" in violated_policies
assert "S8" in violated_policies
assert "S10" in violated_policies

def test_wrong_key_safety_categories_yields_no_categories(self):
"""Regression: old wrong key 'Safety Categories' should not extract categories."""
response = '{"Response Safety": "unsafe", "Safety Categories": "S1, S8"}'
is_safe, *violated_policies = nemoguard_parse_response_safety(response)
assert is_safe is False
assert violated_policies == []

def test_case_insensitive_safety_status(self):
"""Test parsing is case insensitive for safety status."""
response = '{"Response Safety": "SAFE"}'
Expand All @@ -238,7 +252,7 @@ def test_case_insensitive_safety_status(self):

def test_categories_with_whitespace_trimming(self):
"""Test parsing categories with extra whitespace gets trimmed."""
response = '{"Response Safety": "unsafe", "Safety Categories": " S1 , S8 , S10 "}'
response = '{"Response Safety": "unsafe", "Violated Categories": " S1 , S8 , S10 "}'
is_safe, *violated_policies = nemoguard_parse_response_safety(response)
assert is_safe is False
assert "S1" in violated_policies
Expand Down Expand Up @@ -268,7 +282,7 @@ def test_missing_response_safety_field(self):

def test_full_nemoguard_response(self):
"""Test parsing a full NemoGuard response with both user and response safety."""
response = '{"User Safety": "safe", "Response Safety": "unsafe", "Safety Categories": "S1, S8"}'
response = '{"User Safety": "safe", "Response Safety": "unsafe", "Violated Categories": "S1, S8"}'
is_safe, *violated_policies = nemoguard_parse_response_safety(response)
assert is_safe is False
assert "S1" in violated_policies
Expand Down Expand Up @@ -298,13 +312,13 @@ def test_llama_guard_typical_responses(self):

def test_nemoguard_content_safety_responses(self):
"""Test typical NemoGuard ContentSafety model responses."""
response = '{"User Safety": "unsafe", "Safety Categories": "S1: Violence, S8: Hate/Identity Hate"}'
response = '{"User Safety": "unsafe", "Violated Categories": "S1: Violence, S8: Hate/Identity Hate"}'
is_safe, *violated_policies = nemoguard_parse_prompt_safety(response)
assert is_safe is False
assert "S1: Violence" in violated_policies
assert "S8: Hate/Identity Hate" in violated_policies

response = '{"User Safety": "safe", "Response Safety": "unsafe", "Safety Categories": "S11: Sexual Content"}'
response = '{"User Safety": "safe", "Response Safety": "unsafe", "Violated Categories": "S11: Sexual Content"}'
is_safe, *violated_policies = nemoguard_parse_response_safety(response)
assert is_safe is False
assert violated_policies == ["S11: Sexual Content"]
Expand Down Expand Up @@ -359,7 +373,7 @@ def test_starred_unpacking_compatibility(self):
assert is_safe is True
assert violated_policies == []

response = '{"Response Safety": "unsafe", "Safety Categories": "S1, S8"}'
response = '{"Response Safety": "unsafe", "Violated Categories": "S1, S8"}'
result = nemoguard_parse_response_safety(response)
is_safe, *violated_policies = result
assert is_safe is False
Expand Down