diff --git a/docs/pydoc/config/routers_api.yml b/docs/pydoc/config/routers_api.yml index 126d08d7ea..834ad8ac4c 100644 --- a/docs/pydoc/config/routers_api.yml +++ b/docs/pydoc/config/routers_api.yml @@ -5,6 +5,7 @@ loaders: [ "conditional_router", "file_type_router", + "llm_messages_router", "metadata_router", "text_language_router", "transformers_text_router", diff --git a/haystack/components/routers/__init__.py b/haystack/components/routers/__init__.py index 8722134994..4886da54ae 100644 --- a/haystack/components/routers/__init__.py +++ b/haystack/components/routers/__init__.py @@ -10,6 +10,7 @@ _import_structure = { "conditional_router": ["ConditionalRouter"], "file_type_router": ["FileTypeRouter"], + "llm_messages_router": ["LLMMessagesRouter"], "metadata_router": ["MetadataRouter"], "text_language_router": ["TextLanguageRouter"], "transformers_text_router": ["TransformersTextRouter"], @@ -19,6 +20,7 @@ if TYPE_CHECKING: from .conditional_router import ConditionalRouter as ConditionalRouter from .file_type_router import FileTypeRouter as FileTypeRouter + from .llm_messages_router import LLMMessagesRouter as LLMMessagesRouter from .metadata_router import MetadataRouter as MetadataRouter from .text_language_router import TextLanguageRouter as TextLanguageRouter from .transformers_text_router import TransformersTextRouter as TransformersTextRouter diff --git a/haystack/components/routers/llm_messages_router.py b/haystack/components/routers/llm_messages_router.py new file mode 100644 index 0000000000..8af86be39f --- /dev/null +++ b/haystack/components/routers/llm_messages_router.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.chat.types import ChatGenerator +from haystack.core.serialization import component_to_dict +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.utils import deserialize_chatgenerator_inplace + + +@component +class LLMMessagesRouter: + """ + Routes Chat Messages to different connections using a generative Language Model to perform classification. + + This component can be used with general-purpose LLMs and with specialized LLMs for moderation like Llama Guard. + + ### Usage example + ```python + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.components.routers.llm_messages_router import LLMMessagesRouter + from haystack.dataclasses import ChatMessage + + # initialize a Chat Generator with a generative model for moderation + chat_generator = HuggingFaceAPIChatGenerator( + api_type="serverless_inference_api", + api_params={"model": "meta-llama/Llama-Guard-4-12B", "provider": "groq"}, + ) + + router = LLMMessagesRouter(chat_generator=chat_generator, + output_names=["unsafe", "safe"], + output_patterns=["unsafe", "safe"]) + + + print(router.run([ChatMessage.from_user("How to rob a bank?")])) + + # { + # 'chat_generator_text': 'unsafe\nS2', + # 'unsafe': [ + # ChatMessage( + # _role=, + # _content=[TextContent(text='How to rob a bank?')], + # _name=None, + # _meta={} + # ) + # ] + # } + ``` + """ + + def __init__( + self, + chat_generator: ChatGenerator, + output_names: List[str], + output_patterns: List[str], + system_prompt: Optional[str] = None, + ): + """ + Initialize the LLMMessagesRouter component. + + :param chat_generator: A ChatGenerator instance which represents the LLM. + :param output_names: A list of output connection names. These can be used to connect the router to other + components. + :param output_patterns: A list of regular expressions to be matched against the output of the LLM. Each pattern + corresponds to an output name. Patterns are evaluated in order. + When using moderation models, refer to the model card to understand the expected outputs. + :param system_prompt: An optional system prompt to customize the behavior of the LLM. + For moderation models, refer to the model card for supported customization options. + + :raises ValueError: If output_names and output_patterns are not non-empty lists of the same length. + """ + if not output_names or not output_patterns or len(output_names) != len(output_patterns): + raise ValueError("`output_names` and `output_patterns` must be non-empty lists of the same length") + + self._chat_generator = chat_generator + self._system_prompt = system_prompt + self._output_names = output_names + self._output_patterns = output_patterns + + self._compiled_patterns = [re.compile(pattern) for pattern in output_patterns] + self._is_warmed_up = False + + component.set_output_types( + self, **{"chat_generator_text": str, **dict.fromkeys(output_names + ["unmatched"], List[ChatMessage])} + ) + + def warm_up(self): + """ + Warm up the underlying LLM. + """ + if not self._is_warmed_up: + if hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + self._is_warmed_up = True + + def run(self, messages: List[ChatMessage]) -> Dict[str, Union[str, List[ChatMessage]]]: + """ + Classify the messages based on LLM output and route them to the appropriate output connection. + + :param messages: A list of ChatMessages to be routed. Only user and assistant messages are supported. + + :returns: A dictionary with the following keys: + - "chat_generator_text": The text output of the LLM, useful for debugging. + - "output_names": Each contains the list of messages that matched the corresponding pattern. + - "unmatched": The messages that did not match any of the output patterns. + + :raises ValueError: If messages is an empty list or contains messages with unsupported roles. + :raises RuntimeError: If the component is not warmed up and the ChatGenerator has a warm_up method. + """ + if not messages: + raise ValueError("`messages` must be a non-empty list.") + if not all(message.is_from(ChatRole.USER) or message.is_from(ChatRole.ASSISTANT) for message in messages): + msg = ( + "`messages` must contain only user and assistant messages. To customize the behavior of the " + "`chat_generator`, you can use the `system_prompt` parameter." + ) + raise ValueError(msg) + + if not self._is_warmed_up and hasattr(self._chat_generator, "warm_up"): + raise RuntimeError("The component is not warmed up. Please call the `warm_up` method first.") + + messages_for_inference = [] + if self._system_prompt: + messages_for_inference.append(ChatMessage.from_system(self._system_prompt)) + messages_for_inference.extend(messages) + + chat_generator_text = self._chat_generator.run(messages=messages_for_inference)["replies"][0].text + + output = {"chat_generator_text": chat_generator_text} + + for output_name, pattern in zip(self._output_names, self._compiled_patterns): + if pattern.search(chat_generator_text): + output[output_name] = messages + break + else: + output["unmatched"] = messages + + return output + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + return default_to_dict( + self, + chat_generator=component_to_dict(obj=self._chat_generator, name="chat_generator"), + output_names=self._output_names, + output_patterns=self._output_patterns, + system_prompt=self._system_prompt, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LLMMessagesRouter": + """ + Deserialize this component from a dictionary. + + :param data: + The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + if data["init_parameters"].get("chat_generator"): + deserialize_chatgenerator_inplace(data["init_parameters"], key="chat_generator") + + return default_from_dict(cls, data) diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index bafe244e33..1da03b0926 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -241,7 +241,7 @@ def check_valid_model(model_id: str, model_type: HFModelType, token: Optional[Se allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"] error_msg = f"Model {model_id} is not a embedding model. Please provide a embedding model." elif model_type == HFModelType.GENERATION: - allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation"] + allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation", "image-text-to-text"] error_msg = f"Model {model_id} is not a text generation model. Please provide a text generation model." else: allowed_model = False diff --git a/releasenotes/notes/llm-messages-router-bc0ee4e1d3a707a0.yaml b/releasenotes/notes/llm-messages-router-bc0ee4e1d3a707a0.yaml new file mode 100644 index 0000000000..3966726f0f --- /dev/null +++ b/releasenotes/notes/llm-messages-router-bc0ee4e1d3a707a0.yaml @@ -0,0 +1,26 @@ +--- +features: + - | + We introduced the `LLMMessagesRouter` component, that routes Chat Messages to different connections, using a + generative Language Model to perform classification. + + This component can be used with general-purpose LLMs and with specialized LLMs for moderation like Llama Guard. + + Usage example: + ```python + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.components.routers.llm_messages_router import LLMMessagesRouter + from haystack.dataclasses import ChatMessage + + # initialize a Chat Generator with a generative model for moderation + chat_generator = HuggingFaceAPIChatGenerator( + api_type="serverless_inference_api", + api_params={"model": "meta-llama/Llama-Guard-4-12B", "provider": "groq"}, + ) + + router = LLMMessagesRouter(chat_generator=chat_generator, + output_names=["unsafe", "safe"], + output_patterns=["unsafe", "safe"]) + + print(router.run([ChatMessage.from_user("How to rob a bank?")])) + ``` diff --git a/test/components/routers/test_llm_messages_router.py b/test/components/routers/test_llm_messages_router.py new file mode 100644 index 0000000000..abd71012e3 --- /dev/null +++ b/test/components/routers/test_llm_messages_router.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import re +import pytest +from typing import Any, Dict, List +from unittest.mock import Mock +from haystack.components.routers.llm_messages_router import LLMMessagesRouter +from haystack.dataclasses import ChatMessage +from haystack.components.generators.chat.openai import OpenAIChatGenerator +from haystack.core.serialization import default_to_dict, default_from_dict +import os + + +class MockChatGenerator: + def __init__(self, return_text: str = "safe"): + self.return_text = return_text + + def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: + return {"replies": [ChatMessage.from_assistant(self.return_text)]} + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict(self, return_text=self.return_text) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MockChatGenerator": + return default_from_dict(cls, data) + + +class TestLLMMessagesRouter: + def test_init(self): + system_prompt = "Classify the messages as safe or unsafe." + chat_generator = MockChatGenerator() + + router = LLMMessagesRouter( + chat_generator=chat_generator, + system_prompt=system_prompt, + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + ) + + assert router._chat_generator is chat_generator + assert router._system_prompt == system_prompt + assert router._output_names == ["safe", "unsafe"] + assert router._output_patterns == ["safe", "unsafe"] + assert router._compiled_patterns == [re.compile(pattern) for pattern in ["safe", "unsafe"]] + assert router._is_warmed_up is False + + def test_init_errors(self): + chat_generator = MockChatGenerator() + + with pytest.raises(ValueError): + LLMMessagesRouter(chat_generator=chat_generator, output_names=[], output_patterns=["pattern1", "pattern2"]) + + with pytest.raises(ValueError): + LLMMessagesRouter(chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=[]) + + with pytest.raises(ValueError): + LLMMessagesRouter( + chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=["pattern1"] + ) + + def test_warm_up_with_unwarmable_chat_generator(self): + chat_generator = MockChatGenerator() + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + router.warm_up() + assert router._is_warmed_up is True + + def test_warm_up_with_warmable_chat_generator(self): + chat_generator = Mock() + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + router.warm_up() + assert router._is_warmed_up is True + assert router._chat_generator.warm_up.call_count == 1 + + def test_run_input_errors(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + with pytest.raises(ValueError): + router.run([]) + + with pytest.raises(ValueError): + router.run([ChatMessage.from_system("You are a helpful assistant.")]) + + def test_run_no_warm_up_with_unwarmable_chat_generator(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + router.run([ChatMessage.from_user("Hello")]) + + def test_run_no_warm_up_with_warmable_chat_generator(self): + chat_generator = Mock() + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + with pytest.raises(RuntimeError): + router.run([ChatMessage.from_user("Hello")]) + + def test_run(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(return_text="safe"), + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + ) + + messages = [ChatMessage.from_user("Hello")] + result = router.run(messages) + + assert result["chat_generator_text"] == "safe" + assert result["safe"] == messages + assert "unsafe" not in result + assert "unmatched" not in result + + def test_run_with_system_prompt(self): + chat_generator = Mock() + chat_generator.run.return_value = {"replies": [ChatMessage.from_assistant("safe")]} + + system_prompt = "Classify the messages as safe or unsafe." + + router = LLMMessagesRouter( + chat_generator=chat_generator, + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + system_prompt=system_prompt, + ) + router.warm_up() + + messages = [ChatMessage.from_user("Hello")] + router.run(messages) + + chat_generator.run.assert_called_once_with(messages=[ChatMessage.from_system(system_prompt)] + messages) + + def test_run_unmatched_output(self): + router = LLMMessagesRouter( + chat_generator=MockChatGenerator(return_text="irrelevant"), + output_names=["safe", "unsafe"], + output_patterns=["safe", "unsafe"], + ) + + messages = [ChatMessage.from_user("Hello")] + result = router.run(messages) + + assert result["chat_generator_text"] == "irrelevant" + assert result["unmatched"] == messages + assert "safe" not in result + assert "unsafe" not in result + + def test_to_dict(self): + chat_generator = MockChatGenerator(return_text="safe") + + router = LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + result = router.to_dict() + + assert result["type"] == "haystack.components.routers.llm_messages_router.LLMMessagesRouter" + assert result["init_parameters"]["chat_generator"] == chat_generator.to_dict() + assert result["init_parameters"]["output_names"] == ["safe", "unsafe"] + assert result["init_parameters"]["output_patterns"] == ["safe", "unsafe"] + assert result["init_parameters"]["system_prompt"] is None + + def test_from_dict(self): + chat_generator = MockChatGenerator(return_text="safe") + + data = { + "type": "haystack.components.routers.llm_messages_router.LLMMessagesRouter", + "init_parameters": { + "chat_generator": chat_generator.to_dict(), + "output_names": ["safe", "unsafe"], + "output_patterns": ["safe", "unsafe"], + "system_prompt": None, + }, + } + + router = LLMMessagesRouter.from_dict(data) + + assert router._chat_generator.to_dict() == chat_generator.to_dict() + assert router._output_names == ["safe", "unsafe"] + assert router._output_patterns == ["safe", "unsafe"] + assert router._system_prompt is None + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + def test_live_run(self): + system_prompt = "Classify the messages into safe or unsafe. Respond with the label only, no other text." + router = LLMMessagesRouter( + chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), + system_prompt=system_prompt, + output_names=["safe", "unsafe"], + output_patterns=[r"(?i)safe", r"(?i)unsafe"], + ) + + messages = [ChatMessage.from_user("Hello")] + result = router.run(messages) + print(result) + + assert result["safe"] == messages + assert result["chat_generator_text"].lower() == "safe" + assert "unsafe" not in result + assert "unmatched" not in result