Skip to content

Commit 00a3c51

Browse files
committed
refactor(llm): framework-owned provider registry
Add register_provider/get_provider_names to LLMFramework protocol. Public API (llm/providers/__init__.py) delegates to the active framework instead of importing from LangChain internals directly. LangChainFramework implements the new methods, routing to its internal chat/llm provider registries. Backwards-compat aliases register_chat_provider/register_llm_provider still work.
1 parent a7f2246 commit 00a3c51

4 files changed

Lines changed: 73 additions & 15 deletions

File tree

nemoguardrails/integrations/langchain/llm_adapter.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,29 @@ async def stream_async(
200200

201201

202202
class LangChainFramework:
203+
def register_provider(self, name: str, provider_cls: Any) -> None:
204+
from langchain_core.language_models import BaseChatModel
205+
206+
from nemoguardrails.integrations.langchain.providers.providers import (
207+
register_chat_provider as _register_chat,
208+
)
209+
from nemoguardrails.integrations.langchain.providers.providers import (
210+
register_llm_provider as _register_llm,
211+
)
212+
213+
if isinstance(provider_cls, type) and issubclass(provider_cls, BaseChatModel):
214+
_register_chat(name, provider_cls)
215+
else:
216+
_register_llm(name, provider_cls)
217+
218+
def get_provider_names(self) -> List[str]:
219+
from nemoguardrails.integrations.langchain.providers.providers import (
220+
get_chat_provider_names,
221+
get_llm_provider_names,
222+
)
223+
224+
return sorted(set(get_chat_provider_names() + get_llm_provider_names()))
225+
203226
def create_model(
204227
self,
205228
model_name: str,
Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,20 +13,44 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# Re-export from new location for backwards compatibility.
17-
# Implementation moved to nemoguardrails/integrations/langchain/providers/.
18-
from nemoguardrails.integrations.langchain.providers import (
19-
get_chat_provider_names,
20-
get_community_chat_provider_names,
21-
get_llm_provider_names,
22-
register_chat_provider,
23-
register_llm_provider,
24-
)
16+
from typing import Any, List
17+
18+
from nemoguardrails.llm.frameworks import get_default_framework, get_framework
19+
20+
21+
def _active_framework():
22+
return get_framework(get_default_framework())
23+
24+
25+
def register_provider(name: str, provider_cls: Any) -> None:
26+
_active_framework().register_provider(name, provider_cls)
27+
28+
29+
def get_provider_names() -> List[str]:
30+
return _active_framework().get_provider_names()
31+
32+
33+
def register_chat_provider(name: str, provider_cls: Any) -> None:
34+
register_provider(name, provider_cls)
35+
36+
37+
def register_llm_provider(name: str, provider_cls: Any) -> None:
38+
register_provider(name, provider_cls)
39+
40+
41+
def get_chat_provider_names() -> List[str]:
42+
return get_provider_names()
43+
44+
45+
def get_llm_provider_names() -> List[str]:
46+
return get_provider_names()
47+
2548

2649
__all__ = [
27-
"get_chat_provider_names",
28-
"get_community_chat_provider_names",
29-
"get_llm_provider_names",
50+
"register_provider",
51+
"get_provider_names",
3052
"register_chat_provider",
3153
"register_llm_provider",
54+
"get_chat_provider_names",
55+
"get_llm_provider_names",
3256
]

nemoguardrails/types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ class LLMFramework(Protocol):
259259
"""Protocol for pluggable LLM framework backends.
260260
261261
Each framework (LangChain, LiteLLM, etc.) implements this protocol to
262-
provide a factory for creating ``LLMModel`` instances.
262+
provide a factory for creating ``LLMModel`` instances and managing
263+
its own set of providers.
263264
264265
``model_kwargs`` carries all provider-specific configuration. Framework
265266
implementations extract what they need (e.g. LangChain pops ``mode``
@@ -272,3 +273,7 @@ def create_model(
272273
provider_name: str,
273274
model_kwargs: Optional[Dict[str, Any]] = None,
274275
) -> LLMModel: ...
276+
277+
def register_provider(self, name: str, provider_cls: Any) -> None: ...
278+
279+
def get_provider_names(self) -> List[str]: ...

tests/test_types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,15 @@ def test_creation(self):
455455
class TestLLMFrameworkProtocol:
456456
def test_mock_satisfies_protocol(self):
457457
class MockFramework:
458-
def create_model(self, model_name, provider_name, mode, model_kwargs):
458+
def create_model(self, model_name, provider_name, model_kwargs=None):
459459
return None
460460

461+
def register_provider(self, name, provider_cls):
462+
pass
463+
464+
def get_provider_names(self):
465+
return []
466+
461467
assert isinstance(MockFramework(), LLMFramework)
462468

463469
def test_incomplete_class_fails_protocol(self):

0 commit comments

Comments
 (0)