Skip to content

Commit 2b15b8e

Browse files
authored
Merge pull request #34 from m-misiura/fix_guardrail_checks_tests
Fix failing tests for guardrails checks
2 parents 954df0e + aec4d12 commit 2b15b8e

11 files changed

Lines changed: 282 additions & 58 deletions

File tree

nemoguardrails/server/schemas/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def validate_config_ids(cls, data: Any) -> Any:
138138
config_fields = [data.get("config_id"), data.get("config_ids"), data.get("config")]
139139
non_none_count = sum(1 for field in config_fields if field is not None)
140140
if non_none_count > 1:
141-
raise ValueError("Only one of config, config_id, or config_ids should be specified")
141+
raise ValueError("Only one of config_id or config_ids should be specified")
142142
return data
143143

144144
@field_validator("config_ids", mode="before")

scripts/discover_required_models.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -77,42 +77,22 @@ def get_active_guardrails(self) -> List[str]:
7777
logging.error(f"Missing directory: {library_path}")
7878
sys.exit(1)
7979

80-
available = [
81-
item.name
82-
for item in library_path.iterdir()
83-
if item.is_dir() and not item.name.startswith("_")
84-
]
85-
return (
86-
available
87-
if include_closed
88-
else [gr for gr in available if gr not in closed_source]
89-
)
80+
available = [item.name for item in library_path.iterdir() if item.is_dir() and not item.name.startswith("_")]
81+
return available if include_closed else [gr for gr in available if gr not in closed_source]
9082

9183
@staticmethod
9284
def _extract_from_ast(tree: ast.AST) -> Dict[str, Set[str]]:
9385
models = {k: set() for k in ModelDiscoverer.MODEL_KEYS}
9486
for node in ast.walk(tree):
9587
if (
9688
isinstance(node, ast.Call)
97-
and getattr(getattr(node.func, "attr", None), "lower", lambda: "")()
98-
== "load"
89+
and getattr(getattr(node.func, "attr", None), "lower", lambda: "")() == "load"
9990
and getattr(getattr(node.func, "value", None), "id", None) == "spacy"
10091
):
101-
if (
102-
node.args
103-
and isinstance(node.args[0], ast.Constant)
104-
and isinstance(node.args[0].value, str)
105-
):
92+
if node.args and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
10693
models["spacy"].add(node.args[0].value)
107-
if (
108-
isinstance(node, ast.Call)
109-
and getattr(node.func, "id", None) == "SentenceTransformer"
110-
):
111-
if (
112-
node.args
113-
and isinstance(node.args[0], ast.Constant)
114-
and isinstance(node.args[0].value, str)
115-
):
94+
if isinstance(node, ast.Call) and getattr(node.func, "id", None) == "SentenceTransformer":
95+
if node.args and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
11696
name = node.args[0].value
11797
if not name.startswith("sentence-transformers/"):
11898
name = f"sentence-transformers/{name}"
@@ -130,11 +110,7 @@ def _extract_from_ast(tree: ast.AST) -> Dict[str, Set[str]]:
130110
and getattr(node.func, "attr", None) == "download"
131111
and getattr(getattr(node.func, "value", None), "id", None) == "nltk"
132112
):
133-
if (
134-
node.args
135-
and isinstance(node.args[0], ast.Constant)
136-
and isinstance(node.args[0].value, str)
137-
):
113+
if node.args and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str):
138114
models["nltk"].add(node.args[0].value)
139115
return models
140116

@@ -182,9 +158,7 @@ def discover(self) -> Dict[str, Set[str]]:
182158
def print_summary(self):
183159
active_guardrails = self.get_active_guardrails()
184160
print(f"Discovering models for profile: {self.profile}")
185-
print(
186-
f"Active guardrails ({len(active_guardrails)}): {', '.join(active_guardrails)}"
187-
)
161+
print(f"Active guardrails ({len(active_guardrails)}): {', '.join(active_guardrails)}")
188162
for category in self.MODEL_KEYS:
189163
models = self.models[category]
190164
if models:

scripts/filter_guardrails.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616

