Skip to content

Commit e13d813

Browse files
committed
test: add prompt_template_adapter
1 parent 45a4834 commit e13d813

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed
Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
1+
import typing
2+
3+
14
class PromptTemplateAdapter:
2-
""" Makes DSPy base prompts better with custom model templates, for an enhanced text output. """
5+
"""Makes DSPy base prompts better with custom model templates, for an enhanced text output."""
36

4-
templates = {
7+
templates: typing.ClassVar = {
58
"mistral": "[INST] {prompt} [/INST]",
6-
"llama": "<s>[INST] {prompt} [/INST]",
7-
"chatml": "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant",
9+
"llama": "<s>[INST] {prompt} [/INST]",
10+
"chatml": "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant",
811
"default": "{prompt}",
912
}
1013
# Templates currently made for only 3 models as samples. This can be extended.
1114
# (function written below
1215

13-
1416
def __init__(self, model_family: str = "default"):
1517
if model_family not in self.templates:
1618
self.template = self.templates["default"]
1719
else:
1820
self.template = self.templates[model_family]
19-
20-
self.model_family = model_family
2121

22+
self.model_family = model_family
2223

2324
def wrap(self, prompt: str) -> str:
25+
"""Wqrap the prompt according to the model family template."""
2426
return self.template.format(prompt=prompt)
25-
26-
27-
# Makes addition of models and the templates possible.
28-
@classmethod
29-
def register_template(cls, name: str, template: str) -> None:
30-
"""Allow users to add custom model families."""
31-
cls.templates[name] = template
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import sys
2+
3+
sys.path.insert(0, "src/haystack_integrations/components/generators/dspy")
4+
from prompt_template_adapter import PromptTemplateAdapter
5+
6+
7+
def test_mistral_wraps_correctly():
8+
adapter = PromptTemplateAdapter(model_family="mistral")
9+
assert adapter.wrap("Hello") == "[INST] Hello [/INST]"
10+
11+
12+
def test_llama_wraps_correctly():
13+
adapter = PromptTemplateAdapter(model_family="llama")
14+
assert adapter.wrap("Hello") == "<s>[INST] Hello [/INST]"
15+
16+
17+
def test_default_is_passthrough():
18+
adapter = PromptTemplateAdapter()
19+
assert adapter.wrap("Hello") == "Hello"
20+
21+
22+
def test_unknown_model_falls_back_to_default():
23+
adapter = PromptTemplateAdapter(model_family="unknown_model")
24+
assert adapter.wrap("Hello") == "Hello"

0 commit comments

Comments
 (0)