|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import warnings |
16 | 17 | from unittest.mock import MagicMock |
17 | 18 |
|
18 | 19 | import pytest |
|
24 | 25 | register_framework, |
25 | 26 | set_default_framework, |
26 | 27 | ) |
| 28 | +from nemoguardrails.llm.providers import ( |
| 29 | + get_chat_provider_names, |
| 30 | + get_llm_provider_names, |
| 31 | + get_provider_names, |
| 32 | + register_chat_provider, |
| 33 | + register_llm_provider, |
| 34 | + register_provider, |
| 35 | +) |
27 | 36 | from nemoguardrails.types import LLMModel |
28 | 37 |
|
29 | 38 |
|
@@ -82,3 +91,78 @@ def test_reset_clears_registry(self): |
82 | 91 | _reset_frameworks() |
83 | 92 | with pytest.raises(KeyError): |
84 | 93 | get_framework("temp") |
| 94 | + |
| 95 | + |
| 96 | +class FakeChatProvider: |
| 97 | + pass |
| 98 | + |
| 99 | + |
| 100 | +class FakeLLMProvider: |
| 101 | + async def _acall(self, prompt, stop=None, **kwargs): |
| 102 | + return "fake" |
| 103 | + |
| 104 | + |
| 105 | +@pytest.fixture(autouse=False) |
| 106 | +def clean_providers(): |
| 107 | + from nemoguardrails.integrations.langchain.providers import providers as _p |
| 108 | + |
| 109 | + chat_backup = dict(_p._chat_providers) |
| 110 | + llm_backup = dict(_p._llm_providers) |
| 111 | + yield |
| 112 | + _p._chat_providers.clear() |
| 113 | + _p._chat_providers.update(chat_backup) |
| 114 | + _p._llm_providers.clear() |
| 115 | + _p._llm_providers.update(llm_backup) |
| 116 | + |
| 117 | + |
| 118 | +@pytest.mark.usefixtures("clean_providers") |
| 119 | +class TestProviderRegistration: |
| 120 | + def test_register_provider_appears_in_get_provider_names(self): |
| 121 | + register_provider("test_provider", FakeChatProvider) |
| 122 | + assert "test_provider" in get_provider_names() |
| 123 | + |
| 124 | + def test_register_chat_provider_appears_in_chat_names(self): |
| 125 | + register_chat_provider("test_chat", FakeChatProvider) |
| 126 | + assert "test_chat" in get_chat_provider_names() |
| 127 | + |
| 128 | + def test_register_llm_provider_appears_in_llm_names(self): |
| 129 | + with warnings.catch_warnings(): |
| 130 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 131 | + register_llm_provider("test_llm", FakeLLMProvider) |
| 132 | + with warnings.catch_warnings(): |
| 133 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 134 | + assert "test_llm" in get_llm_provider_names() |
| 135 | + |
| 136 | + def test_chat_and_llm_provider_names_are_different_subsets(self): |
| 137 | + register_chat_provider("only_chat_test", FakeChatProvider) |
| 138 | + with warnings.catch_warnings(): |
| 139 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 140 | + register_llm_provider("only_llm_test", FakeLLMProvider) |
| 141 | + |
| 142 | + chat_names = get_chat_provider_names() |
| 143 | + with warnings.catch_warnings(): |
| 144 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 145 | + llm_names = get_llm_provider_names() |
| 146 | + |
| 147 | + assert "only_chat_test" in chat_names |
| 148 | + assert "only_chat_test" not in llm_names |
| 149 | + assert "only_llm_test" in llm_names |
| 150 | + assert "only_llm_test" not in chat_names |
| 151 | + |
| 152 | + def test_get_provider_names_returns_both(self): |
| 153 | + register_chat_provider("both_chat", FakeChatProvider) |
| 154 | + with warnings.catch_warnings(): |
| 155 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 156 | + register_llm_provider("both_llm", FakeLLMProvider) |
| 157 | + |
| 158 | + all_names = get_provider_names() |
| 159 | + assert "both_chat" in all_names |
| 160 | + assert "both_llm" in all_names |
| 161 | + |
| 162 | + def test_register_llm_provider_emits_deprecation(self): |
| 163 | + with pytest.warns(DeprecationWarning, match="removed in 0.23.0"): |
| 164 | + register_llm_provider("dep_test", FakeLLMProvider) |
| 165 | + |
| 166 | + def test_get_llm_provider_names_emits_deprecation(self): |
| 167 | + with pytest.warns(DeprecationWarning, match="removed in 0.23.0"): |
| 168 | + get_llm_provider_names() |
0 commit comments