-
Notifications
You must be signed in to change notification settings - Fork 675
refactor(jailbreak): Use onnx instead of pickle to load model #1715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
82633e2
fbece23
63a9a54
759e8ef
947e011
be31f73
3e7204b
c956cdc
ba907cb
b170e50
411d969
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,18 +13,24 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| from typing import Tuple | ||
|
|
||
| import numpy as np | ||
|
|
||
|
|
||
| class SnowflakeEmbed: | ||
| def __init__(self): | ||
| import torch | ||
| from transformers import AutoModel, AutoTokenizer | ||
|
|
||
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
| self.tokenizer = AutoTokenizer.from_pretrained("Snowflake/snowflake-arctic-embed-m-long") | ||
| device = os.environ.get("JAILBREAK_CHECK_DEVICE") | ||
| if device is None: | ||
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| else: | ||
| self.device = device | ||
| self.tokenizer = AutoTokenizer.from_pretrained( | ||
| "Snowflake/snowflake-arctic-embed-m-long", | ||
| trust_remote_code=True, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this needed?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which part? All of that is contained in the model card instructions on how to use the model. |
||
| ) | ||
|
erickgalinkin marked this conversation as resolved.
|
||
| self.model = AutoModel.from_pretrained( | ||
| "Snowflake/snowflake-arctic-embed-m-long", | ||
| trust_remote_code=True, | ||
|
|
@@ -43,16 +49,19 @@ def __call__(self, text: str): | |
|
|
||
| class JailbreakClassifier: | ||
| def __init__(self, random_forest_path: str): | ||
| import pickle | ||
| from onnxruntime import InferenceSession | ||
|
|
||
| self.embed = SnowflakeEmbed() | ||
| with open(random_forest_path, "rb") as fd: | ||
| self.classifier = pickle.load(fd) | ||
| # See https://onnx.ai/sklearn-onnx/auto_examples/plot_convert_decision_function.html | ||
| self.classifier = InferenceSession(random_forest_path, providers=["CPUExecutionProvider"]) | ||
|
|
||
| def __call__(self, text: str) -> Tuple[bool, float]: | ||
| e = self.embed(text) | ||
| probs = self.classifier.predict_proba([e]) | ||
| classification = np.argmax(probs) | ||
| prob = np.max(probs) | ||
| res = self.classifier.run(None, {"X": [e]}) | ||
| # InferenceSession returns a result where the first item is equivalent to argmax over probabilities | ||
| classification = res[0].item() | ||
| # The second is a list of dicts of probabilities -- the slice res[1][:2] should have only one element. | ||
| # We access the dict entry for the class. | ||
| prob = res[1][0][classification] | ||
| score = -prob if classification == 0 else prob | ||
| return bool(classification), float(score) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dead sklearn monkeypatch and stale test intent as in below
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No test covers the new
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,9 +24,9 @@ | |
|
|
||
| def test_lazy_import_does_not_require_heavy_deps(): | ||
| """ | ||
| Importing the checks module should not require torch, transformers, or sklearn unless model-based classifier is used. | ||
| Importing the checks module should not require torch, transformers, or onnxruntime unless model-based classifier is used. | ||
| """ | ||
| with mock.patch.dict(sys.modules, {"torch": None, "transformers": None, "sklearn": None}): | ||
| with mock.patch.dict(sys.modules, {"torch": None, "transformers": None, "onnxruntime": None}): | ||
| import nemoguardrails.library.jailbreak_detection.model_based.checks as checks | ||
|
|
||
| # Just importing and calling unrelated functions should not raise ImportError | ||
|
|
@@ -38,20 +38,20 @@ def test_lazy_import_does_not_require_heavy_deps(): | |
|
|
||
| def test_model_based_classifier_imports(monkeypatch): | ||
| """ | ||
| Instantiating JailbreakClassifier should require sklearn and pickle, and use SnowflakeEmbed which requires torch/transformers. | ||
| Instantiating JailbreakClassifier should require onnxruntime, and use SnowflakeEmbed which requires torch/transformers. | ||
| """ | ||
| # Mock dependencies | ||
| fake_rf = mock.MagicMock() | ||
| fake_embed = mock.MagicMock(return_value=[0.0]) | ||
| fake_pickle = types.SimpleNamespace(load=mock.MagicMock(return_value=fake_rf)) | ||
| fake_onnx = types.SimpleNamespace(InferenceSession=mock.MagicMock(return_value=fake_rf)) | ||
| fake_snowflake = mock.MagicMock(return_value=fake_embed) | ||
|
|
||
| monkeypatch.setitem( | ||
| sys.modules, | ||
| "sklearn.ensemble", | ||
| types.SimpleNamespace(RandomForestClassifier=mock.MagicMock()), | ||
| ) | ||
|
Comment on lines
49
to
53
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. stale patch |
||
| monkeypatch.setitem(sys.modules, "pickle", fake_pickle) | ||
| monkeypatch.setitem(sys.modules, "onnxruntime", fake_onnx) | ||
| monkeypatch.setitem(sys.modules, "torch", mock.MagicMock()) | ||
| monkeypatch.setitem(sys.modules, "transformers", mock.MagicMock()) | ||
|
|
||
|
|
@@ -64,7 +64,7 @@ def test_model_based_classifier_imports(monkeypatch): | |
| mock_open = mock.mock_open() | ||
| with mock.patch("builtins.open", mock_open): | ||
| # Should not raise | ||
| classifier = models.JailbreakClassifier("fake_model_path.pkl") | ||
| classifier = models.JailbreakClassifier("fake_model_path.onnx") | ||
| assert classifier is not None | ||
| # Should be callable | ||
| result = classifier("test") | ||
|
|
@@ -76,17 +76,17 @@ def test_model_based_classifier_imports(monkeypatch): | |
|
|
||
| def test_model_based_classifier_missing_deps(monkeypatch): | ||
| """ | ||
| If sklearn is missing, instantiating JailbreakClassifier should raise ImportError. | ||
| If onnxruntime is missing, instantiating JailbreakClassifier should raise ImportError. | ||
| """ | ||
| monkeypatch.setitem(sys.modules, "sklearn.ensemble", None) | ||
| monkeypatch.setitem(sys.modules, "onnxruntime", None) | ||
|
|
||
| import nemoguardrails.library.jailbreak_detection.model_based.models as models | ||
|
|
||
| # to avoid Windows permission issues | ||
| mock_open = mock.mock_open() | ||
| with mock.patch("builtins.open", mock_open): | ||
| with pytest.raises(ImportError): | ||
| models.JailbreakClassifier("fake_model_path.pkl") | ||
| models.JailbreakClassifier("fake_model_path.onnx") | ||
|
|
||
|
|
||
| # Test 4: Return None when EMBEDDING_CLASSIFIER_PATH is not set | ||
|
|
@@ -253,10 +253,61 @@ def test_initialize_model_with_valid_path(monkeypatch): | |
|
|
||
| assert result == mock_classifier | ||
|
|
||
| expected_path = str(Path(test_path).joinpath("snowflake.pkl")) | ||
| expected_path = str(Path(test_path).joinpath("snowflake.onnx")) | ||
| mock_jailbreak_classifier_class.assert_called_once_with(expected_path) | ||
|
|
||
|
|
||
| def test_initialize_model_skips_hf_hub_download_when_snowflake_onnx_exists(monkeypatch, tmp_path): | ||
| """ | ||
| When snowflake.onnx is already present under EMBEDDING_CLASSIFIER_PATH, do not call hf_hub_download. | ||
| """ | ||
| import nemoguardrails.library.jailbreak_detection.model_based.checks as checks | ||
|
|
||
| checks.initialize_model.cache_clear() | ||
|
|
||
| (tmp_path / "snowflake.onnx").write_bytes(b"") | ||
| monkeypatch.setenv("EMBEDDING_CLASSIFIER_PATH", str(tmp_path)) | ||
|
|
||
| mock_classifier = mock.MagicMock() | ||
| monkeypatch.setattr( | ||
| "nemoguardrails.library.jailbreak_detection.model_based.models.JailbreakClassifier", | ||
| mock.MagicMock(return_value=mock_classifier), | ||
| ) | ||
|
|
||
| with mock.patch("huggingface_hub.hf_hub_download") as mock_hf_hub_download: | ||
| result = checks.initialize_model() | ||
|
|
||
| assert result is mock_classifier | ||
| mock_hf_hub_download.assert_not_called() | ||
|
|
||
|
|
||
| def test_initialize_model_calls_hf_hub_download_when_snowflake_onnx_missing(monkeypatch, tmp_path): | ||
| """ | ||
| When snowflake.onnx is absent, hf_hub_download is invoked once with the NemoGuard repo and paths. | ||
| """ | ||
| import nemoguardrails.library.jailbreak_detection.model_based.checks as checks | ||
|
|
||
| checks.initialize_model.cache_clear() | ||
|
|
||
| monkeypatch.setenv("EMBEDDING_CLASSIFIER_PATH", str(tmp_path)) | ||
|
|
||
| mock_classifier = mock.MagicMock() | ||
| monkeypatch.setattr( | ||
| "nemoguardrails.library.jailbreak_detection.model_based.models.JailbreakClassifier", | ||
| mock.MagicMock(return_value=mock_classifier), | ||
| ) | ||
|
|
||
| with mock.patch("huggingface_hub.hf_hub_download") as mock_hf_hub_download: | ||
| result = checks.initialize_model() | ||
|
|
||
| assert result is mock_classifier | ||
| mock_hf_hub_download.assert_called_once_with( | ||
| repo_id="nvidia/NemoGuard-JailbreakDetect", | ||
| filename="snowflake.onnx", | ||
| local_dir=str(tmp_path), | ||
| ) | ||
|
|
||
|
|
||
| # Test 10: Test that NvEmbedE5 class no longer exists | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we pin
revision?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We certainly could. it won't hurt anything.