2323from dataclasses import asdict
2424from functools import lru_cache
2525from time import time
26- from typing import Any , Awaitable , Callable , Dict , List , Optional , Union , cast
26+ from typing import Any , Awaitable , Callable , Dict , List , Optional , cast
2727
2828from jinja2 import meta
2929from jinja2 .sandbox import SandboxedEnvironment
30- from langchain_core .language_models import BaseChatModel , BaseLLM
3130
3231from nemoguardrails .actions .actions import ActionResult , action
3332from nemoguardrails .actions .llm .utils import (
6160from nemoguardrails .rails .llm .config import EmbeddingSearchProvider , RailsConfig
6261from nemoguardrails .rails .llm .options import GenerationOptions
6362from nemoguardrails .streaming import StreamingHandler
63+ from nemoguardrails .types import LLMModel
6464from nemoguardrails .utils import (
6565 new_event_dict ,
6666 new_uuid ,
@@ -79,7 +79,7 @@ class LLMGenerationActions:
7979 def __init__ (
8080 self ,
8181 config : RailsConfig ,
82- llm : Optional [Union [ BaseLLM , BaseChatModel ] ],
82+ llm : Optional [LLMModel ],
8383 llm_task_manager : LLMTaskManager ,
8484 get_embedding_search_provider_instance : Callable [[Optional [EmbeddingSearchProvider ]], EmbeddingsIndex ],
8585 verbose : bool = False ,
@@ -350,7 +350,7 @@ async def generate_user_intent(
350350 events : List [dict ],
351351 context : dict ,
352352 config : RailsConfig ,
353- llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = None ,
353+ llm : Optional [LLMModel ] = None ,
354354 kb : Optional [KnowledgeBase ] = None ,
355355 ):
356356 """Generate the canonical form for what the user said i.e. user intent."""
@@ -369,7 +369,7 @@ async def generate_user_intent(
369369
370370 # Use action specific llm if registered else fallback to main llm
371371 # This can be None as some code-paths use embedding lookups rather than LLM generation
372- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
372+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
373373
374374 streaming_handler = streaming_handler_var .get ()
375375
@@ -424,11 +424,13 @@ async def generate_user_intent(
424424 llm_call_info_var .set (LLMCallInfo (task = Task .GENERATE_USER_INTENT .value ))
425425
426426 # We make this call with temperature 0 to have it as deterministic as possible.
427- result = await llm_call (
428- generation_llm ,
429- prompt ,
430- llm_params = {"temperature" : self .config .lowest_temperature },
431- )
427+ result = (
428+ await llm_call (
429+ generation_llm ,
430+ prompt ,
431+ llm_params = {"temperature" : self .config .lowest_temperature },
432+ )
433+ ).content
432434
433435 # Parse the output using the associated parser
434436 result = self .llm_task_manager .parse_task_output (Task .GENERATE_USER_INTENT , output = result )
@@ -501,12 +503,14 @@ async def generate_user_intent(
501503
502504 streaming_handler : Optional [StreamingHandler ] = streaming_handler_var .get ()
503505
504- text = await llm_call (
505- generation_llm ,
506- prompt ,
507- streaming_handler = streaming_handler ,
508- llm_params = llm_params ,
509- )
506+ text = (
507+ await llm_call (
508+ generation_llm ,
509+ prompt ,
510+ streaming_handler = streaming_handler ,
511+ llm_params = llm_params ,
512+ )
513+ ).content
510514 text = self .llm_task_manager .parse_task_output (Task .GENERAL , output = text )
511515
512516 else :
@@ -530,13 +534,15 @@ async def generate_user_intent(
530534 generation_options : Optional [GenerationOptions ] = generation_options_var .get ()
531535 llm_params = (generation_options and generation_options .llm_params ) or {}
532536
533- result = await llm_call (
534- generation_llm ,
535- prompt ,
536- streaming_handler = streaming_handler ,
537- stop = ["User:" ],
538- llm_params = llm_params ,
539- )
537+ result = (
538+ await llm_call (
539+ generation_llm ,
540+ prompt ,
541+ streaming_handler = streaming_handler ,
542+ stop = ["User:" ],
543+ llm_params = llm_params ,
544+ )
545+ ).content
540546
541547 text = self .llm_task_manager .parse_task_output (Task .GENERAL , output = result )
542548 text = text .strip ()
@@ -588,15 +594,15 @@ async def _search_flows_index(self, text, max_results):
588594 return final_results [0 :max_results ]
589595
590596 @action (is_system_action = True )
591- async def generate_next_steps (self , events : List [dict ], llm : Optional [BaseLLM ] = None ):
597+ async def generate_next_steps (self , events : List [dict ], llm : Optional [LLMModel ] = None ):
592598 """Generate the next step in the current conversation flow.
593599
594600 Currently, only generates a next step after a user intent.
595601 """
596602 log .info ("Phase 2 :: Generating next step ..." )
597603
598604 # Use action specific llm if registered else fallback to main llm
599- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
605+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
600606
601607 # The last event should be the "StartInternalSystemAction" and the one before it the "UserIntent".
602608 event = get_last_user_intent_event (events )
@@ -633,11 +639,13 @@ async def generate_next_steps(self, events: List[dict], llm: Optional[BaseLLM] =
633639 llm_call_info_var .set (LLMCallInfo (task = Task .GENERATE_NEXT_STEPS .value ))
634640
635641 # We use temperature 0 for next step prediction as well
636- result = await llm_call (
637- generation_llm ,
638- prompt ,
639- llm_params = {"temperature" : self .config .lowest_temperature },
640- )
642+ result = (
643+ await llm_call (
644+ generation_llm ,
645+ prompt ,
646+ llm_params = {"temperature" : self .config .lowest_temperature },
647+ )
648+ ).content
641649
642650 # Parse the output using the associated parser
643651 result = self .llm_task_manager .parse_task_output (Task .GENERATE_NEXT_STEPS , output = result )
@@ -743,12 +751,12 @@ def _render_string(
743751 return template .render (render_context )
744752
745753 @action (is_system_action = True )
746- async def generate_bot_message (self , events : List [dict ], context : dict , llm : Optional [BaseLLM ] = None ):
754+ async def generate_bot_message (self , events : List [dict ], context : dict , llm : Optional [LLMModel ] = None ):
747755 """Generate a bot message based on the desired bot intent."""
748756 log .info ("Phase 3 :: Generating bot message ..." )
749757
750758 # Use action specific llm if registered else fallback to main llm
751- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
759+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
752760
753761 # The last event should be the "StartInternalSystemAction" and the one before it the "BotIntent".
754762 event = get_last_bot_intent_event (events )
@@ -894,12 +902,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
894902
895903 if not prompt :
896904 raise RuntimeError ("No prompt found to generate bot message" )
897- result = await llm_call (
898- generation_llm ,
899- prompt ,
900- streaming_handler = streaming_handler ,
901- llm_params = llm_params ,
902- )
905+ result = (
906+ await llm_call (
907+ generation_llm ,
908+ prompt ,
909+ streaming_handler = streaming_handler ,
910+ llm_params = llm_params ,
911+ )
912+ ).content
903913
904914 result = self .llm_task_manager .parse_task_output (Task .GENERAL , output = result )
905915
@@ -948,12 +958,14 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
948958 generation_options : Optional [GenerationOptions ] = generation_options_var .get ()
949959 llm_params = (generation_options and generation_options .llm_params ) or {}
950960
951- result = await llm_call (
952- generation_llm ,
953- prompt ,
954- streaming_handler = streaming_handler ,
955- llm_params = llm_params ,
956- )
961+ result = (
962+ await llm_call (
963+ generation_llm ,
964+ prompt ,
965+ streaming_handler = streaming_handler ,
966+ llm_params = llm_params ,
967+ )
968+ ).content
957969
958970 log .info (
959971 "--- :: LLM Bot Message Generation call took %.2f seconds" ,
@@ -1016,7 +1028,7 @@ async def generate_value(
10161028 instructions : str ,
10171029 events : List [dict ],
10181030 var_name : Optional [str ] = None ,
1019- llm : Optional [BaseLLM ] = None ,
1031+ llm : Optional [LLMModel ] = None ,
10201032 ):
10211033 """Generate a value in the context of the conversation.
10221034
@@ -1027,7 +1039,7 @@ async def generate_value(
10271039 :param llm: Custom llm model to generate_value
10281040 """
10291041 # Use action specific llm if registered else fallback to main llm
1030- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
1042+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
10311043
10321044 last_event = events [- 1 ]
10331045 assert last_event ["type" ] == "StartInternalSystemAction"
@@ -1062,11 +1074,13 @@ async def generate_value(
10621074 # Initialize the LLMCallInfo object
10631075 llm_call_info_var .set (LLMCallInfo (task = Task .GENERATE_VALUE .value ))
10641076
1065- result = await llm_call (
1066- generation_llm ,
1067- prompt ,
1068- llm_params = {"temperature" : self .config .lowest_temperature },
1069- )
1077+ result = (
1078+ await llm_call (
1079+ generation_llm ,
1080+ prompt ,
1081+ llm_params = {"temperature" : self .config .lowest_temperature },
1082+ )
1083+ ).content
10701084
10711085 # Parse the output using the associated parser
10721086 result = self .llm_task_manager .parse_task_output (Task .GENERATE_VALUE , output = result )
@@ -1092,7 +1106,7 @@ async def generate_value(
10921106 async def generate_intent_steps_message (
10931107 self ,
10941108 events : List [dict ],
1095- llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = None ,
1109+ llm : Optional [LLMModel ] = None ,
10961110 kb : Optional [KnowledgeBase ] = None ,
10971111 ):
10981112 """Generate all three main Guardrails phases with a single LLM call.
@@ -1110,7 +1124,7 @@ async def generate_intent_steps_message(
11101124 "Cannot generate user intent from this event type."
11111125 )
11121126 # Use action specific llm if registered else fallback to main llm
1113- generation_llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = llm if llm else self .llm
1127+ generation_llm : Optional [LLMModel ] = llm if llm else self .llm
11141128
11151129 streaming_handler = streaming_handler_var .get ()
11161130
@@ -1274,7 +1288,7 @@ async def generate_intent_steps_message(
12741288 ** llm_params ,
12751289 "temperature" : self .config .lowest_temperature ,
12761290 }
1277- result = await llm_call (generation_llm , prompt , llm_params = additional_params )
1291+ result = ( await llm_call (generation_llm , prompt , llm_params = additional_params )). content
12781292
12791293 # Parse the output using the associated parser
12801294 result = self .llm_task_manager .parse_task_output (Task .GENERATE_INTENT_STEPS_MESSAGE , output = result )
@@ -1341,7 +1355,7 @@ async def generate_intent_steps_message(
13411355 # We make this call with temperature 0 to have it as deterministic as possible.
13421356 gen_options : Optional [GenerationOptions ] = generation_options_var .get ()
13431357 llm_params = (gen_options and gen_options .llm_params ) or {}
1344- result = await llm_call (generation_llm , prompt , llm_params = llm_params )
1358+ result = ( await llm_call (generation_llm , prompt , llm_params = llm_params )). content
13451359
13461360 result = self .llm_task_manager .parse_task_output (Task .GENERAL , output = result )
13471361 text = result .strip ()
0 commit comments