Skip to content

Commit 6cc77c5

Browse files
committed
fix(llm): keep separate chat/llm provider lists, add deprecation warnings
LangChainFramework exposes get_chat_provider_names() and get_llm_provider_names() separately so CLI and existing tests work. register_llm_provider and get_llm_provider_names emit DeprecationWarning (removal in 0.23.0). register_provider always targets chat registry. Add 7 provider registration tests with cleanup fixture to prevent polluting the global LangChain provider dicts.
1 parent 00a3c51 commit 6cc77c5

4 files changed

Lines changed: 141 additions & 16 deletions

File tree

nemoguardrails/cli/providers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17+
import warnings
1718
from typing import List, Literal, Optional, Tuple, cast
1819

1920
import typer
@@ -31,9 +32,14 @@
3132

3233
def _list_providers() -> None:
3334
"""List all available providers."""
34-
console.print("\n[bold]Text Completion Providers:[/]")
35-
for provider in sorted(get_llm_provider_names()):
36-
console.print(f" • {provider}")
35+
# Suppress deprecation warning: get_llm_provider_names is deprecated for
36+
# external callers but the CLI intentionally shows both categories until
37+
# text completion providers are removed in 0.23.0.
38+
with warnings.catch_warnings():
39+
warnings.simplefilter("ignore", DeprecationWarning)
40+
console.print("\n[bold]Text Completion Providers:[/]")
41+
for provider in sorted(get_llm_provider_names()):
42+
console.print(f" • {provider}")
3743

3844
console.print("\n[bold]Chat Completion Providers:[/]")
3945
for provider in sorted(get_chat_provider_names()):
@@ -45,7 +51,10 @@ def _get_provider_completions(
4551
) -> List[str]:
4652
"""Get list of providers based on type."""
4753
if provider_type == "text completion":
48-
return sorted(get_llm_provider_names())
54+
# See comment in _list_providers for why we suppress this warning.
55+
with warnings.catch_warnings():
56+
warnings.simplefilter("ignore", DeprecationWarning)
57+
return sorted(get_llm_provider_names())
4958
elif provider_type == "chat completion":
5059
return sorted(get_chat_provider_names())
5160
return []

