Skip to content

Commit 40df9bc

Browse files
Update tests
1 parent e11de23 commit 40df9bc

18 files changed

Lines changed: 582 additions & 124 deletions

nemoguardrails/actions/llm/generation.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -887,9 +887,22 @@ async def generate_bot_message(
887887
# of course, it does not work when passed as context in `run_output_rails_in_streaming`
888888
# streaming_handler is set when stream_async method is used
889889

890+
has_streaming_handler = streaming_handler is not None
891+
has_output_streaming = (
892+
self.config.rails.output.streaming
893+
and self.config.rails.output.streaming.enabled
894+
)
895+
log.info(
896+
f"generate_bot_message: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}"
897+
)
890898
# if streaming_handler and len(self.config.rails.output.flows) > 0:
891-
if streaming_handler and self.config.rails.output.streaming.enabled:
899+
if streaming_handler and has_output_streaming:
900+
log.info("Setting skip_output_rails = True")
892901
context_updates["skip_output_rails"] = True
902+
else:
903+
log.info(
904+
f"NOT setting skip_output_rails: streaming_handler={has_streaming_handler}, output.streaming={has_output_streaming}"
905+
)
893906

894907
if bot_intent in self.config.bot_messages:
895908
# Choose a message randomly from self.config.bot_messages[bot_message]
@@ -970,7 +983,9 @@ async def generate_bot_message(
970983
new_event_dict("BotMessage", text=text)
971984
)
972985

973-
return ActionResult(events=output_events)
986+
return ActionResult(
987+
events=output_events, context_updates=context_updates
988+
)
974989
else:
975990
if streaming_handler:
976991
await streaming_handler.push_chunk(
@@ -987,7 +1002,9 @@ async def generate_bot_message(
9871002
)
9881003
output_events.append(bot_message_event)
9891004

990-
return ActionResult(events=output_events)
1005+
return ActionResult(
1006+
events=output_events, context_updates=context_updates
1007+
)
9911008

9921009
# If we are in passthrough mode, we just use the input for prompting
9931010
if self.config.passthrough:

nemoguardrails/colang/v1_0/runtime/flows.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""A simplified modeling of the CoFlows engine."""
1717

18+
import logging
1819
import uuid
1920
from dataclasses import dataclass, field
2021
from enum import Enum
@@ -25,6 +26,8 @@
2526
from nemoguardrails.colang.v1_0.runtime.sliding import slide
2627
from nemoguardrails.utils import new_event_dict, new_uuid
2728

29+
log = logging.getLogger(__name__)
30+
2831

2932
@dataclass
3033
class FlowConfig:
@@ -356,7 +359,15 @@ def compute_next_state(state: State, event: dict) -> State:
356359
if event["type"] == "ContextUpdate":
357360
# TODO: add support to also remove keys from the context.
358361
# maybe with a special context key e.g. "__remove__": ["key1", "key2"]
362+
if "skip_output_rails" in event["data"]:
363+
log.info(
364+
f"ContextUpdate setting skip_output_rails={event['data']['skip_output_rails']}"
365+
)
359366
state.context.update(event["data"])
367+
if "skip_output_rails" in state.context:
368+
log.info(
369+
f"After update, context skip_output_rails={state.context.get('skip_output_rails')}"
370+
)
360371
state.context_updates = {}
361372
state.next_step = None
362373
return state
@@ -415,6 +426,12 @@ def compute_next_state(state: State, event: dict) -> State:
415426
_record_next_step(new_state, flow_state, flow_config, priority_modifier=0.9)
416427
continue
417428

429+
# Debug logging for BotMessage event and skip_output_rails
430+
if event["type"] == "BotMessage":
431+
log.info(
432+
f"BotMessage event processing for flow '{flow_config.id}', skip_output_rails in context: {flow_state.context.get('skip_output_rails', 'NOT SET')}"
433+
)
434+
418435
# If we're at a branching point, we look at all individual heads.
419436
matching_head = None
420437

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,9 @@ async def _run_output_rails_in_parallel_streaming(
523523
flows_with_params: Dictionary mapping flow_id to {"action_name": str, "params": dict}
524524
events: The events list for context
525525
"""
526+
# Compute context from events so actions can access bot_message
527+
context = compute_context(events)
528+
526529
tasks = []
527530

528531
async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
@@ -532,8 +535,11 @@ async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
532535
action_name = action_info["action_name"]
533536
params = action_info["params"]
534537

538+
# Merge context into params so actions have access to bot_message
539+
params_with_context = {**params, "context": context}
540+
535541
result_tuple = await self.action_dispatcher.execute_action(
536-
action_name, params
542+
action_name, params_with_context
537543
)
538544
result, status = result_tuple
539545

@@ -731,10 +737,19 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
731737
return_events = []
732738
context_updates = {}
733739

740+
if action_name == "generate_bot_message":
741+
log.info(
742+
f"DEBUG: generate_bot_message returned, isinstance(ActionResult)={isinstance(result, ActionResult)}"
743+
)
744+
734745
if isinstance(result, ActionResult):
735746
return_value = result.return_value
736747
return_events = result.events
737748
context_updates.update(result.context_updates)
749+
if action_name == "generate_bot_message":
750+
log.info(
751+
f"generate_bot_message ActionResult: context_updates={context_updates}, skip_output_rails={'skip_output_rails' in context_updates}"
752+
)
738753

739754
# If we have an action result key, we also record the update.
740755
if action_result_key:

nemoguardrails/rails/llm/event_translator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
"""Event translator for converting between messages and events."""
217

318
import logging
@@ -24,7 +39,7 @@ def __init__(self, config: RailsConfig):
2439
self.events_history_cache: Dict[str, List[dict]] = {}
2540

2641
def messages_to_events(
27-
self, messages: List[dict], state: Optional[Any] = None
42+
self, messages: List[dict[str, Any]], state: Optional[Any] = None
2843
) -> List[dict]:
2944
"""Convert messages to events.
3045
@@ -156,7 +171,7 @@ def _messages_to_events_v1(self, messages: List[dict]) -> List[dict]:
156171
return events
157172

158173
def _messages_to_events_v2(
159-
self, messages: List[dict], state: Optional[Any]
174+
self, messages: List[dict[str, Any]], state: Optional[Any]
160175
) -> List[dict]:
161176
"""Convert messages to events for Colang 2.x.
162177

0 commit comments

Comments
 (0)