1717
import logging
18-
import os
1918
import shutil
2019
import sys
2120
from pathlib import Path
@@ -39,9 +38,7 @@ def main():
3938
config = yaml.safe_load(f)
4039

4140
if profile not in config["profiles"]:
42-
logger.error(
43-
f"Profile '{profile}' not found. Available: {list(config['profiles'].keys())}"
44-
)
41+
logger.error(f"Profile '{profile}' not found. Available: {list(config['profiles'].keys())}")
4542
sys.exit(1)
4643

4744
include_closed_source = config["profiles"][profile]["include_closed_source"]
@@ -59,11 +56,7 @@ def main():
5956
removed_dirs = []
6057

6158
for guardrail_dir in library_path.iterdir():
62-
if (
63-
not guardrail_dir.is_dir()
64-
or guardrail_dir.name.startswith(".")
65-
or guardrail_dir.name.startswith("__")
66-
):
59+
if not guardrail_dir.is_dir() or guardrail_dir.name.startswith(".") or guardrail_dir.name.startswith("__"):
6760
continue
6861

6962
guardrail_name = guardrail_dir.name
@@ -78,9 +71,7 @@ def main():
7871
logger.info(f"Keeping {source_type}: {guardrail_name}")
7972
kept_dirs.append(guardrail_name)
8073

81-
logger.info(
82-
f"\nSummary: kept {len(kept_dirs)}, removed {len(removed_dirs)} guardrails"
83-
)
74+
logger.info(f"\nSummary: kept {len(kept_dirs)}, removed {len(removed_dirs)} guardrails")
8475

8576

8677
if __name__ == "__main__":

scripts/pre_download_required_models.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -89,9 +89,7 @@ def download_sentence_transformers_models(models):
8989
sentence_transformers.SentenceTransformer(model_name)
9090
logging.info(f"Downloaded Sentence Transformers model: {model_name}")
9191
except Exception as e:
92-
logging.warning(
93-
f"Failed to download Sentence Transformers model {model_name}: {e}"
94-
)
92+
logging.warning(f"Failed to download Sentence Transformers model {model_name}: {e}")
9593

9694

9795
def download_fastembed_models(models):
@@ -143,9 +141,7 @@ def download_huggingface_models(models):
143141
except Exception as e2:
144142
logging.warning(f"Failed to download {model_name}: {e2}")
145143
else:
146-
logging.warning(
147-
f"Failed to download HuggingFace model {model_name}: {e}"
148-
)
144+
logging.warning(f"Failed to download HuggingFace model {model_name}: {e}")
149145

150146

