11"""@private"""
22
33from abc import ABC , abstractmethod
4- from typing import Optional , TypedDict , Any , Dict , Union , List
4+ from typing import Optional , TypedDict , Any , Dict , Union , List , Tuple
55import re
66
77from langfuse .api .resources .commons .types .dataset import (
@@ -54,6 +54,72 @@ class ChatMessageDict(TypedDict):
5454 content : str
5555
5656
57+ class TemplateParser :
58+ OPENING = "{{"
59+ CLOSING = "}}"
60+
61+ @staticmethod
62+ def _parse_next_variable (
63+ content : str , start_idx : int
64+ ) -> Optional [Tuple [str , int , int ]]:
65+ """Returns (variable_name, start_pos, end_pos) or None if no variable found"""
66+ var_start = content .find (TemplateParser .OPENING , start_idx )
67+ if var_start == - 1 :
68+ return None
69+
70+ var_end = content .find (TemplateParser .CLOSING , var_start )
71+ if var_end == - 1 :
72+ return None
73+
74+ variable_name = content [
75+ var_start + len (TemplateParser .OPENING ) : var_end
76+ ].strip ()
77+ return (variable_name , var_start , var_end + len (TemplateParser .CLOSING ))
78+
79+ @staticmethod
80+ def find_variable_names (content : str ) -> List [str ]:
81+ names = []
82+ curr_idx = 0
83+
84+ while curr_idx < len (content ):
85+ result = TemplateParser ._parse_next_variable (content , curr_idx )
86+ if not result :
87+ break
88+ names .append (result [0 ])
89+ curr_idx = result [2 ]
90+
91+ return names
92+
93+ @staticmethod
94+ def compile_template (content : str , data : Optional [Dict [str , Any ]] = None ) -> str :
95+ if data is None :
96+ return content
97+
98+ result_list = []
99+ curr_idx = 0
100+
101+ while curr_idx < len (content ):
102+ result = TemplateParser ._parse_next_variable (content , curr_idx )
103+
104+ if not result :
105+ result_list .append (content [curr_idx :])
106+ break
107+
108+ variable_name , var_start , var_end = result
109+ result_list .append (content [curr_idx :var_start ])
110+
111+ if variable_name in data :
112+ result_list .append (
113+ str (data [variable_name ]) if data [variable_name ] is not None else ""
114+ )
115+ else :
116+ result_list .append (content [var_start :var_end ])
117+
118+ curr_idx = var_end
119+
120+ return "" .join (result_list )
121+
122+
57123class BasePromptClient (ABC ):
58124 name : str
59125 version : int
@@ -73,6 +139,11 @@ def __init__(self, prompt: Prompt, is_fallback: bool = False):
73139 def compile (self , ** kwargs ) -> Union [str , List [ChatMessage ]]:
74140 pass
75141
142+ @property
143+ @abstractmethod
144+ def variables (self ) -> List [str ]:
145+ pass
146+
76147 @abstractmethod
77148 def __eq__ (self , other ):
78149 pass
@@ -85,55 +156,19 @@ def get_langchain_prompt(self):
85156 def _get_langchain_prompt_string (content : str ):
86157 return re .sub (r"{{\s*(\w+)\s*}}" , r"{\g<1>}" , content )
87158
88- @staticmethod
89- def _compile_template_string (content : str , data : Dict [str , Any ] = {}) -> str :
90- opening = "{{"
91- closing = "}}"
92-
93- result_list = []
94- curr_idx = 0
95-
96- while curr_idx < len (content ):
97- # Find the next opening tag
98- var_start = content .find (opening , curr_idx )
99-
100- if var_start == - 1 :
101- result_list .append (content [curr_idx :])
102- break
103-
104- # Find the next closing tag
105- var_end = content .find (closing , var_start )
106-
107- if var_end == - 1 :
108- result_list .append (content [curr_idx :])
109- break
110-
111- # Append the content before the variable
112- result_list .append (content [curr_idx :var_start ])
113-
114- # Extract the variable name
115- variable_name = content [var_start + len (opening ) : var_end ].strip ()
116-
117- # Append the variable value
118- if variable_name in data :
119- result_list .append (
120- str (data [variable_name ]) if data [variable_name ] is not None else ""
121- )
122- else :
123- result_list .append (content [var_start : var_end + len (closing )])
124-
125- curr_idx = var_end + len (closing )
126-
127- return "" .join (result_list )
128-
129159
130160class TextPromptClient (BasePromptClient ):
131161 def __init__ (self , prompt : Prompt_Text , is_fallback : bool = False ):
132162 super ().__init__ (prompt , is_fallback )
133163 self .prompt = prompt .prompt
134164
135165 def compile (self , ** kwargs ) -> str :
136- return self ._compile_template_string (self .prompt , kwargs )
166+ return TemplateParser .compile_template (self .prompt , kwargs )
167+
168+ @property
169+ def variables (self ) -> List [str ]:
170+ """Return all the variable names in the prompt template."""
171+ return TemplateParser .find_variable_names (self .prompt )
137172
138173 def __eq__ (self , other ):
139174 if isinstance (self , other .__class__ ):
@@ -160,7 +195,7 @@ def get_langchain_prompt(self, **kwargs) -> str:
160195 str: The string that can be plugged into Langchain's PromptTemplate.
161196 """
162197 prompt = (
163- self . _compile_template_string (self .prompt , kwargs )
198+ TemplateParser . compile_template (self .prompt , kwargs )
164199 if kwargs
165200 else self .prompt
166201 )
@@ -178,12 +213,23 @@ def __init__(self, prompt: Prompt_Chat, is_fallback: bool = False):
178213 def compile (self , ** kwargs ) -> List [ChatMessageDict ]:
179214 return [
180215 ChatMessageDict (
181- content = self ._compile_template_string (chat_message ["content" ], kwargs ),
216+ content = TemplateParser .compile_template (
217+ chat_message ["content" ], kwargs
218+ ),
182219 role = chat_message ["role" ],
183220 )
184221 for chat_message in self .prompt
185222 ]
186223
224+ @property
225+ def variables (self ) -> List [str ]:
226+ """Return all the variable names in the chat prompt template."""
227+ return [
228+ variable
229+ for chat_message in self .prompt
230+ for variable in TemplateParser .find_variable_names (chat_message ["content" ])
231+ ]
232+
187233 def __eq__ (self , other ):
188234 if isinstance (self , other .__class__ ):
189235 return (
@@ -215,7 +261,7 @@ def get_langchain_prompt(self, **kwargs):
215261 (
216262 msg ["role" ],
217263 self ._get_langchain_prompt_string (
218- self . _compile_template_string (msg ["content" ], kwargs )
264+ TemplateParser . compile_template (msg ["content" ], kwargs )
219265 if kwargs
220266 else msg ["content" ]
221267 ),
0 commit comments