Skip to content

Commit a7085e0

Browse files
authored
refactor(llm)!: atomic switch to LLMModel protocol (#1760)
Part of the LangChain decoupling stack. Introduced framework abstraction layer enabling support for multiple LLM providers beyond LangChain, with new public APIs: get_default_framework(), register_framework(), and set_default_framework(). Replaced LangChain-specific LLM types with unified LLMModel interface throughout the codebase. Added LangChainLLMAdapter for backward compatibility with existing LangChain implementations.
1 parent 6a7be49 commit a7085e0

40 files changed

Lines changed: 904 additions & 1665 deletions

nemoguardrails/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,17 @@
4848
# Use the original LLMRails class
4949
from nemoguardrails.rails import LLMRails
5050

51+
from nemoguardrails.llm.frameworks import ( # noqa: E402
52+
get_default_framework,
53+
register_framework,
54+
set_default_framework,
55+
)
56+
5157
__version__ = version("nemoguardrails")
52-
__all__ = ["LLMRails", "RailsConfig"]
58+
__all__ = [
59+
"LLMRails",
60+
"RailsConfig",
61+
"get_default_framework",
62+
"register_framework",
63+
"set_default_framework",
64+
]

nemoguardrails/actions/llm/generation.py

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@
2323
from dataclasses import asdict
2424
from functools import lru_cache
2525
from time import time
26-
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, cast
26+
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
2727

2828
from jinja2 import meta
2929
from jinja2.sandbox import SandboxedEnvironment
30-
from langchain_core.language_models import BaseChatModel, BaseLLM
3130

3231
from nemoguardrails.actions.actions import ActionResult, action
3332
from nemoguardrails.actions.llm.utils import (
@@ -61,6 +60,7 @@
6160
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
6261
from nemoguardrails.rails.llm.options import GenerationOptions
6362
from nemoguardrails.streaming import StreamingHandler
63+
from nemoguardrails.types import LLMModel
6464
from nemoguardrails.utils import (
6565
new_event_dict,
6666
new_uuid,
@@ -79,7 +79,7 @@ class LLMGenerationActions:
7979
def __init__(
8080
self,
8181
config: RailsConfig,
82-
llm: Optional[Union[BaseLLM, BaseChatModel]],
82+
llm: Optional[LLMModel],
8383
llm_task_manager: LLMTaskManager,
8484
get_embedding_search_provider_instance: Callable[[Optional[EmbeddingSearchProvider]], EmbeddingsIndex],
8585
verbose: bool = False,
@@ -350,7 +350,7 @@ async def generate_user_intent(
350350
events: List[dict],
351351
context: dict,
352352
config: RailsConfig,
353-
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
353+
llm: Optional[LLMModel] = None,
354354
kb: Optional[KnowledgeBase] = None,
355355
):
356356
"""Generate the canonical form for what the user said i.e. user intent."""
@@ -369,7 +369,7 @@ async def generate_user_intent(
369369

370370
# Use action specific llm if registered else fallback to main llm
371371
# This can be None as some code-paths use embedding lookups rather than LLM generation
372-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
372+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
373373

374374
streaming_handler = streaming_handler_var.get()
375375

@@ -424,11 +424,13 @@ async def generate_user_intent(
424424
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_USER_INTENT.value))
425425

426426
# We make this call with temperature 0 to have it as deterministic as possible.
427-
result = await llm_call(
428-
generation_llm,
429-
prompt,
430-
llm_params={"temperature": self.config.lowest_temperature},
431-
)
427+
result = (
428+
await llm_call(
429+
generation_llm,
430+
prompt,
431+
llm_params={"temperature": self.config.lowest_temperature},
432+
)
433+
).content
432434

433435
# Parse the output using the associated parser
434436
result = self.llm_task_manager.parse_task_output(Task.GENERATE_USER_INTENT, output=result)
@@ -501,12 +503,14 @@ async def generate_user_intent(
501503

502504
streaming_handler: Optional[StreamingHandler] = streaming_handler_var.get()
503505

504-
text = await llm_call(
505-
generation_llm,
506-
prompt,
507-
streaming_handler=streaming_handler,
508-
llm_params=llm_params,
509-
)
506+
text = (
507+
await llm_call(
508+
generation_llm,
509+
prompt,
510+
streaming_handler=streaming_handler,
511+
llm_params=llm_params,
512+
)
513+
).content
510514
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text)
511515

512516
else:
@@ -530,13 +534,15 @@ async def generate_user_intent(
530534
generation_options: Optional[GenerationOptions] = generation_options_var.get()
531535
llm_params = (generation_options and generation_options.llm_params) or {}
532536

533-
result = await llm_call(
534-
generation_llm,
535-
prompt,
536-
streaming_handler=streaming_handler,
537-
stop=["User:"],
538-
llm_params=llm_params,
539-
)
537+
result = (
538+
await llm_call(
539+
generation_llm,
540+
prompt,
541+
streaming_handler=streaming_handler,
542+
stop=["User:"],
543+
llm_params=llm_params,
544+
)
545+
).content
540546

541547
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=result)
542548
text = text.strip()
@@ -588,15 +594,15 @@ async def _search_flows_index(self, text, max_results):
588594
return final_results[0:max_results]
589595

590596
@action(is_system_action=True)
591-
async def generate_next_steps(self, events: List[dict], llm: Optional[BaseLLM] = None):
597+
async def generate_next_steps(self, events: List[dict], llm: Optional[LLMModel] = None):
592598
"""Generate the next step in the current conversation flow.
593599
594600
Currently, only generates a next step after a user intent.
595601
"""
596602
log.info("Phase 2 :: Generating next step ...")
597603

598604
# Use action specific llm if registered else fallback to main llm
599-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
605+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
600606

601607
# The last event should be the "StartInternalSystemAction" and the one before it the "UserIntent".
602608
event = get_last_user_intent_event(events)
@@ -633,11 +639,13 @@ async def generate_next_steps(self, events: List[dict], llm: Optional[BaseLLM] =
633639
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_NEXT_STEPS.value))
634640

635641
# We use temperature 0 for next step prediction as well
636-
result = await llm_call(
637-
generation_llm,
638-
prompt,
639-
llm_params={"temperature": self.config.lowest_temperature},
640-
)
642+
result = (
643+
await llm_call(
644+
generation_llm,
645+
prompt,
646+
llm_params={"temperature": self.config.lowest_temperature},
647+
)
648+
).content
641649

642650
# Parse the output using the associated parser
643651
result = self.llm_task_manager.parse_task_output(Task.GENERATE_NEXT_STEPS, output=result)
@@ -743,12 +751,12 @@ def _render_string(
743751
return template.render(render_context)
744752

745753
@action(is_system_action=True)
746-
async def generate_bot_message(self, events: List[dict], context: dict, llm: Optional[BaseLLM] = None):
754+
async def generate_bot_message(self, events: List[dict], context: dict, llm: Optional[LLMModel] = None):
747755
"""Generate a bot message based on the desired bot intent."""
748756
log.info("Phase 3 :: Generating bot message ...")
749757

750758
# Use action specific llm if registered else fallback to main llm
751-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
759+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
752760

753761
# The last event should be the "StartInternalSystemAction" and the one before it the "BotIntent".
754762
event = get_last_bot_intent_event(events)
@@ -894,12 +902,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
894902

895903
if not prompt:
896904
raise RuntimeError("No prompt found to generate bot message")
897-
result = await llm_call(
898-
generation_llm,
899-
prompt,
900-
streaming_handler=streaming_handler,
901-
llm_params=llm_params,
902-
)
905+
result = (
906+
await llm_call(
907+
generation_llm,
908+
prompt,
909+
streaming_handler=streaming_handler,
910+
llm_params=llm_params,
911+
)
912+
).content
903913

904914
result = self.llm_task_manager.parse_task_output(Task.GENERAL, output=result)
905915

@@ -948,12 +958,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
948958
generation_options: Optional[GenerationOptions] = generation_options_var.get()
949959
llm_params = (generation_options and generation_options.llm_params) or {}
950960

951-
result = await llm_call(
952-
generation_llm,
953-
prompt,
954-
streaming_handler=streaming_handler,
955-
llm_params=llm_params,
956-
)
961+
result = (
962+
await llm_call(
963+
generation_llm,
964+
prompt,
965+
streaming_handler=streaming_handler,
966+
llm_params=llm_params,
967+
)
968+
).content
957969

958970
log.info(
959971
"--- :: LLM Bot Message Generation call took %.2f seconds",
@@ -1016,7 +1028,7 @@ async def generate_value(
10161028
instructions: str,
10171029
events: List[dict],
10181030
var_name: Optional[str] = None,
1019-
llm: Optional[BaseLLM] = None,
1031+
llm: Optional[LLMModel] = None,
10201032
):
10211033
"""Generate a value in the context of the conversation.
10221034
@@ -1027,7 +1039,7 @@ async def generate_value(
10271039
:param llm: Custom llm model to generate_value
10281040
"""
10291041
# Use action specific llm if registered else fallback to main llm
1030-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
1042+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
10311043

10321044
last_event = events[-1]
10331045
assert last_event["type"] == "StartInternalSystemAction"
@@ -1062,11 +1074,13 @@ async def generate_value(
10621074
# Initialize the LLMCallInfo object
10631075
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_VALUE.value))
10641076

1065-
result = await llm_call(
1066-
generation_llm,
1067-
prompt,
1068-
llm_params={"temperature": self.config.lowest_temperature},
1069-
)
1077+
result = (
1078+
await llm_call(
1079+
generation_llm,
1080+
prompt,
1081+
llm_params={"temperature": self.config.lowest_temperature},
1082+
)
1083+
).content
10701084

10711085
# Parse the output using the associated parser
10721086
result = self.llm_task_manager.parse_task_output(Task.GENERATE_VALUE, output=result)
@@ -1092,7 +1106,7 @@ async def generate_value(
10921106
async def generate_intent_steps_message(
10931107
self,
10941108
events: List[dict],
1095-
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
1109+
llm: Optional[LLMModel] = None,
10961110
kb: Optional[KnowledgeBase] = None,
10971111
):
10981112
"""Generate all three main Guardrails phases with a single LLM call.
@@ -1110,7 +1124,7 @@ async def generate_intent_steps_message(
11101124
"Cannot generate user intent from this event type."
11111125
)
11121126
# Use action specific llm if registered else fallback to main llm
1113-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
1127+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
11141128

11151129
streaming_handler = streaming_handler_var.get()
11161130

@@ -1274,7 +1288,7 @@ async def generate_intent_steps_message(
12741288
**llm_params,
12751289
"temperature": self.config.lowest_temperature,
12761290
}
1277-
result = await llm_call(generation_llm, prompt, llm_params=additional_params)
1291+
result = (await llm_call(generation_llm, prompt, llm_params=additional_params)).content
12781292

12791293
# Parse the output using the associated parser
12801294
result = self.llm_task_manager.parse_task_output(Task.GENERATE_INTENT_STEPS_MESSAGE, output=result)
@@ -1341,7 +1355,7 @@ async def generate_intent_steps_message(
13411355
# We make this call with temperature 0 to have it as deterministic as possible.
13421356
gen_options: Optional[GenerationOptions] = generation_options_var.get()
13431357
llm_params = (gen_options and gen_options.llm_params) or {}
1344-
result = await llm_call(generation_llm, prompt, llm_params=llm_params)
1358+
result = (await llm_call(generation_llm, prompt, llm_params=llm_params)).content
13451359

13461360
result = self.llm_task_manager.parse_task_output(Task.GENERAL, output=result)
13471361
text = result.strip()

0 commit comments

Comments
 (0)