151147
def download_nltk_data():
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from nemoguardrails.actions import action
17+
18+
19+
@action(is_system_action=True)
20+
async def check_forbidden_words(context: dict = {}):
21+
"""Check if the message contains forbidden words."""
22+
user_message = context.get("user_message", "").lower()
23+
24+
forbidden_categories = {
25+
"security": ["password", "hack", "exploit", "vulnerability"],
26+
"inappropriate": ["violence", "illegal", "harmful"],
27+
"competitors": ["chatgpt", "openai", "claude", "anthropic"],
28+
}
29+
30+
for category, words in forbidden_categories.items():
31+
for word in words:
32+
if word in user_message:
33+
return {"status": "blocked", "category": category, "word": word}
34+
35+
return {"status": "allowed"}
36+
37+
38+
@action(is_system_action=True)
39+
async def check_output_length(context: dict = {}):
40+
"""Check if the bot message is too long."""
41+
bot_msg = context.get("bot_message", "")
42+
return "blocked" if len(bot_msg.split()) > 100 else "allowed"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
models:
2+
- type: main
3+
engine: openai
4+
model: test
5+
6+
instructions:
7+
- type: general
8+
content: |
9+
You are a helpful assistant.
10+
11+
rails:
12+
input:
13+
flows:
14+
- check forbidden words
15+
output:
16+
flows:
17+
- check output length
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
define flow check forbidden words
2+
$result = execute check_forbidden_words
3+
4+
if $result.status == "blocked"
5+
bot inform forbidden content
6+
stop
7+
8+
define bot inform forbidden content
9+
"I can't answer questions about closed source AI models"
10+
11+
define flow check output length
12+
$result = execute check_output_length
13+
14+
if $result == "blocked"
15+
bot inform output too long
16+
stop
17+
18+
define bot inform output too long
19+
"The response is too long."
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import re
17+
18+
from nemoguardrails.actions import action
19+
20+
21+
@action(is_system_action=True)
22+
async def check_forbidden_words(context: dict = {}):
23+
"""Check if the message contains forbidden words."""
24+
user_message = context.get("user_message", "").lower()
25+
26+
forbidden_categories = {
27+
"security": ["password", "hack", "exploit", "vulnerability"],
28+
"inappropriate": ["violence", "illegal", "harmful"],
29+
"competitors": ["chatgpt", "openai", "claude", "anthropic"],
30+
}
31+
32+
for category, words in forbidden_categories.items():
33+
for word in words:
34+
if word in user_message:
35+
return {"status": "blocked", "category": category, "word": word}
36+
37+
return {"status": "allowed"}
38+
39+
40+
@action(is_system_action=True)
41+
async def check_output_length(context: dict = {}):
42+
"""Check if the bot message is too long."""
43+
bot_msg = context.get("bot_message", "")
44+
return "blocked" if len(bot_msg.split()) > 100 else "allowed"
45+
46+
47+
@action(is_system_action=True)
48+
async def check_tool_response_safety(tool_message: str = None, context: dict = None):
49+
"""Validate tool responses for sensitive data leakage."""
50+
if tool_message is None:
51+
tool_message = context.get("tool_message", "") if context else ""
52+
53+
if not tool_message:
54+
return "allowed"
55+
56+
credential_patterns = {
57+
"password": r"password[:\s=]+\w+",
58+
"api_key": r"(?:api[_\s-]?key|apikey)[:\s=]+[\w-]+",
59+
"secret": r"secret[:\s=]+\w+",
60+
"token": r"(?:access[_\s]?token|bearer)[:\s=]+[\w.-]+",
61+
"private_key": r"-----BEGIN (?:RSA |EC )?PRIVATE KEY-----",
62+
}
63+
64+
tool_message_lower = tool_message.lower()
65+
66+
for pattern_name, pattern in credential_patterns.items():
67+
if re.search(pattern, tool_message_lower):
68+
return "blocked"
69+
70+
return "allowed"
71+
72+
73+
@action(is_system_action=True)
74+
async def check_tool_call_safety(tool_calls=None, context=None):
75+
"""Validate tool calls before execution using an allow list approach."""
76+
if tool_calls is None:
77+
tool_calls = context.get("tool_calls", []) if context else []
78+
79+
allowed_tools = [
80+
"get_weather",
81+
"search_web",
82+
"read_file",
83+
"get_time",
84+
"get_stock_price",
85+
"calculate",
86+
]
87+
88+
dangerous_patterns = {
89+
"path_traversal": r"\.\./",
90+
"command_injection": r"[;&|`$]",
91+
"sql_injection": r"(?:DROP|DELETE|TRUNCATE)\s+(?:TABLE|DATABASE)",
92+
}
93+
94+
for tool_call in tool_calls:
95+
tool_name = tool_call.get("name", "")
96+
97+
if tool_name not in allowed_tools:
98+
return "blocked"
99+
100+
args = tool_call.get("args", {})
101+
for arg_name, arg_value in args.items():
102+
if isinstance(arg_value, str):
103+
for pattern_name, pattern in dangerous_patterns.items():
104+
if re.search(pattern, arg_value, re.IGNORECASE):
105+
return "blocked"
106+
107+
return "allowed"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
models:
2+
- type: main
3+
engine: openai
4+
model: test
5+
6+
instructions:
7+
- type: general
8+
content: |
9+
You are a helpful assistant.
10+
11+
passthrough: true
12+
13+
rails:
14+
input:
15+
flows:
16+
- check forbidden words
17+
output:
18+
flows:
19+
- check output length
20+
tool_input:
21+
flows:
22+
- check tool response safety
23+
tool_output:
24+
flows:
25+
- check tool call safety

0 commit comments

Comments
 (0)