nemoguardrails/integrations/langchain/llm_adapter.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,27 +201,35 @@ async def stream_async(
201201

202202
class LangChainFramework:
203203
def register_provider(self, name: str, provider_cls: Any) -> None:
204-
from langchain_core.language_models import BaseChatModel
205-
206204
from nemoguardrails.integrations.langchain.providers.providers import (
207205
register_chat_provider as _register_chat,
208206
)
207+
208+
_register_chat(name, provider_cls)
209+
210+
def register_llm_provider(self, name: str, provider_cls: Any) -> None:
209211
from nemoguardrails.integrations.langchain.providers.providers import (
210212
register_llm_provider as _register_llm,
211213
)
212214

213-
if isinstance(provider_cls, type) and issubclass(provider_cls, BaseChatModel):
214-
_register_chat(name, provider_cls)
215-
else:
216-
_register_llm(name, provider_cls)
215+
_register_llm(name, provider_cls)
217216

218217
def get_provider_names(self) -> List[str]:
218+
return sorted(set(self.get_chat_provider_names() + self.get_llm_provider_names()))
219+
220+
def get_chat_provider_names(self) -> List[str]:
221+
from nemoguardrails.integrations.langchain.providers.providers import (
222+
get_chat_provider_names as _get_chat,
223+
)
224+
225+
return _get_chat()
226+
227+
def get_llm_provider_names(self) -> List[str]:
219228
from nemoguardrails.integrations.langchain.providers.providers import (
220-
get_chat_provider_names,
221-
get_llm_provider_names,
229+
get_llm_provider_names as _get_llm,
222230
)
223231

224-
return sorted(set(get_chat_provider_names() + get_llm_provider_names()))
232+
return _get_llm()
225233

226234
def create_model(
227235
self,

nemoguardrails/llm/providers/__init__.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import warnings
1617
from typing import Any, List
1718

1819
from nemoguardrails.llm.frameworks import get_default_framework, get_framework
@@ -35,15 +36,38 @@ def register_chat_provider(name: str, provider_cls: Any) -> None:
3536

3637

3738
def register_llm_provider(name: str, provider_cls: Any) -> None:
38-
register_provider(name, provider_cls)
39+
warnings.warn(
40+
"register_llm_provider is deprecated and will be removed in 0.23.0. "
41+
"Text completion providers are being removed. Use register_chat_provider "
42+
"or register_provider instead.",
43+
DeprecationWarning,
44+
stacklevel=2,
45+
)
46+
fw = _active_framework()
47+
if hasattr(fw, "register_llm_provider"):
48+
fw.register_llm_provider(name, provider_cls) # type: ignore[attr-defined]
49+
else:
50+
fw.register_provider(name, provider_cls)
3951

4052

4153
def get_chat_provider_names() -> List[str]:
42-
return get_provider_names()
54+
fw = _active_framework()
55+
if hasattr(fw, "get_chat_provider_names"):
56+
return fw.get_chat_provider_names() # type: ignore[attr-defined]
57+
return fw.get_provider_names()
4358

4459

4560
def get_llm_provider_names() -> List[str]:
46-
return get_provider_names()
61+
warnings.warn(
62+
"get_llm_provider_names is deprecated and will be removed in 0.23.0. "
63+
"Text completion providers are being removed. Use get_provider_names instead.",
64+
DeprecationWarning,
65+
stacklevel=2,
66+
)
67+
fw = _active_framework()
68+
if hasattr(fw, "get_llm_provider_names"):
69+
return fw.get_llm_provider_names() # type: ignore[attr-defined]
70+
return fw.get_provider_names()
4771

4872

4973
__all__ = [

tests/llm/test_frameworks.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import warnings
1617
from unittest.mock import MagicMock
1718

1819
import pytest
@@ -24,6 +25,14 @@
2425
register_framework,
2526
set_default_framework,
2627
)
28+
from nemoguardrails.llm.providers import (
29+
get_chat_provider_names,
30+
get_llm_provider_names,
31+
get_provider_names,
32+
register_chat_provider,
33+
register_llm_provider,
34+
register_provider,
35+
)
2736
from nemoguardrails.types import LLMModel
2837

2938

@@ -81,3 +90,78 @@ def test_reset_clears_registry(self):
8190
_reset_frameworks()
8291
with pytest.raises(KeyError):
8392
get_framework("temp")
93+
94+
95+
class FakeChatProvider:
96+
pass
97+
98+
99+
class FakeLLMProvider:
100+
async def _acall(self, prompt, stop=None, **kwargs):
101+
return "fake"
102+
103+
104+
@pytest.fixture(autouse=False)
105+
def clean_providers():
106+
from nemoguardrails.integrations.langchain.providers import providers as _p
107+
108+
chat_backup = dict(_p._chat_providers)
109+
llm_backup = dict(_p._llm_providers)
110+
yield
111+
_p._chat_providers.clear()
112+
_p._chat_providers.update(chat_backup)
113+
_p._llm_providers.clear()
114+
_p._llm_providers.update(llm_backup)
115+
116+
117+
@pytest.mark.usefixtures("clean_providers")
118+
class TestProviderRegistration:
119+
def test_register_provider_appears_in_get_provider_names(self):
120+
register_provider("test_provider", FakeChatProvider)
121+
assert "test_provider" in get_provider_names()
122+
123+
def test_register_chat_provider_appears_in_chat_names(self):
124+
register_chat_provider("test_chat", FakeChatProvider)
125+
assert "test_chat" in get_chat_provider_names()
126+
127+
def test_register_llm_provider_appears_in_llm_names(self):
128+
with warnings.catch_warnings():
129+
warnings.simplefilter("ignore", DeprecationWarning)
130+
register_llm_provider("test_llm", FakeLLMProvider)
131+
with warnings.catch_warnings():
132+
warnings.simplefilter("ignore", DeprecationWarning)
133+
assert "test_llm" in get_llm_provider_names()
134+
135+
def test_chat_and_llm_provider_names_are_different_subsets(self):
136+
register_chat_provider("only_chat_test", FakeChatProvider)
137+
with warnings.catch_warnings():
138+
warnings.simplefilter("ignore", DeprecationWarning)
139+
register_llm_provider("only_llm_test", FakeLLMProvider)
140+
141+
chat_names = get_chat_provider_names()
142+
with warnings.catch_warnings():
143+
warnings.simplefilter("ignore", DeprecationWarning)
144+
llm_names = get_llm_provider_names()
145+
146+
assert "only_chat_test" in chat_names
147+
assert "only_chat_test" not in llm_names
148+
assert "only_llm_test" in llm_names
149+
assert "only_llm_test" not in chat_names
150+
151+
def test_get_provider_names_returns_both(self):
152+
register_chat_provider("both_chat", FakeChatProvider)
153+
with warnings.catch_warnings():
154+
warnings.simplefilter("ignore", DeprecationWarning)
155+
register_llm_provider("both_llm", FakeLLMProvider)
156+
157+
all_names = get_provider_names()
158+
assert "both_chat" in all_names
159+
assert "both_llm" in all_names
160+
161+
def test_register_llm_provider_emits_deprecation(self):
162+
with pytest.warns(DeprecationWarning, match="removed in 0.23.0"):
163+
register_llm_provider("dep_test", FakeLLMProvider)
164+
165+
def test_get_llm_provider_names_emits_deprecation(self):
166+
with pytest.warns(DeprecationWarning, match="removed in 0.23.0"):
167+
get_llm_provider_names()

0 commit comments

Comments
 (0)