Skip to content

Commit 4945318

Browse files
hassiebprhighs
andauthored
feat(prompts): add util for variable name extraction (#1046)
Co-authored-by: Roberto Montalti <37136851+rhighs@users.noreply.github.com>
1 parent 2324237 commit 4945318

File tree

3 files changed

+209
-105
lines changed

3 files changed

+209
-105
lines changed

langfuse/model.py

Lines changed: 92 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""@private"""
22

33
from abc import ABC, abstractmethod
4-
from typing import Optional, TypedDict, Any, Dict, Union, List
4+
from typing import Optional, TypedDict, Any, Dict, Union, List, Tuple
55
import re
66

77
from langfuse.api.resources.commons.types.dataset import (
@@ -54,6 +54,72 @@ class ChatMessageDict(TypedDict):
5454
content: str
5555

5656

57+
class TemplateParser:
58+
OPENING = "{{"
59+
CLOSING = "}}"
60+
61+
@staticmethod
62+
def _parse_next_variable(
63+
content: str, start_idx: int
64+
) -> Optional[Tuple[str, int, int]]:
65+
"""Returns (variable_name, start_pos, end_pos) or None if no variable found"""
66+
var_start = content.find(TemplateParser.OPENING, start_idx)
67+
if var_start == -1:
68+
return None
69+
70+
var_end = content.find(TemplateParser.CLOSING, var_start)
71+
if var_end == -1:
72+
return None
73+
74+
variable_name = content[
75+
var_start + len(TemplateParser.OPENING) : var_end
76+
].strip()
77+
return (variable_name, var_start, var_end + len(TemplateParser.CLOSING))
78+
79+
@staticmethod
80+
def find_variable_names(content: str) -> List[str]:
81+
names = []
82+
curr_idx = 0
83+
84+
while curr_idx < len(content):
85+
result = TemplateParser._parse_next_variable(content, curr_idx)
86+
if not result:
87+
break
88+
names.append(result[0])
89+
curr_idx = result[2]
90+
91+
return names
92+
93+
@staticmethod
94+
def compile_template(content: str, data: Optional[Dict[str, Any]] = None) -> str:
95+
if data is None:
96+
return content
97+
98+
result_list = []
99+
curr_idx = 0
100+
101+
while curr_idx < len(content):
102+
result = TemplateParser._parse_next_variable(content, curr_idx)
103+
104+
if not result:
105+
result_list.append(content[curr_idx:])
106+
break
107+
108+
variable_name, var_start, var_end = result
109+
result_list.append(content[curr_idx:var_start])
110+
111+
if variable_name in data:
112+
result_list.append(
113+
str(data[variable_name]) if data[variable_name] is not None else ""
114+
)
115+
else:
116+
result_list.append(content[var_start:var_end])
117+
118+
curr_idx = var_end
119+
120+
return "".join(result_list)
121+
122+
57123
class BasePromptClient(ABC):
58124
name: str
59125
version: int
@@ -73,6 +139,11 @@ def __init__(self, prompt: Prompt, is_fallback: bool = False):
73139
def compile(self, **kwargs) -> Union[str, List[ChatMessage]]:
74140
pass
75141

142+
@property
143+
@abstractmethod
144+
def variables(self) -> List[str]:
145+
pass
146+
76147
@abstractmethod
77148
def __eq__(self, other):
78149
pass
@@ -85,55 +156,19 @@ def get_langchain_prompt(self):
85156
def _get_langchain_prompt_string(content: str):
86157
return re.sub(r"{{\s*(\w+)\s*}}", r"{\g<1>}", content)
87158

88-
@staticmethod
89-
def _compile_template_string(content: str, data: Dict[str, Any] = {}) -> str:
90-
opening = "{{"
91-
closing = "}}"
92-
93-
result_list = []
94-
curr_idx = 0
95-
96-
while curr_idx < len(content):
97-
# Find the next opening tag
98-
var_start = content.find(opening, curr_idx)
99-
100-
if var_start == -1:
101-
result_list.append(content[curr_idx:])
102-
break
103-
104-
# Find the next closing tag
105-
var_end = content.find(closing, var_start)
106-
107-
if var_end == -1:
108-
result_list.append(content[curr_idx:])
109-
break
110-
111-
# Append the content before the variable
112-
result_list.append(content[curr_idx:var_start])
113-
114-
# Extract the variable name
115-
variable_name = content[var_start + len(opening) : var_end].strip()
116-
117-
# Append the variable value
118-
if variable_name in data:
119-
result_list.append(
120-
str(data[variable_name]) if data[variable_name] is not None else ""
121-
)
122-
else:
123-
result_list.append(content[var_start : var_end + len(closing)])
124-
125-
curr_idx = var_end + len(closing)
126-
127-
return "".join(result_list)
128-
129159

130160
class TextPromptClient(BasePromptClient):
131161
def __init__(self, prompt: Prompt_Text, is_fallback: bool = False):
132162
super().__init__(prompt, is_fallback)
133163
self.prompt = prompt.prompt
134164

135165
def compile(self, **kwargs) -> str:
136-
return self._compile_template_string(self.prompt, kwargs)
166+
return TemplateParser.compile_template(self.prompt, kwargs)
167+
168+
@property
169+
def variables(self) -> List[str]:
170+
"""Return all the variable names in the prompt template."""
171+
return TemplateParser.find_variable_names(self.prompt)
137172

138173
def __eq__(self, other):
139174
if isinstance(self, other.__class__):
@@ -160,7 +195,7 @@ def get_langchain_prompt(self, **kwargs) -> str:
160195
str: The string that can be plugged into Langchain's PromptTemplate.
161196
"""
162197
prompt = (
163-
self._compile_template_string(self.prompt, kwargs)
198+
TemplateParser.compile_template(self.prompt, kwargs)
164199
if kwargs
165200
else self.prompt
166201
)
@@ -178,12 +213,23 @@ def __init__(self, prompt: Prompt_Chat, is_fallback: bool = False):
178213
def compile(self, **kwargs) -> List[ChatMessageDict]:
179214
return [
180215
ChatMessageDict(
181-
content=self._compile_template_string(chat_message["content"], kwargs),
216+
content=TemplateParser.compile_template(
217+
chat_message["content"], kwargs
218+
),
182219
role=chat_message["role"],
183220
)
184221
for chat_message in self.prompt
185222
]
186223

224+
@property
225+
def variables(self) -> List[str]:
226+
"""Return all the variable names in the chat prompt template."""
227+
return [
228+
variable
229+
for chat_message in self.prompt
230+
for variable in TemplateParser.find_variable_names(chat_message["content"])
231+
]
232+
187233
def __eq__(self, other):
188234
if isinstance(self, other.__class__):
189235
return (
@@ -215,7 +261,7 @@ def get_langchain_prompt(self, **kwargs):
215261
(
216262
msg["role"],
217263
self._get_langchain_prompt_string(
218-
self._compile_template_string(msg["content"], kwargs)
264+
TemplateParser.compile_template(msg["content"], kwargs)
219265
if kwargs
220266
else msg["content"]
221267
),

tests/test_prompt.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,3 +994,100 @@ def test_do_not_link_observation_if_fallback():
994994

995995
assert len(trace.observations) == 1
996996
assert trace.observations[0].prompt_id is None
997+
998+
999+
def test_variable_names_on_content_with_variable_names():
1000+
langfuse = Langfuse()
1001+
1002+
prompt_client = langfuse.create_prompt(
1003+
name="test_variable_names_1",
1004+
prompt="test prompt with var names {{ var1 }} {{ var2 }}",
1005+
is_active=True,
1006+
type="text",
1007+
)
1008+
1009+
second_prompt_client = langfuse.get_prompt("test_variable_names_1")
1010+
1011+
assert prompt_client.name == second_prompt_client.name
1012+
assert prompt_client.version == second_prompt_client.version
1013+
assert prompt_client.prompt == second_prompt_client.prompt
1014+
assert prompt_client.labels == ["production", "latest"]
1015+
1016+
var_names = second_prompt_client.variables
1017+
1018+
assert var_names == ["var1", "var2"]
1019+
1020+
1021+
def test_variable_names_on_content_with_no_variable_names():
1022+
langfuse = Langfuse()
1023+
1024+
prompt_client = langfuse.create_prompt(
1025+
name="test_variable_names_2",
1026+
prompt="test prompt with no var names",
1027+
is_active=True,
1028+
type="text",
1029+
)
1030+
1031+
second_prompt_client = langfuse.get_prompt("test_variable_names_2")
1032+
1033+
assert prompt_client.name == second_prompt_client.name
1034+
assert prompt_client.version == second_prompt_client.version
1035+
assert prompt_client.prompt == second_prompt_client.prompt
1036+
assert prompt_client.labels == ["production", "latest"]
1037+
1038+
var_names = second_prompt_client.variables
1039+
1040+
assert var_names == []
1041+
1042+
1043+
def test_variable_names_on_content_with_variable_names_chat_messages():
1044+
langfuse = Langfuse()
1045+
1046+
prompt_client = langfuse.create_prompt(
1047+
name="test_variable_names_3",
1048+
prompt=[
1049+
{
1050+
"role": "system",
1051+
"content": "test prompt with template vars {{ var1 }} {{ var2 }}",
1052+
},
1053+
{"role": "user", "content": "test prompt 2 with template vars {{ var3 }}"},
1054+
],
1055+
is_active=True,
1056+
type="chat",
1057+
)
1058+
1059+
second_prompt_client = langfuse.get_prompt("test_variable_names_3")
1060+
1061+
assert prompt_client.name == second_prompt_client.name
1062+
assert prompt_client.version == second_prompt_client.version
1063+
assert prompt_client.prompt == second_prompt_client.prompt
1064+
assert prompt_client.labels == ["production", "latest"]
1065+
1066+
var_names = second_prompt_client.variables
1067+
1068+
assert var_names == ["var1", "var2", "var3"]
1069+
1070+
1071+
def test_variable_names_on_content_with_no_variable_names_chat_messages():
1072+
langfuse = Langfuse()
1073+
1074+
prompt_client = langfuse.create_prompt(
1075+
name="test_variable_names_4",
1076+
prompt=[
1077+
{"role": "system", "content": "test prompt with no template vars"},
1078+
{"role": "user", "content": "test prompt 2 with no template vars"},
1079+
],
1080+
is_active=True,
1081+
type="chat",
1082+
)
1083+
1084+
second_prompt_client = langfuse.get_prompt("test_variable_names_4")
1085+
1086+
assert prompt_client.name == second_prompt_client.name
1087+
assert prompt_client.version == second_prompt_client.version
1088+
assert prompt_client.prompt == second_prompt_client.prompt
1089+
assert prompt_client.labels == ["production", "latest"]
1090+
1091+
var_names = second_prompt_client.variables
1092+
1093+
assert var_names == []

0 commit comments

Comments
 (0)