Skip to content

Commit 6e581da

Browse files
committed
refactor(llm): migrate all callers to LLMResponse.content
Update ~35 call sites across generation actions, library actions, and eval modules to use .content for text access now that llm_call() returns LLMResponse instead of str.
1 parent 150894a commit 6e581da

21 files changed

Lines changed: 311 additions & 246 deletions

File tree

nemoguardrails/actions/llm/generation.py

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@
2323
from dataclasses import asdict
2424
from functools import lru_cache
2525
from time import time
26-
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, cast
26+
from typing import Any, Awaitable, Callable, Dict, List, Optional, cast
2727

2828
from jinja2 import meta
2929
from jinja2.sandbox import SandboxedEnvironment
30-
from langchain_core.language_models import BaseChatModel, BaseLLM
3130

3231
from nemoguardrails.actions.actions import ActionResult, action
3332
from nemoguardrails.actions.llm.utils import (
@@ -61,6 +60,7 @@
6160
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
6261
from nemoguardrails.rails.llm.options import GenerationOptions
6362
from nemoguardrails.streaming import StreamingHandler
63+
from nemoguardrails.types import LLMModel
6464
from nemoguardrails.utils import (
6565
new_event_dict,
6666
new_uuid,
@@ -79,7 +79,7 @@ class LLMGenerationActions:
7979
def __init__(
8080
self,
8181
config: RailsConfig,
82-
llm: Optional[Union[BaseLLM, BaseChatModel]],
82+
llm: Optional[LLMModel],
8383
llm_task_manager: LLMTaskManager,
8484
get_embedding_search_provider_instance: Callable[[Optional[EmbeddingSearchProvider]], EmbeddingsIndex],
8585
verbose: bool = False,
@@ -350,7 +350,7 @@ async def generate_user_intent(
350350
events: List[dict],
351351
context: dict,
352352
config: RailsConfig,
353-
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
353+
llm: Optional[LLMModel] = None,
354354
kb: Optional[KnowledgeBase] = None,
355355
):
356356
"""Generate the canonical form for what the user said i.e. user intent."""
@@ -369,7 +369,7 @@ async def generate_user_intent(
369369

370370
# Use action specific llm if registered else fallback to main llm
371371
# This can be None as some code-paths use embedding lookups rather than LLM generation
372-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
372+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
373373

374374
streaming_handler = streaming_handler_var.get()
375375

@@ -435,11 +435,13 @@ async def generate_user_intent(
435435
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_USER_INTENT.value))
436436

437437
# We make this call with temperature 0 to have it as deterministic as possible.
438-
result = await llm_call(
439-
generation_llm,
440-
prompt,
441-
llm_params={"temperature": self.config.lowest_temperature},
442-
)
438+
result = (
439+
await llm_call(
440+
generation_llm,
441+
prompt,
442+
llm_params={"temperature": self.config.lowest_temperature},
443+
)
444+
).content
443445

444446
# Parse the output using the associated parser
445447
result = self.llm_task_manager.parse_task_output(Task.GENERATE_USER_INTENT, output=result)
@@ -512,12 +514,14 @@ async def generate_user_intent(
512514

513515
streaming_handler: Optional[StreamingHandler] = streaming_handler_var.get()
514516

515-
text = await llm_call(
516-
generation_llm,
517-
prompt,
518-
streaming_handler=streaming_handler,
519-
llm_params=llm_params,
520-
)
517+
text = (
518+
await llm_call(
519+
generation_llm,
520+
prompt,
521+
streaming_handler=streaming_handler,
522+
llm_params=llm_params,
523+
)
524+
).content
521525
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text)
522526

523527
else:
@@ -541,13 +545,15 @@ async def generate_user_intent(
541545
generation_options: Optional[GenerationOptions] = generation_options_var.get()
542546
llm_params = (generation_options and generation_options.llm_params) or {}
543547

544-
result = await llm_call(
545-
generation_llm,
546-
prompt,
547-
streaming_handler=streaming_handler,
548-
stop=["User:"],
549-
llm_params=llm_params,
550-
)
548+
result = (
549+
await llm_call(
550+
generation_llm,
551+
prompt,
552+
streaming_handler=streaming_handler,
553+
stop=["User:"],
554+
llm_params=llm_params,
555+
)
556+
).content
551557

552558
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=result)
553559
text = text.strip()
@@ -599,15 +605,15 @@ async def _search_flows_index(self, text, max_results):
599605
return final_results[0:max_results]
600606

601607
@action(is_system_action=True)
602-
async def generate_next_steps(self, events: List[dict], llm: Optional[BaseLLM] = None):
608+
async def generate_next_steps(self, events: List[dict], llm: Optional[LLMModel] = None):
603609
"""Generate the next step in the current conversation flow.
604610
605611
Currently, only generates a next step after a user intent.
606612
"""
607613
log.info("Phase 2 :: Generating next step ...")
608614

609615
# Use action specific llm if registered else fallback to main llm
610-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
616+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
611617

612618
# The last event should be the "StartInternalSystemAction" and the one before it the "UserIntent".
613619
event = get_last_user_intent_event(events)
@@ -644,11 +650,13 @@ async def generate_next_steps(self, events: List[dict], llm: Optional[BaseLLM] =
644650
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_NEXT_STEPS.value))
645651

646652
# We use temperature 0 for next step prediction as well
647-
result = await llm_call(
648-
generation_llm,
649-
prompt,
650-
llm_params={"temperature": self.config.lowest_temperature},
651-
)
653+
result = (
654+
await llm_call(
655+
generation_llm,
656+
prompt,
657+
llm_params={"temperature": self.config.lowest_temperature},
658+
)
659+
).content
652660

653661
# Parse the output using the associated parser
654662
result = self.llm_task_manager.parse_task_output(Task.GENERATE_NEXT_STEPS, output=result)
@@ -754,12 +762,12 @@ def _render_string(
754762
return template.render(render_context)
755763

756764
@action(is_system_action=True)
757-
async def generate_bot_message(self, events: List[dict], context: dict, llm: Optional[BaseLLM] = None):
765+
async def generate_bot_message(self, events: List[dict], context: dict, llm: Optional[LLMModel] = None):
758766
"""Generate a bot message based on the desired bot intent."""
759767
log.info("Phase 3 :: Generating bot message ...")
760768

761769
# Use action specific llm if registered else fallback to main llm
762-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
770+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
763771

764772
# The last event should be the "StartInternalSystemAction" and the one before it the "BotIntent".
765773
event = get_last_bot_intent_event(events)
@@ -905,12 +913,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
905913

906914
if not prompt:
907915
raise RuntimeError("No prompt found to generate bot message")
908-
result = await llm_call(
909-
generation_llm,
910-
prompt,
911-
streaming_handler=streaming_handler,
912-
llm_params=llm_params,
913-
)
916+
result = (
917+
await llm_call(
918+
generation_llm,
919+
prompt,
920+
streaming_handler=streaming_handler,
921+
llm_params=llm_params,
922+
)
923+
).content
914924

915925
result = self.llm_task_manager.parse_task_output(Task.GENERAL, output=result)
916926

@@ -959,12 +969,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
959969
generation_options: Optional[GenerationOptions] = generation_options_var.get()
960970
llm_params = (generation_options and generation_options.llm_params) or {}
961971

962-
result = await llm_call(
963-
generation_llm,
964-
prompt,
965-
streaming_handler=streaming_handler,
966-
llm_params=llm_params,
967-
)
972+
result = (
973+
await llm_call(
974+
generation_llm,
975+
prompt,
976+
streaming_handler=streaming_handler,
977+
llm_params=llm_params,
978+
)
979+
).content
968980

969981
log.info(
970982
"--- :: LLM Bot Message Generation call took %.2f seconds",
@@ -1027,7 +1039,7 @@ async def generate_value(
10271039
instructions: str,
10281040
events: List[dict],
10291041
var_name: Optional[str] = None,
1030-
llm: Optional[BaseLLM] = None,
1042+
llm: Optional[LLMModel] = None,
10311043
):
10321044
"""Generate a value in the context of the conversation.
10331045
@@ -1038,7 +1050,7 @@ async def generate_value(
10381050
:param llm: Custom llm model to generate_value
10391051
"""
10401052
# Use action specific llm if registered else fallback to main llm
1041-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
1053+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
10421054

10431055
last_event = events[-1]
10441056
assert last_event["type"] == "StartInternalSystemAction"
@@ -1073,11 +1085,13 @@ async def generate_value(
10731085
# Initialize the LLMCallInfo object
10741086
llm_call_info_var.set(LLMCallInfo(task=Task.GENERATE_VALUE.value))
10751087

1076-
result = await llm_call(
1077-
generation_llm,
1078-
prompt,
1079-
llm_params={"temperature": self.config.lowest_temperature},
1080-
)
1088+
result = (
1089+
await llm_call(
1090+
generation_llm,
1091+
prompt,
1092+
llm_params={"temperature": self.config.lowest_temperature},
1093+
)
1094+
).content
10811095

10821096
# Parse the output using the associated parser
10831097
result = self.llm_task_manager.parse_task_output(Task.GENERATE_VALUE, output=result)
@@ -1103,7 +1117,7 @@ async def generate_value(
11031117
async def generate_intent_steps_message(
11041118
self,
11051119
events: List[dict],
1106-
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
1120+
llm: Optional[LLMModel] = None,
11071121
kb: Optional[KnowledgeBase] = None,
11081122
):
11091123
"""Generate all three main Guardrails phases with a single LLM call.
@@ -1121,7 +1135,7 @@ async def generate_intent_steps_message(
11211135
"Cannot generate user intent from this event type."
11221136
)
11231137
# Use action specific llm if registered else fallback to main llm
1124-
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm
1138+
generation_llm: Optional[LLMModel] = llm if llm else self.llm
11251139

11261140
streaming_handler = streaming_handler_var.get()
11271141

@@ -1285,7 +1299,7 @@ async def generate_intent_steps_message(
12851299
**llm_params,
12861300
"temperature": self.config.lowest_temperature,
12871301
}
1288-
result = await llm_call(generation_llm, prompt, llm_params=additional_params)
1302+
result = (await llm_call(generation_llm, prompt, llm_params=additional_params)).content
12891303

12901304
# Parse the output using the associated parser
12911305
result = self.llm_task_manager.parse_task_output(Task.GENERATE_INTENT_STEPS_MESSAGE, output=result)
@@ -1352,7 +1366,7 @@ async def generate_intent_steps_message(
13521366
# We make this call with temperature 0 to have it as deterministic as possible.
13531367
gen_options: Optional[GenerationOptions] = generation_options_var.get()
13541368
llm_params = (gen_options and gen_options.llm_params) or {}
1355-
result = await llm_call(generation_llm, prompt, llm_params=llm_params)
1369+
result = (await llm_call(generation_llm, prompt, llm_params=llm_params)).content
13561370

13571371
result = self.llm_task_manager.parse_task_output(Task.GENERAL, output=result)
13581372
text = result.strip()

0 commit comments

Comments
 (0)