diff --git a/integrations/dspy/src/haystack_integrations/components/generators/dspy/prompt_template_adapter.py b/integrations/dspy/src/haystack_integrations/components/generators/dspy/prompt_template_adapter.py new file mode 100644 index 0000000000..672040c9b4 --- /dev/null +++ b/integrations/dspy/src/haystack_integrations/components/generators/dspy/prompt_template_adapter.py @@ -0,0 +1,26 @@ +import typing + + +class PromptTemplateAdapter: + """Makes DSPy base prompts better with custom model templates, for an enhanced text output.""" + + templates: typing.ClassVar = { + "mistral": "[INST] {prompt} [/INST]", + "llama": "[INST] {prompt} [/INST]", + "chatml": "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant", + "default": "{prompt}", + } + # Templates currently made for only 3 models as samples. This can be extended. + # (function written below + + def __init__(self, model_family: str = "default"): + if model_family not in self.templates: + self.template = self.templates["default"] + else: + self.template = self.templates[model_family] + + self.model_family = model_family + + def wrap(self, prompt: str) -> str: + """Wqrap the prompt according to the model family template.""" + return self.template.format(prompt=prompt) diff --git a/integrations/dspy/tests/test_prompt_template_adapter.py b/integrations/dspy/tests/test_prompt_template_adapter.py new file mode 100644 index 0000000000..ee829d3a6a --- /dev/null +++ b/integrations/dspy/tests/test_prompt_template_adapter.py @@ -0,0 +1,24 @@ +import sys + +sys.path.insert(0, "src/haystack_integrations/components/generators/dspy") +from prompt_template_adapter import PromptTemplateAdapter + + +def test_mistral_wraps_correctly(): + adapter = PromptTemplateAdapter(model_family="mistral") + assert adapter.wrap("Hello") == "[INST] Hello [/INST]" + + +def test_llama_wraps_correctly(): + adapter = PromptTemplateAdapter(model_family="llama") + assert adapter.wrap("Hello") == "[INST] Hello [/INST]" + + +def test_default_is_passthrough(): + adapter = PromptTemplateAdapter() + assert adapter.wrap("Hello") == "Hello" + + +def test_unknown_model_falls_back_to_default(): + adapter = PromptTemplateAdapter(model_family="unknown_model") + assert adapter.wrap("Hello") == "Hello"