diff --git a/garak/probes/topic.py b/garak/probes/topic.py index 04efc4895..5aab0c39a 100644 --- a/garak/probes/topic.py +++ b/garak/probes/topic.py @@ -109,7 +109,9 @@ def __init__(self, config_root=_config): self.w = None try: self.w = wn.Wordnet(self.lexicon) - except sqlite3.OperationalError: + except (sqlite3.OperationalError, wn.Error): + # sqlite3.OperationalError: the wordnet database has not been created yet + # wn.Error: the database exists but the requested lexicon is not installed logging.debug("Downloading wordnet lexicon: %s", self.lexicon) download_tempfile_path = wn.download(self.lexicon) self.w = wn.Wordnet(self.lexicon) diff --git a/tests/probes/test_probes_topic.py b/tests/probes/test_probes_topic.py index 2d570c97b..b2ad86211 100644 --- a/tests/probes/test_probes_topic.py +++ b/tests/probes/test_probes_topic.py @@ -1,7 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import pathlib import pytest +import sqlite3 +from unittest.mock import MagicMock, patch + import wn import garak._plugins @@ -71,3 +75,42 @@ def test_topic_wordnet_blocklist_get_initial_nodes(sysnet): wn.synset("oewn-00231342-n"), wn.synset("oewn-00232028-n"), ] + + +@pytest.mark.parametrize( + "first_call_error", + [ + sqlite3.OperationalError("database not initialised"), + wn.Error("requested lexicon is not installed"), + ], + ids=["sqlite_missing_database", "wn_missing_lexicon"], +) +def test_topic_wordnet_downloads_missing_lexicon(first_call_error): + # when the lexicon cannot be opened the probe should download it and retry, + # rather than letting the underlying error propagate and fail to load. + # The recovery path keys off the exception type (sqlite3.OperationalError for + # a missing database, wn.Error for a missing lexicon, see + # https://github.com/NVIDIA/garak/issues/1230), so the message text here is + # only illustrative and does not need to track wn's exact wording. + loaded_wordnet = MagicMock(name="wn.Wordnet") + download_path = MagicMock(name="download_tempfile_path") + + with ( + patch.object( + wn, "Wordnet", side_effect=[first_call_error, loaded_wordnet] + ) as mock_wordnet, + patch.object(wn, "download", return_value=download_path) as mock_download, + patch.object(pathlib.Path, "rmdir"), + ): + probe = garak._plugins.load_plugin("probes.topic.WordnetBlockedWords") + + assert ( + mock_download.call_count == 1 + ), "a missing lexicon should trigger exactly one download" + assert ( + mock_wordnet.call_count == 2 + ), "the lexicon should be reloaded after downloading it" + assert ( + probe.w is loaded_wordnet + ), "the probe should keep the reloaded wordnet after recovery" + download_path.unlink.assert_called_once()