From 47f29fc55e58f88b0344442f2414cce90d6ca8b5 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 20 Jun 2025 18:17:01 +0200 Subject: [PATCH 1/4] llmmessagesrouter - draft --- .../components/routers/llm_messages_router.py | 138 ++++++++++++++++++ haystack/utils/hf.py | 2 +- 2 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 haystack/components/routers/llm_messages_router.py diff --git a/haystack/components/routers/llm_messages_router.py b/haystack/components/routers/llm_messages_router.py new file mode 100644 index 0000000000..af2decc0cb --- /dev/null +++ b/haystack/components/routers/llm_messages_router.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Dict, List, Optional, Union + +from haystack import component +from haystack.components.generators.chat.types import ChatGenerator +from haystack.dataclasses import ChatMessage, ChatRole + + +@component +class LLMMessagesRouter: + """ + Routes Chat Messages to different connections, using a generative Language Model to perform classification. + + ### Usage example + ```python + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.components.routers.llm_messages_router import LLMMessagesRouter + from haystack.dataclasses import ChatMessage + + # initialize a Chat Generator with a generative model for moderation + chat_generator = HuggingFaceAPIChatGenerator( + api_type="serverless_inference_api", + api_params={"model": "meta-llama/Llama-Guard-4-12B", "provider": "groq"}, + ) + + router = LLMMessagesRouter(chat_generator=chat_generator, + output_names=["unsafe", "safe"], + output_patterns=["unsafe", "safe"]) + + + print(router.run([ChatMessage.from_user("How to rob a bank?")])) + + # { + # 'router_text': 'unsafe\nS2', + # 'unsafe': [ + # ChatMessage( + # _role=, + # _content=[TextContent(text='How to rob a bank?')], + # _name=None, + # _meta={} + # ) + # ] + # } + """ + + def __init__( + self, + chat_generator: ChatGenerator, + output_names: List[str], + output_patterns: List[str], + system_prompt: Optional[str] = None, + ): + """ + Initialize the LLMMessagesRouter component. + + :param chat_generator: a ChatGenerator instance which represents the LLM. + :param output_names: list of names of the output connections of the router. These names can be used to connect + the router to other components. + :param output_patterns: list of regular expressions to be matched against the output of the LLM. Each of them + corresponds to an output name. Matching is executed in the order of the output_patterns list. + :param system_prompt: system prompt to customize the behavior of the LLM. + + :return: a LLMMessagesRouter instance. + + :raises ValueError: if output_names and output_patterns are not non-empty lists of the same length. + """ + if not output_names or not output_patterns or len(output_names) != len(output_patterns): + raise ValueError("output_names and output_patterns must be non-empty lists of the same length") + + self._chat_generator = chat_generator + self._system_prompt = system_prompt + self._output_names = output_names + self._output_patterns = output_patterns + + self._compiled_patterns = [re.compile(pattern) for pattern in output_patterns] + self._is_warmed_up = False + + component.set_output_types( + self, **{"router_text": str, **dict.fromkeys(output_names + ["unmatched"], List[ChatMessage])} + ) + + def warm_up(self): + """ + Warm up the underlying LLM. + """ + if not self._is_warmed_up: + if hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + self._is_warmed_up = True + + def run(self, messages: List[ChatMessage]) -> Dict[str, Union[str, List[ChatMessage]]]: + """ + Use the LLM to classify the messages and route them to the appropriate output connection. + + :param messages: list of ChatMessages to route. Only user and assistant messages are supported. + + :returns: A dictionary with the following keys: + - "router_text": the text output of the LLM (for debugging purposes). + - "unmatched": the messages that did not match any of the output patterns. + - other keys are the output names, and the values are the messages that matched the corresponding output + pattern. + + :raises ValueError: if messages is an empty list. + :raises RuntimeError: if the component is not warmed up and the ChatGenerator has a warm_up method. + """ + if not messages: + raise ValueError("messages must be a non-empty list.") + if not all(message.is_from(ChatRole.USER) or message.is_from(ChatRole.ASSISTANT) for message in messages): + msg = ( + "messages must contain only user and assistant messages. To customize the behavior of the " + "chat_generator, you can use the `system_prompt` parameter." + ) + raise ValueError(msg) + + if not self._is_warmed_up and hasattr(self._chat_generator, "warm_up"): + raise RuntimeError("The component is not warmed up. Please call the `warm_up` method first.") + + messages_for_inference = [] + if self._system_prompt: + messages_for_inference.append(ChatMessage.from_system(self._system_prompt)) + messages_for_inference.extend(messages) + + llm_response = self._chat_generator.run(messages=messages_for_inference)["replies"][0].text + + output = {"router_text": llm_response} + + for output_name, pattern in zip(self._output_names, self._compiled_patterns): + if pattern.search(llm_response): + output[output_name] = messages + break + else: + output["unmatched"] = messages + + return output diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index bafe244e33..1da03b0926 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -241,7 +241,7 @@ def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Se allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"] error_msg = f"Model {model_id} is not a embedding model. Please provide a embedding model." elif model_type == HFModelType.GENERATION: - allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"] + allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation", "image-text-to-text"] error_msg = f"Model {model_id} is not a text generation model. Please provide a text generation model." else: allowed_model = False From 4623b9c08050775b0e97ee30e3cf05244ed476c9 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 23 Jun 2025 11:39:03 +0200 Subject: [PATCH 2/4] serde methods --- .../components/routers/llm_messages_router.py | 122 ++++++++++++------ 1 file changed, 79 insertions(+), 43 deletions(-) diff --git a/haystack/components/routers/llm_messages_router.py b/haystack/components/routers/llm_messages_router.py index af2decc0cb..d89abe2f64 100644 --- a/haystack/components/routers/llm_messages_router.py +++ b/haystack/components/routers/llm_messages_router.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union -from haystack import component +from haystack import component, default_from_dict, default_to_dict from haystack.components.generators.chat.types import ChatGenerator +from haystack.core.serialization import component_to_dict from haystack.dataclasses import ChatMessage, ChatRole +from haystack.utils import deserialize_chatgenerator_inplace @component @@ -15,36 +17,39 @@ class LLMMessagesRouter: """ Routes Chat Messages to different connections, using a generative Language Model to perform classification. + This component can be used with general-purpose LLMs and with specialized LLMs for moderation like Llama Guard. + ### Usage example ```python - from haystack.components.generators.chat import HuggingFaceAPIChatGenerator - from haystack.components.routers.llm_messages_router import LLMMessagesRouter - from haystack.dataclasses import ChatMessage - - # initialize a Chat Generator with a generative model for moderation - chat_generator = HuggingFaceAPIChatGenerator( - api_type="serverless_inference_api", - api_params={"model": "meta-llama/Llama-Guard-4-12B", "provider": "groq"}, - ) - - router = LLMMessagesRouter(chat_generator=chat_generator, - output_names=["unsafe", "safe"], - output_patterns=["unsafe", "safe"]) - - - print(router.run([ChatMessage.from_user("How to rob a bank?")])) - - # { - # 'router_text': 'unsafe\nS2', - # 'unsafe': [ - # ChatMessage( - # _role=, - # _content=[TextContent(text='How to rob a bank?')], - # _name=None, - # _meta={} - # ) - # ] - # } + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.components.routers.llm_messages_router import LLMMessagesRouter + from haystack.dataclasses import ChatMessage + + # initialize a Chat Generator with a generative model for moderation + chat_generator = HuggingFaceAPIChatGenerator( + api_type="serverless_inference_api", + api_params={"model": "meta-llama/Llama-Guard-4-12B", "provider": "groq"}, + ) + + router = LLMMessagesRouter(chat_generator=chat_generator, + output_names=["unsafe", "safe"], + output_patterns=["unsafe", "safe"]) + + + print(router.run([ChatMessage.from_user("How to rob a bank?")])) + + # { + # 'router_text': 'unsafe\nS2', + # 'unsafe': [ + # ChatMessage( + # _role=, + # _content=[TextContent(text='How to rob a bank?')], + # _name=None, + # _meta={} + # ) + # ] + # } + ``` """ def __init__( @@ -58,11 +63,13 @@ def __init__( Initialize the LLMMessagesRouter component. :param chat_generator: a ChatGenerator instance which represents the LLM. - :param output_names: list of names of the output connections of the router. These names can be used to connect - the router to other components. - :param output_patterns: list of regular expressions to be matched against the output of the LLM. Each of them - corresponds to an output name. Matching is executed in the order of the output_patterns list. - :param system_prompt: system prompt to customize the behavior of the LLM. + :param output_names: list of output connection names. These can be used to connect the router to other + components. + :param output_patterns: list of regular expressions to be matched against the output of the LLM. Each pattern + corresponds to an output name. Patterns are evaluated in order. + When using moderation models, refer to the model card to understand the expected outputs. + :param system_prompt: optional system prompt to customize the behavior of the LLM. + For moderation models, refer to the model card for supported customization options. :return: a LLMMessagesRouter instance. @@ -94,15 +101,14 @@ def warm_up(self): def run(self, messages: List[ChatMessage]) -> Dict[str, Union[str, List[ChatMessage]]]: """ - Use the LLM to classify the messages and route them to the appropriate output connection. + Classify the messages based on LLM output and route them to the appropriate output connection. - :param messages: list of ChatMessages to route. Only user and assistant messages are supported. + :param messages: list of ChatMessages to be routed. Only user and assistant messages are supported. :returns: A dictionary with the following keys: - - "router_text": the text output of the LLM (for debugging purposes). + - "llm_text": the text output of the LLM, useful for debugging. + - output names: each contains the list of messages that matched the corresponding pattern. - "unmatched": the messages that did not match any of the output patterns. - - other keys are the output names, and the values are the messages that matched the corresponding output - pattern. :raises ValueError: if messages is an empty list. :raises RuntimeError: if the component is not warmed up and the ChatGenerator has a warm_up method. @@ -124,15 +130,45 @@ def run(self, messages: List[ChatMessage]) -> Dict[str, Union[str, List[ChatMess messages_for_inference.append(ChatMessage.from_system(self._system_prompt)) messages_for_inference.extend(messages) - llm_response = self._chat_generator.run(messages=messages_for_inference)["replies"][0].text + llm_text = self._chat_generator.run(messages=messages_for_inference)["replies"][0].text - output = {"router_text": llm_response} + output = {"router_text": llm_text} for output_name, pattern in zip(self._output_names, self._compiled_patterns): - if pattern.search(llm_response): + if pattern.search(llm_text): output[output_name] = messages break else: output["unmatched"] = messages return output + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + return default_to_dict( + self, + chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"), + output_names=self._output_names, + output_patterns=self._output_patterns, + system_prompt=self._system_prompt, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LLMMessagesRouter": + """ + Deserialize this component from a dictionary. + + :param data: + The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + if data["init_parameters"].get("chat_generator"): + deserialize_chatgenerator_inplace(data["init_parameters"], key="chat_generator") + + return default_from_dict(cls, data) From 57bbdab0e466e82a6f994e59684c7a2bb928c15d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Mon, 23 Jun 2025 18:07:42 +0200 Subject: [PATCH 3/4] refinements, tests and release note --- docs/pydoc/config/routers_api.yml | 1 + haystack/components/routers/__init__.py | 2 + .../components/routers/llm_messages_router.py | 20 +- .../llm-messages-router-bc0ee4e1d3a707a0.yaml | 26 +++ .../routers/test_llm_messages_router.py | 213 ++++++++++++++++++ 5 files changed, 251 insertions(+), 11 deletions(-) create mode 100644 releasenotes/notes/llm-messages-router-bc0ee4e1d3a707a0.yaml create mode 100644 test/components/routers/test_llm_messages_router.py diff --git a/docs/pydoc/config/routers_api.yml b/docs/pydoc/config/routers_api.yml index 126d08d7ea..834ad8ac4c 100644 --- a/docs/pydoc/config/routers_api.yml +++ b/docs/pydoc/config/routers_api.yml @@ -5,6 +5,7 @@ loaders: [ "conditional_router", "file_type_router", + "llm_messages_router", "metadata_router", "text_language_router", "transformers_text_router", diff --git a/haystack/components/routers/__init__.py b/haystack/components/routers/__init__.py index 8722134994..4886da54ae 100644 --- a/haystack/components/routers/__init__.py +++ b/haystack/components/routers/__init__.py @@ -10,6 +10,7 @@ _import_structure = { "conditional_router": ["ConditionalRouter"], "file_type_router": ["FileTypeRouter"], + "llm_messages_router": ["LLMMessagesRouter"], "metadata_router": ["MetadataRouter"], "text_language_router": ["TextLanguageRouter"], "transformers_text_router": ["TransformersTextRouter"], @@ -19,6 +20,7 @@ if TYPE_CHECKING: from .conditional_router import ConditionalRouter as ConditionalRouter from .file_type_router import FileTypeRouter as FileTypeRouter + from .llm_messages_router import LLMMessagesRouter as LLMMessagesRouter from .metadata_router import MetadataRouter as MetadataRouter from .text_language_router import TextLanguageRouter as TextLanguageRouter from .transformers_text_router import TransformersTextRouter as TransformersTextRouter diff --git a/haystack/components/routers/llm_messages_router.py b/haystack/components/routers/llm_messages_router.py index d89abe2f64..34439b218a 100644 --- a/haystack/components/routers/llm_messages_router.py +++ b/haystack/components/routers/llm_messages_router.py @@ -39,7 +39,7 @@ class LLMMessagesRouter: print(router.run([ChatMessage.from_user("How to rob a bank?")])) # { - # 'router_text': 'unsafe\nS2', + # 'chat_generator_text': 'unsafe\nS2', # 'unsafe': [ # ChatMessage( # _role=, @@ -71,12 +71,10 @@ def __init__( :param system_prompt: optional system prompt to customize the behavior of the LLM. For moderation models, refer to the model card for supported customization options. - :return: a LLMMessagesRouter instance. - :raises ValueError: if output_names and output_patterns are not non-empty lists of the same length. """ if not output_names or not output_patterns or len(output_names) != len(output_patterns): - raise ValueError("output_names and output_patterns must be non-empty lists of the same length") + raise ValueError("`output_names` and `output_patterns` must be non-empty lists of the same length") self._chat_generator = chat_generator self._system_prompt = system_prompt @@ -87,7 +85,7 @@ def __init__( self._is_warmed_up = False component.set_output_types( - self, **{"router_text": str, **dict.fromkeys(output_names + ["unmatched"], List[ChatMessage])} + self, **{"chat_generator_text": str, **dict.fromkeys(output_names + ["unmatched"], List[ChatMessage])} ) def warm_up(self): @@ -106,18 +104,18 @@ def run(self, messages: List[ChatMessage]) -> Dict[str, Union[str, List[ChatMess :param messages: list of ChatMessages to be routed. Only user and assistant messages are supported. :returns: A dictionary with the following keys: - - "llm_text": the text output of the LLM, useful for debugging. + - "chat_generator_text": the text output of the LLM, useful for debugging. - output names: each contains the list of messages that matched the corresponding pattern. - "unmatched": the messages that did not match any of the output patterns. - :raises ValueError: if messages is an empty list. + :raises ValueError: if messages is an empty list or contains messages with unsupported roles. :raises RuntimeError: if the component is not warmed up and the ChatGenerator has a warm_up method. """ if not messages: raise ValueError("messages must be a non-empty list.") if not all(message.is_from(ChatRole.USER) or message.is_from(ChatRole.ASSISTANT) for message in messages): msg = ( - "messages must contain only user and assistant messages. To customize the behavior of the " + "`messages` must contain only user and assistant messages. To customize the behavior of the " "chat_generator, you can use the `system_prompt` parameter." ) raise ValueError(msg) @@ -130,12 +128,12 @@ def run(self, messages: List[ChatMessage]) -> Dict[str, Union[str, List[ChatMess messages_for_inference.append(ChatMessage.from_system(self._system_prompt)) messages_for_inference.extend(messages) - llm_text = self._chat_generator.run(messages=messages_for_inference)["replies"][0].text + chat_generator_text = self._chat_generator.run(messages=messages_for_inference)["replies"][0].text - output = {"router_text": llm_text} + output = {"chat_generator_text": chat_generator_text} for output_name, pattern in zip(self._output_names, self._compiled_patterns): - if pattern.search(llm_text): + if pattern.search(chat_generator_text): output[output_name] = messages break else: diff --git a/releasenotes/notes/llm-messages-router-bc0ee4e1d3a707a0.yaml b/releasenotes/notes/llm-messages-router-bc0ee4e1d3a707a0.yaml new file mode 100644 index 0000000000..3966726f0f --- /dev/null +++ b/releasenotes/notes/llm-messages-router-bc0ee4e1d3a707a0.yaml @@ -0,0 +1,26 @@ +--- +features: + - | + We introduced the `LLMMessagesRouter` component, that routes Chat Messages to different connections, using a + generative Language Model to perform classification. + + This component can be used with general-purpose LLMs and with specialized LLMs for moderation like Llama Guard. + + Usage example: + ```python + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.components.routers.llm_messages_router import LLMMessagesRouter + from haystack.dataclasses import ChatMessage + + # initialize a Chat Generator with a generative model for moderation + chat_generator = HuggingFaceAPIChatGenerator( + api_type="serverless_inference_api", + api_params={"model": "meta-llama/Llama-Guard-4-12B", "provider": "groq"}, + ) + + router = LLMMessagesRouter(chat_generator=chat_generator, + output_names=["unsafe", "safe"], + output_patterns=["unsafe", "safe"]) + + print(router.run([ChatMessage.from_user("How to rob a bank?")])) + ``` diff --git a/test/components/routers/test_llm_messages_router.py b/test/components/routers/test_llm_messages_router.py new file mode 100644 index 0000000000..abd71012e3 --- /dev/null +++ b/test/components/routers/test_llm_messages_router.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import re +import pytest +from typing import Any, Dict, List +from unittest.mock import Mock +from haystack.components.routers.llm_messages_router import LLMMessagesRouter +from haystack.dataclasses import ChatMessage +from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.core.serialization import default_to_dict, default_from_dict +import os + + +class MockChatGenerator: + def __init__(self, return_text: str = "safe"): + self.return_text = return_text + + def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: + return {"replies": [ChatMessage.from_assistant(self.return_text)]} + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict(self, return_text=self.return_text) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MockChatGenerator": + return default_from_dict(cls, data) + + +class TestLLMMessagesRouter: + def test_init(self): + system_prompt = "Classify the messages as safe or unsafe." + chat_generator = MockChatGenerator() + + router = LLMMessagesRouter( + chat_generator=chat_generator, + system_prompt=system_prompt, + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + ) + + assert router._chat_generator is chat_generator + assert router._system_prompt == system_prompt + assert router._output_names == ["safe", "unsafe"] + assert router._output_patterns == ["safe", "unsafe"] + assert router._compiled_patterns == [re.compile(pattern) for pattern in ["safe", "unsafe"]] + assert router._is_warmed_up is False + + def test_init_errors(self): + chat_generator = MockChatGenerator() + + with pytest.raises(ValueError): + LLMMessagesRouter(chat_generator=chat_generator, output_names=[], output_patterns=["pattern1", "pattern2"]) + + with pytest.raises(ValueError): + LLMMessagesRouter(chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=[]) + + with pytest.raises(ValueError): + LLMMessagesRouter( + chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=["pattern1"] + ) + + def test_warm_up_with_unwarmable_chat_generator(self): + chat_generator = MockChatGenerator() + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + router.warm_up() + assert router._is_warmed_up is True + + def test_warm_up_with_warmable_chat_generator(self): + chat_generator = Mock() + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + router.warm_up() + assert router._is_warmed_up is True + assert router._chat_generator.warm_up.call_count == 1 + + def test_run_input_errors(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + with pytest.raises(ValueError): + router.run([]) + + with pytest.raises(ValueError): + router.run([ChatMessage.from_system("You are a helpful assistant.")]) + + def test_run_no_warm_up_with_unwarmable_chat_generator(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + router.run([ChatMessage.from_user("Hello")]) + + def test_run_no_warm_up_with_warmable_chat_generator(self): + chat_generator = Mock() + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + with pytest.raises(RuntimeError): + router.run([ChatMessage.from_user("Hello")]) + + def test_run(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(return_text="safe"), + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + ) + + messages = [ChatMessage.from_user("Hello")] + result = router.run(messages) + + assert result["chat_generator_text"] == "safe" + assert result["safe"] == messages + assert "unsafe" not in result + assert "unmatched" not in result + + def test_run_with_system_prompt(self): + chat_generator = Mock() + chat_generator.run.return_value = {"replies": [ChatMessage.from_assistant("safe")]} + + system_prompt = "Classify the messages as safe or unsafe." + + router = LLMMessagesRouter( + chat_generator=chat_generator, + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + system_prompt=system_prompt, + ) + router.warm_up() + + messages = [ChatMessage.from_user("Hello")] + router.run(messages) + + chat_generator.run.assert_called_once_with(messages=[ChatMessage.from_system(system_prompt)] + messages) + + def test_run_unmatched_output(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(return_text="irrelevant"), + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + ) + + messages = [ChatMessage.from_user("Hello")] + result = router.run(messages) + + assert result["chat_generator_text"] == "irrelevant" + assert result["unmatched"] == messages + assert "safe" not in result + assert "unsafe" not in result + + def test_to_dict(self): + chat_generator = MockChatGenerator(return_text="safe") + + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + result = router.to_dict() + + assert result["type"] == "haystack.components.routers.llm_messages_router.LLMMessagesRouter" + assert result["init_parameters"]["chat_generator"] == chat_generator.to_dict() + assert result["init_parameters"]["output_names"] == ["safe", "unsafe"] + assert result["init_parameters"]["output_patterns"] == ["safe", "unsafe"] + assert result["init_parameters"]["system_prompt"] is None + + def test_from_dict(self): + chat_generator = MockChatGenerator(return_text="safe") + + data = { + "type": "haystack.components.routers.llm_messages_router.LLMMessagesRouter", + "init_parameters": { + "chat_generator": chat_generator.to_dict(), + "output_names": ["safe", "unsafe"], + "output_patterns": ["safe", "unsafe"], + "system_prompt": None, + }, + } + + router = LLMMessagesRouter.from_dict(data) + + assert router._chat_generator.to_dict() == chat_generator.to_dict() + assert router._output_names == ["safe", "unsafe"] + assert router._output_patterns == ["safe", "unsafe"] + assert router._system_prompt is None + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + def test_live_run(self): + system_prompt = "Classify the messages into safe or unsafe. Respond with the label only, no other text." + router = LLMMessagesRouter( + chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), + system_prompt=system_prompt, + output_names=["safe", "unsafe"], + output_patterns=[r"(?i)safe", r"(?i)unsafe"], + ) + + messages = [ChatMessage.from_user("Hello")] + result = router.run(messages) + print(result) + + assert result["safe"] == messages + assert result["chat_generator_text"].lower() == "safe" + assert "unsafe" not in result + assert "unmatched" not in result From c2a575bba64d80abc0e6f7e3464677bcd95d3a64 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 24 Jun 2025 14:07:08 +0200 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Daria Fokina --- .../components/routers/llm_messages_router.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/haystack/components/routers/llm_messages_router.py b/haystack/components/routers/llm_messages_router.py index 34439b218a..8af86be39f 100644 --- a/haystack/components/routers/llm_messages_router.py +++ b/haystack/components/routers/llm_messages_router.py @@ -15,7 +15,7 @@ @component class LLMMessagesRouter: """ - Routes Chat Messages to different connections, using a generative Language Model to perform classification. + Routes Chat Messages to different connections using a generative Language Model to perform classification. This component can be used with general-purpose LLMs and with specialized LLMs for moderation like Llama Guard. @@ -62,16 +62,16 @@ def __init__( """ Initialize the LLMMessagesRouter component. - :param chat_generator: a ChatGenerator instance which represents the LLM. - :param output_names: list of output connection names. These can be used to connect the router to other + :param chat_generator: A ChatGenerator instance which represents the LLM. + :param output_names: A list of output connection names. These can be used to connect the router to other components. - :param output_patterns: list of regular expressions to be matched against the output of the LLM. Each pattern + :param output_patterns: A list of regular expressions to be matched against the output of the LLM. Each pattern corresponds to an output name. Patterns are evaluated in order. When using moderation models, refer to the model card to understand the expected outputs. - :param system_prompt: optional system prompt to customize the behavior of the LLM. + :param system_prompt: An optional system prompt to customize the behavior of the LLM. For moderation models, refer to the model card for supported customization options. - :raises ValueError: if output_names and output_patterns are not non-empty lists of the same length. + :raises ValueError: If output_names and output_patterns are not non-empty lists of the same length. """ if not output_names or not output_patterns or len(output_names) != len(output_patterns): raise ValueError("`output_names` and `output_patterns` must be non-empty lists of the same length") @@ -101,22 +101,22 @@ def run(self, messages: List[ChatMessage]) -> Dict[str, Union[str, List[ChatMess """ Classify the messages based on LLM output and route them to the appropriate output connection. - :param messages: list of ChatMessages to be routed. Only user and assistant messages are supported. + :param messages: A list of ChatMessages to be routed. Only user and assistant messages are supported. :returns: A dictionary with the following keys: - - "chat_generator_text": the text output of the LLM, useful for debugging. - - output names: each contains the list of messages that matched the corresponding pattern. - - "unmatched": the messages that did not match any of the output patterns. + - "chat_generator_text": The text output of the LLM, useful for debugging. + - "output_names": Each contains the list of messages that matched the corresponding pattern. + - "unmatched": The messages that did not match any of the output patterns. - :raises ValueError: if messages is an empty list or contains messages with unsupported roles. - :raises RuntimeError: if the component is not warmed up and the ChatGenerator has a warm_up method. + :raises ValueError: If messages is an empty list or contains messages with unsupported roles. + :raises RuntimeError: If the component is not warmed up and the ChatGenerator has a warm_up method. """ if not messages: - raise ValueError("messages must be a non-empty list.") + raise ValueError("`messages` must be a non-empty list.") if not all(message.is_from(ChatRole.USER) or message.is_from(ChatRole.ASSISTANT) for message in messages): msg = ( "`messages` must contain only user and assistant messages. To customize the behavior of the " - "chat_generator, you can use the `system_prompt` parameter." + "`chat_generator`, you can use the `system_prompt` parameter." ) raise ValueError(msg)