2323from dataclasses import asdict
2424from functools import lru_cache
2525from time import time
26- from typing import Any , Awaitable , Callable , Dict , List , Optional , Union , cast
26+ from typing import Any , Awaitable , Callable , Dict , List , Optional , cast
2727
2828from jinja2 import meta
2929from jinja2 .sandbox import SandboxedEnvironment
30- from langchain_core .language_models import BaseChatModel , BaseLLM
3130
3231from nemoguardrails .actions .actions import ActionResult , action
3332from nemoguardrails .actions .llm .utils import (
6160from nemoguardrails .rails .llm .config import EmbeddingSearchProvider , RailsConfig
6261from nemoguardrails .rails .llm .options import GenerationOptions
6362from nemoguardrails .streaming import StreamingHandler
63+ from nemoguardrails .types import LLMModel
6464from nemoguardrails .utils import (
6565 new_event_dict ,
6666 new_uuid ,
@@ -79,7 +79,7 @@ class LLMGenerationActions:
7979 def __init__ (
8080 self ,
8181 config : RailsConfig ,
82- llm : Optional [Union [ BaseLLM , BaseChatModel ] ],
82+ llm : Optional [LLMModel ],
8383 llm_task_manager : LLMTaskManager ,
8484 get_embedding_search_provider_instance : Callable [[Optional [EmbeddingSearchProvider ]], EmbeddingsIndex ],
8585 verbose : bool = False ,
@@ -350,7 +350,7 @@ async def generate_user_intent(
350350 events : List [dict ],
351351 context : dict ,
352352 config : RailsConfig ,
353- llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = None ,
353+ llm : Optional [LLMModel ] = None ,
354354 kb : Optional [KnowledgeBase ] = None ,
355355 ):
356356 """Generate the canonical form for what the user said i.e. user intent."""
@@ -369,7 +369,7 @@ async def generate_user_intent(
369369
370370 # Use action specific llm if registered else fallback to main llm
371371 # This can be None as some code-paths use embedding lookups rather than LLM generation
372- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
372+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
373373
374374 streaming_handler = streaming_handler_var .get ()
375375
@@ -435,11 +435,13 @@ async def generate_user_intent(
435435 llm_call_info_var .set (LLMCallInfo (task = Task .GENERATE_USER_INTENT .value ))
436436
437437 # We make this call with temperature 0 to have it as deterministic as possible.
438- result = await llm_call (
439- generation_llm ,
440- prompt ,
441- llm_params = {"temperature" : self .config .lowest_temperature },
442- )
438+ result = (
439+ await llm_call (
440+ generation_llm ,
441+ prompt ,
442+ llm_params = {"temperature" : self .config .lowest_temperature },
443+ )
444+ ).content
443445
444446 # Parse the output using the associated parser
445447 result = self .llm_task_manager .parse_task_output (Task .GENERATE_USER_INTENT , output = result )
@@ -512,12 +514,14 @@ async def generate_user_intent(
512514
513515 streaming_handler : Optional [StreamingHandler ] = streaming_handler_var .get ()
514516
515- text = await llm_call (
516- generation_llm ,
517- prompt ,
518- streaming_handler = streaming_handler ,
519- llm_params = llm_params ,
520- )
517+ text = (
518+ await llm_call (
519+ generation_llm ,
520+ prompt ,
521+ streaming_handler = streaming_handler ,
522+ llm_params = llm_params ,
523+ )
524+ ).content
521525 text = self .llm_task_manager .parse_task_output (Task .GENERAL , output = text )
522526
523527 else :
@@ -541,13 +545,15 @@ async def generate_user_intent(
541545 generation_options : Optional [GenerationOptions ] = generation_options_var .get ()
542546 llm_params = (generation_options and generation_options .llm_params ) or {}
543547
544- result = await llm_call (
545- generation_llm ,
546- prompt ,
547- streaming_handler = streaming_handler ,
548- stop = ["User:" ],
549- llm_params = llm_params ,
550- )
548+ result = (
549+ await llm_call (
550+ generation_llm ,
551+ prompt ,
552+ streaming_handler = streaming_handler ,
553+ stop = ["User:" ],
554+ llm_params = llm_params ,
555+ )
556+ ).content
551557
552558 text = self .llm_task_manager .parse_task_output (Task .GENERAL , output = result )
553559 text = text .strip ()
@@ -599,15 +605,15 @@ async def _search_flows_index(self, text, max_results):
599605 return final_results [0 :max_results ]
600606
601607 @action (is_system_action = True )
602- async def generate_next_steps (self , events : List [dict ], llm : Optional [BaseLLM ] = None ):
608+ async def generate_next_steps (self , events : List [dict ], llm : Optional [LLMModel ] = None ):
603609 """Generate the next step in the current conversation flow.
604610
605611 Currently, only generates a next step after a user intent.
606612 """
607613 log .info ("Phase 2 :: Generating next step ..." )
608614
609615 # Use action specific llm if registered else fallback to main llm
610- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
616+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
611617
612618 # The last event should be the "StartInternalSystemAction" and the one before it the "UserIntent".
613619 event = get_last_user_intent_event (events )
@@ -644,11 +650,13 @@ async def generate_next_steps(self, events: List[dict], llm: Optional[BaseLLM] =
644650 llm_call_info_var .set (LLMCallInfo (task = Task .GENERATE_NEXT_STEPS .value ))
645651
646652 # We use temperature 0 for next step prediction as well
647- result = await llm_call (
648- generation_llm ,
649- prompt ,
650- llm_params = {"temperature" : self .config .lowest_temperature },
651- )
653+ result = (
654+ await llm_call (
655+ generation_llm ,
656+ prompt ,
657+ llm_params = {"temperature" : self .config .lowest_temperature },
658+ )
659+ ).content
652660
653661 # Parse the output using the associated parser
654662 result = self .llm_task_manager .parse_task_output (Task .GENERATE_NEXT_STEPS , output = result )
@@ -754,12 +762,12 @@ def _render_string(
754762 return template .render (render_context )
755763
756764 @action (is_system_action = True )
757- async def generate_bot_message (self , events : List [dict ], context : dict , llm : Optional [BaseLLM ] = None ):
765+ async def generate_bot_message (self , events : List [dict ], context : dict , llm : Optional [LLMModel ] = None ):
758766 """Generate a bot message based on the desired bot intent."""
759767 log .info ("Phase 3 :: Generating bot message ..." )
760768
761769 # Use action specific llm if registered else fallback to main llm
762- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
770+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
763771
764772 # The last event should be the "StartInternalSystemAction" and the one before it the "BotIntent".
765773 event = get_last_bot_intent_event (events )
@@ -905,12 +913,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
905913
906914 if not prompt :
907915 raise RuntimeError ("No prompt found to generate bot message" )
908- result = await llm_call (
909- generation_llm ,
910- prompt ,
911- streaming_handler = streaming_handler ,
912- llm_params = llm_params ,
913- )
916+ result = (
917+ await llm_call (
918+ generation_llm ,
919+ prompt ,
920+ streaming_handler = streaming_handler ,
921+ llm_params = llm_params ,
922+ )
923+ ).content
914924
915925 result = self .llm_task_manager .parse_task_output (Task .GENERAL , output = result )
916926
@@ -959,12 +969,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
959969 generation_options : Optional [GenerationOptions ] = generation_options_var .get ()
960970 llm_params = (generation_options and generation_options .llm_params ) or {}
961971
962- result = await llm_call (
963- generation_llm ,
964- prompt ,
965- streaming_handler = streaming_handler ,
966- llm_params = llm_params ,
967- )
972+ result = (
973+ await llm_call (
974+ generation_llm ,
975+ prompt ,
976+ streaming_handler = streaming_handler ,
977+ llm_params = llm_params ,
978+ )
979+ ).content
968980
969981 log .info (
970982 "--- :: LLM Bot Message Generation call took %.2f seconds" ,
@@ -1027,7 +1039,7 @@ async def generate_value(
10271039 instructions : str ,
10281040 events : List [dict ],
10291041 var_name : Optional [str ] = None ,
1030- llm : Optional [BaseLLM ] = None ,
1042+ llm : Optional [LLMModel ] = None ,
10311043 ):
10321044 """Generate a value in the context of the conversation.
10331045
@@ -1038,7 +1050,7 @@ async def generate_value(
10381050 :param llm: Custom llm model to generate_value
10391051 """
10401052 # Use action specific llm if registered else fallback to main llm
1041- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
1053+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
10421054
10431055 last_event = events [- 1 ]
10441056 assert last_event ["type" ] == "StartInternalSystemAction"
@@ -1073,11 +1085,13 @@ async def generate_value(
10731085 # Initialize the LLMCallInfo object
10741086 llm_call_info_var .set (LLMCallInfo (task = Task .GENERATE_VALUE .value ))
10751087
1076- result = await llm_call (
1077- generation_llm ,
1078- prompt ,
1079- llm_params = {"temperature" : self .config .lowest_temperature },
1080- )
1088+ result = (
1089+ await llm_call (
1090+ generation_llm ,
1091+ prompt ,
1092+ llm_params = {"temperature" : self .config .lowest_temperature },
1093+ )
1094+ ).content
10811095
10821096 # Parse the output using the associated parser
10831097 result = self .llm_task_manager .parse_task_output (Task .GENERATE_VALUE , output = result )
@@ -1103,7 +1117,7 @@ async def generate_value(
11031117 async def generate_intent_steps_message (
11041118 self ,
11051119 events : List [dict ],
1106- llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = None ,
1120+ llm : Optional [LLMModel ] = None ,
11071121 kb : Optional [KnowledgeBase ] = None ,
11081122 ):
11091123 """Generate all three main Guardrails phases with a single LLM call.
@@ -1121,7 +1135,7 @@ async def generate_intent_steps_message(
11211135 "Cannot generate user intent from this event type."
11221136 )
11231137 # Use action specific llm if registered else fallback to main llm
1124- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
1138+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
11251139
11261140 streaming_handler = streaming_handler_var .get ()
11271141
@@ -1285,7 +1299,7 @@ async def generate_intent_steps_message(
12851299 ** llm_params ,
12861300 "temperature" : self .config .lowest_temperature ,
12871301 }
1288- result = await llm_call (generation_llm , prompt , llm_params = additional_params )
1302+ result = ( await llm_call (generation_llm , prompt , llm_params = additional_params )). content
12891303
12901304 # Parse the output using the associated parser
12911305 result = self .llm_task_manager .parse_task_output (Task .GENERATE_INTENT_STEPS_MESSAGE , output = result )
@@ -1352,7 +1366,7 @@ async def generate_intent_steps_message(
13521366 # We make this call with temperature 0 to have it as deterministic as possible.
13531367 gen_options : Optional [GenerationOptions ] = generation_options_var .get ()
13541368 llm_params = (gen_options and gen_options .llm_params ) or {}
1355- result = await llm_call (generation_llm , prompt , llm_params = llm_params )
1369+ result = ( await llm_call (generation_llm , prompt , llm_params = llm_params )). content
13561370
13571371 result = self .llm_task_manager .parse_task_output (Task .GENERAL , output = result )
13581372 text = result .strip ()
0 commit comments