Skip to content

Commit 3ce048a

Browse files
committed
add tests
1 parent 5b2a370 commit 3ce048a

2 files changed

Lines changed: 55 additions & 3 deletions

File tree

clarifai_datautils/text/prompt_factory.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
# This was taken from litellm/litellm_core_utils/prompt_templates/factory.py
1+
# This was taken from litellm
22

33
from enum import Enum
44
from typing import Any, Optional
55
import json
6-
from jinja2 import Environment
6+
from jinja2.sandbox import ImmutableSandboxedEnvironment
77

88
import requests
99

@@ -415,7 +415,7 @@ def raise_exception(message):
415415
raise Exception(f"Error message - {message}")
416416

417417
# Create a template object from the template text
418-
env = Environment()
418+
env = ImmutableSandboxedEnvironment()
419419
env.globals["raise_exception"] = raise_exception
420420
try:
421421
template = env.from_string(chat_template)
@@ -555,6 +555,11 @@ def prompt_factory(
555555
try:
556556
if "meta-llama/llama-2" in model and "chat" in model:
557557
return llama_2_chat_pt(messages=messages)
558+
elif "llama3" in model and "instruct" in model:
559+
return hf_chat_template(
560+
model="meta-llama/Meta-Llama-3-8B-Instruct",
561+
messages=messages,
562+
)
558563
elif (
559564
"tiiuae/falcon" in model
560565
): # Note: for the instruct models, it's best to use a User: .., Assistant:.. approach in your prompt template.

tests/test_prompt_factory.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# This was taken from litellm
2+
3+
from clarifai_datautils.text.prompt_factory import claude_2_1_pt, llama_2_chat_pt
4+
5+
6+
def test_codellama_prompt_format():
7+
messages = [
8+
{"role": "system", "content": "You are a good bot"},
9+
{"role": "user", "content": "Hey, how's it going?"},
10+
]
11+
expected_prompt = "<s>[INST] <<SYS>>\nYou are a good bot\n<</SYS>>\n [/INST]\n[INST] Hey, how's it going? [/INST]\n"
12+
assert llama_2_chat_pt(messages) == expected_prompt
13+
14+
def test_claude_2_1_pt_formatting():
15+
# Test case: User only, should add Assistant
16+
messages = [{"role": "user", "content": "Hello"}]
17+
expected_prompt = "\n\nHuman: Hello\n\nAssistant: "
18+
assert claude_2_1_pt(messages) == expected_prompt
19+
20+
# Test case: System, User, and Assistant "pre-fill" sequence,
21+
# Should return pre-fill
22+
messages = [
23+
{"role": "system", "content": "You are a helpful assistant."},
24+
{"role": "user", "content": 'Please return "Hello World" as a JSON object.'},
25+
{"role": "assistant", "content": "{"},
26+
]
27+
expected_prompt = 'You are a helpful assistant.\n\nHuman: Please return "Hello World" as a JSON object.\n\nAssistant: {'
28+
assert claude_2_1_pt(messages) == expected_prompt
29+
30+
# Test case: System, Assistant sequence, should insert blank Human message
31+
# before Assistant pre-fill
32+
messages = [
33+
{"role": "system", "content": "You are a storyteller."},
34+
{"role": "assistant", "content": "Once upon a time, there "},
35+
]
36+
expected_prompt = (
37+
"You are a storyteller.\n\nHuman: \n\nAssistant: Once upon a time, there "
38+
)
39+
assert claude_2_1_pt(messages) == expected_prompt
40+
41+
# Test case: System, User sequence
42+
messages = [
43+
{"role": "system", "content": "System reboot"},
44+
{"role": "user", "content": "Is everything okay?"},
45+
]
46+
expected_prompt = "System reboot\n\nHuman: Is everything okay?\n\nAssistant: "
47+
assert claude_2_1_pt(messages) == expected_prompt

0 commit comments

Comments
 (0)