Skip to content

Commit 5c2c1ea

Browse files
authored
refactor(llm): framework-owned provider registry (#1773)
Add register_provider/get_provider_names to LLMFramework protocol. our public API delegates to the active framework instead of importing from LangChain internals directly. LangChainFramework implements the new methods, routing to its internal chat/llm provider registries. Backward compat aliases register_chat_provider/register_llm_provider still work.
1 parent c2319b0 commit 5c2c1ea

6 files changed

Lines changed: 201 additions & 18 deletions

File tree

nemoguardrails/cli/providers.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17+
import warnings
1718
from typing import List, Literal, Optional, Tuple, cast
1819

1920
import typer
@@ -31,9 +32,14 @@
3132

3233
def _list_providers() -> None:
3334
"""List all available providers."""
34-
console.print("\n[bold]Text Completion Providers:[/]")
35-
for provider in sorted(get_llm_provider_names()):
36-
console.print(f" • {provider}")
35+
# Suppress deprecation warning: get_llm_provider_names is deprecated for
36+
# external callers but the CLI intentionally shows both categories until
37+
# text completion providers are removed in 0.23.0.
38+
with warnings.catch_warnings():
39+
warnings.simplefilter("ignore", DeprecationWarning)
40+
console.print("\n[bold]Text Completion Providers:[/]")
41+
for provider in sorted(get_llm_provider_names()):
42+
console.print(f" • {provider}")
3743

3844
console.print("\n[bold]Chat Completion Providers:[/]")
3945
for provider in sorted(get_chat_provider_names()):
@@ -45,7 +51,10 @@ def _get_provider_completions(
4551
) -> List[str]:
4652
"""Get list of providers based on type."""
4753
if provider_type == "text completion":
48-
return sorted(get_llm_provider_names())
54+
# See comment in _list_providers for why we suppress this warning.
55+
with warnings.catch_warnings():
56+
warnings.simplefilter("ignore", DeprecationWarning)
57+
return sorted(get_llm_provider_names())
4958
elif provider_type == "chat completion":
5059
return sorted(get_chat_provider_names())
5160
return []

nemoguardrails/integrations/langchain/llm_adapter.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,37 @@ async def stream_async(
210210

211211

212212
class LangChainFramework:
213+
def register_provider(self, name: str, provider_cls: Any) -> None:
214+
from nemoguardrails.integrations.langchain.providers.providers import (
215+
register_chat_provider as _register_chat,
216+
)
217+
218+
_register_chat(name, provider_cls)
219+
220+
def register_llm_provider(self, name: str, provider_cls: Any) -> None:
221+
from nemoguardrails.integrations.langchain.providers.providers import (
222+
register_llm_provider as _register_llm,
223+
)
224+
225+
_register_llm(name, provider_cls)
226+
227+
def get_provider_names(self) -> List[str]:
228+
return sorted(set(self.get_chat_provider_names() + self.get_llm_provider_names()))
229+
230+
def get_chat_provider_names(self) -> List[str]:
231+
from nemoguardrails.integrations.langchain.providers.providers import (
232+
get_chat_provider_names as _get_chat,
233+
)
234+
235+
return _get_chat()
236+
237+
def get_llm_provider_names(self) -> List[str]:
238+
from nemoguardrails.integrations.langchain.providers.providers import (
239+
get_llm_provider_names as _get_llm,
240+
)
241+
242+
return _get_llm()
243+
213244
def create_model(
214245
self,
215246
model_name: str,
Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,20 +13,68 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# Re-export from new location for backwards compatibility.
17-
# Implementation moved to nemoguardrails/integrations/langchain/providers/.
18-
from nemoguardrails.integrations.langchain.providers import (
19-
get_chat_provider_names,
20-
get_community_chat_provider_names,
21-
get_llm_provider_names,
22-
register_chat_provider,
23-
register_llm_provider,
24-
)
16+
import warnings
17+
from typing import Any, List
18+
19+
from nemoguardrails.llm.frameworks import get_default_framework, get_framework
20+
21+
22+
def _active_framework():
23+
return get_framework(get_default_framework())
24+
25+
26+
def register_provider(name: str, provider_cls: Any) -> None:
27+
_active_framework().register_provider(name, provider_cls)
28+
29+
30+
def get_provider_names() -> List[str]:
31+
return _active_framework().get_provider_names()
32+
33+
34+
def register_chat_provider(name: str, provider_cls: Any) -> None:
35+
register_provider(name, provider_cls)
36+
37+
38+
def register_llm_provider(name: str, provider_cls: Any) -> None:
39+
warnings.warn(
40+
"register_llm_provider is deprecated and will be removed in 0.23.0. "
41+
"Text completion providers are being removed. Use register_chat_provider "
42+
"or register_provider instead.",
43+
DeprecationWarning,
44+
stacklevel=2,
45+
)
46+
fw = _active_framework()
47+
if hasattr(fw, "register_llm_provider"):
48+
fw.register_llm_provider(name, provider_cls) # type: ignore[attr-defined]
49+
else:
50+
fw.register_provider(name, provider_cls)
51+
52+
53+
def get_chat_provider_names() -> List[str]:
54+
fw = _active_framework()
55+
if hasattr(fw, "get_chat_provider_names"):
56+
return fw.get_chat_provider_names() # type: ignore[attr-defined]
57+
return fw.get_provider_names()
58+
59+
60+
def get_llm_provider_names() -> List[str]:
61+
warnings.warn(
62+
"get_llm_provider_names is deprecated and will be removed in 0.23.0. "
63+
"Text completion providers are being removed. Use get_provider_names instead.",
64+
DeprecationWarning,
65+
stacklevel=2,
66+
)
67+
fw = _active_framework()
68+
if hasattr(fw, "get_llm_provider_names"):
69+
return fw.get_llm_provider_names() # type: ignore[attr-defined]
70+
return fw.get_provider_names()
71+
2572

2673
__all__ = [
27-
"get_chat_provider_names",
28-
"get_community_chat_provider_names",
29-
"get_llm_provider_names",
74+
"register_provider",
75+
"get_provider_names",
3076
"register_chat_provider",
3177
"register_llm_provider",
78+
"get_chat_provider_names",
79+
"get_llm_provider_names",
3280
]

nemoguardrails/types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ class LLMFramework(Protocol):
261261
"""Protocol for pluggable LLM framework backends.
262262
263263
Each framework (LangChain, LiteLLM, etc.) implements this protocol to
264-
provide a factory for creating ``LLMModel`` instances.
264+
provide a factory for creating ``LLMModel`` instances and managing
265+
its own set of providers.
265266
266267
``model_kwargs`` carries all provider-specific configuration. Framework
267268
implementations extract what they need (e.g. LangChain pops ``mode``
@@ -274,3 +275,7 @@ def create_model(
274275
provider_name: str,
275276
model_kwargs: Optional[Dict[str, Any]] = None,
276277
) -> LLMModel: ...
278+
279+
def register_provider(self, name: str, provider_cls: Any) -> None: ...
280+
281+
def get_provider_names(self) -> List[str]: ...

tests/llm/test_frameworks.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import warnings
1617
from unittest.mock import MagicMock
1718

1819
import pytest
@@ -24,6 +25,14 @@
2425
register_framework,
2526
set_default_framework,
2627
)
28+
from nemoguardrails.llm.providers import (
29+
get_chat_provider_names,
30+
get_llm_provider_names,
31+
get_provider_names,
32+
register_chat_provider,
33+
register_llm_provider,
34+
register_provider,
35+
)
2736
from nemoguardrails.types import LLMModel
2837

2938

@@ -82,3 +91,78 @@ def test_reset_clears_registry(self):
8291
_reset_frameworks()
8392
with pytest.raises(KeyError):
8493
get_framework("temp")
94+
95+
96+
class FakeChatProvider:
97+
pass
98+
99+
100+
class FakeLLMProvider:
101+
async def _acall(self, prompt, stop=None, **kwargs):
102+
return "fake"
103+
104+
105+
@pytest.fixture(autouse=False)
106+
def clean_providers():
107+
from nemoguardrails.integrations.langchain.providers import providers as _p
108+
109+
chat_backup = dict(_p._chat_providers)
110+
llm_backup = dict(_p._llm_providers)
111+
yield
112+
_p._chat_providers.clear()
113+
_p._chat_providers.update(chat_backup)
114+
_p._llm_providers.clear()
115+
_p._llm_providers.update(llm_backup)
116+
117+
118+
@pytest.mark.usefixtures("clean_providers")
119+
class TestProviderRegistration:
120+
def test_register_provider_appears_in_get_provider_names(self):
121+
register_provider("test_provider", FakeChatProvider)
122+
assert "test_provider" in get_provider_names()
123+
124+
def test_register_chat_provider_appears_in_chat_names(self):
125+
register_chat_provider("test_chat", FakeChatProvider)
126+
assert "test_chat" in get_chat_provider_names()
127+
128+
def test_register_llm_provider_appears_in_llm_names(self):
129+
with warnings.catch_warnings():
130+
warnings.simplefilter("ignore", DeprecationWarning)
131+
register_llm_provider("test_llm", FakeLLMProvider)
132+
with warnings.catch_warnings():
133+
warnings.simplefilter("ignore", DeprecationWarning)
134+
assert "test_llm" in get_llm_provider_names()
135+
136+
def test_chat_and_llm_provider_names_are_different_subsets(self):
137+
register_chat_provider("only_chat_test", FakeChatProvider)
138+
with warnings.catch_warnings():
139+
warnings.simplefilter("ignore", DeprecationWarning)
140+
register_llm_provider("only_llm_test", FakeLLMProvider)
141+
142+
chat_names = get_chat_provider_names()
143+
with warnings.catch_warnings():
144+
warnings.simplefilter("ignore", DeprecationWarning)
145+
llm_names = get_llm_provider_names()
146+
147+
assert "only_chat_test" in chat_names
148+
assert "only_chat_test" not in llm_names
149+
assert "only_llm_test" in llm_names
150+
assert "only_llm_test" not in chat_names
151+
152+
def test_get_provider_names_returns_both(self):
153+
register_chat_provider("both_chat", FakeChatProvider)
154+
with warnings.catch_warnings():
155+
warnings.simplefilter("ignore", DeprecationWarning)
156+
register_llm_provider("both_llm", FakeLLMProvider)
157+
158+
all_names = get_provider_names()
159+
assert "both_chat" in all_names
160+
assert "both_llm" in all_names
161+
162+
def test_register_llm_provider_emits_deprecation(self):
163+
with pytest.warns(DeprecationWarning, match="removed in 0.23.0"):
164+
register_llm_provider("dep_test", FakeLLMProvider)
165+
166+
def test_get_llm_provider_names_emits_deprecation(self):
167+
with pytest.warns(DeprecationWarning, match="removed in 0.23.0"):
168+
get_llm_provider_names()

tests/test_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,12 @@ class MockFramework:
473473
def create_model(self, model_name, provider_name, model_kwargs=None):
474474
return None
475475

476+
def register_provider(self, name, provider_cls):
477+
pass
478+
479+
def get_provider_names(self):
480+
return []
481+
476482
assert isinstance(MockFramework(), LLMFramework)
477483

478484
def test_incomplete_class_fails_protocol(self):

0 commit comments

Comments
 (0)