@@ -162,6 +162,160 @@ def _parse_fields(sql_template: str) -> list[str]:
162162 ]
163163
164164
165+ def _is_escaped_open_brace (sql_template : str , idx : int , literal_char : str ) -> bool :
166+ """Checks if the character at idx in sql_template is an escaped open brace '{{'."""
167+ return sql_template [idx : idx + 2 ] == "{{" and literal_char == "{"
168+
169+
170+ def _is_escaped_close_brace (sql_template : str , idx : int , literal_char : str ) -> bool :
171+ """Checks if the character at idx in sql_template is an escaped close brace '}}'."""
172+ return sql_template [idx : idx + 2 ] == "}}" and literal_char == "}"
173+
174+
175+ def _consume_literal (sql_template : str , current_idx : int , literal_text : str ) -> int :
176+ """Advances current_idx past literal_text in sql_template, accounting for escaped braces.
177+
178+ A **literal** (or literal text) is the static part of the template string that
179+ does not contain formatting placeholders. The string.Formatter parser resolves
180+ escaped braces ('{{' and '}}') into single braces ('{' and '}') in its output
181+ literal_text.
182+
183+ This function aligns the resolved literal_text back to the original
184+ sql_template by consuming 2 characters from sql_template ('{{' or '}}') for
185+ every single escaped brace character in literal_text, and 1 character for
186+ everything else.
187+
188+ Returns:
189+ int: the advanced current_idx in sql_template.
190+ """
191+ lit_idx = 0
192+ while lit_idx < len (literal_text ):
193+ if _is_escaped_open_brace (sql_template , current_idx , literal_text [lit_idx ]):
194+ current_idx += 2
195+ lit_idx += 1
196+ elif _is_escaped_close_brace (sql_template , current_idx , literal_text [lit_idx ]):
197+ current_idx += 2
198+ lit_idx += 1
199+ elif (
200+ current_idx < len (sql_template )
201+ and sql_template [current_idx ] == literal_text [lit_idx ]
202+ ):
203+ current_idx += 1
204+ lit_idx += 1
205+ else :
206+ raise RuntimeError (
207+ "Internal error: failed to align parsed SQL template with original query. "
208+ f"Expected { literal_text [lit_idx ]!r} at position { current_idx } in template, "
209+ f"but found { sql_template [current_idx : current_idx + 2 ]!r} ."
210+ )
211+ return current_idx
212+
213+
214+ def _is_escaped_brace (sql_template : str , idx : int ) -> bool :
215+ """Checks if the template has an escaped brace ('{{' or '}}') at the given index."""
216+ return sql_template [idx : idx + 2 ] in ("{{" , "}}" )
217+
218+
219+ def _advance_past_field (sql_template : str , current_idx : int ) -> int :
220+ """Advances current_idx past the format field starting at current_idx.
221+
222+ A **field** (or replacement field) is a placeholder in the template enclosed
223+ in braces (e.g., "{my_var}" or "{json_col: { "val": 1 } }").
224+
225+ This function assumes current_idx points to the opening '{' of a field.
226+ It parses forward, tracking nested braces to find the matching closing '}'
227+ that terminates the field, while ignoring escaped braces ('{{' and '}}')
228+ which do not affect the nesting level.
229+
230+ Returns:
231+ int: the index immediately after the closing '}' of the field.
232+ """
233+ assert sql_template [current_idx ] == "{"
234+ brace_count = 1
235+ current_idx += 1 # past '{'
236+
237+ while brace_count > 0 and current_idx < len (sql_template ):
238+ if _is_escaped_brace (sql_template , current_idx ):
239+ current_idx += 2
240+ elif sql_template [current_idx ] == "{" :
241+ brace_count += 1
242+ current_idx += 1
243+ elif sql_template [current_idx ] == "}" :
244+ brace_count -= 1
245+ current_idx += 1
246+ else :
247+ current_idx += 1
248+
249+ return current_idx
250+
251+
252+ def _find_all_field_positions (sql_template : str ) -> dict [tuple [str , int ], int ]:
253+ """Finds the character positions of all fields in the sql_template.
254+
255+ Returns:
256+ dict: a dict mapping (field_name, occurrence_idx) to character index.
257+ """
258+ formatter = string .Formatter ()
259+ current_idx = 0
260+ seen_counts : dict [str , int ] = {}
261+ positions : dict [tuple [str , int ], int ] = {}
262+
263+ for literal_text , field_name , _ , _ in formatter .parse (sql_template ):
264+ current_idx = _consume_literal (sql_template , current_idx , literal_text )
265+
266+ if field_name is not None :
267+ occurrence_idx = seen_counts .get (field_name , 0 )
268+ seen_counts [field_name ] = occurrence_idx + 1
269+
270+ positions [(field_name , occurrence_idx )] = current_idx
271+
272+ current_idx = _advance_past_field (sql_template , current_idx )
273+
274+ return positions
275+
276+
277+ def get_error_context_at_pos (sql_template : str , pos : int ) -> str :
278+ """Create a helpful 'pointer' to where the problematic position is
279+ in the original SQL.
280+
281+ This should make the error message a lot friendlier, by providing more
282+ context towards the problematic syntax.
283+ """
284+ if pos == - 1 :
285+ return ""
286+
287+ lines = sql_template .splitlines (keepends = True )
288+
289+ char_count = 0
290+ target_line_idx = - 1
291+ for i , line in enumerate (lines ):
292+ if char_count <= pos < char_count + len (line ):
293+ target_line_idx = i
294+ break
295+ char_count += len (line )
296+
297+ if target_line_idx == - 1 :
298+ return ""
299+
300+ col_offset = pos - char_count
301+
302+ context_lines = []
303+ start_line = max (0 , target_line_idx - 2 )
304+ end_line = min (len (lines ), target_line_idx + 3 )
305+
306+ for i in range (start_line , end_line ):
307+ line_num = i + 1
308+ line_content = lines [i ].rstrip ("\r \n " )
309+ if i == target_line_idx :
310+ context_lines .append (f"{ line_num :4d} : { line_content } " )
311+ indent = 6 + col_offset
312+ context_lines .append (" " * indent + "^" )
313+ else :
314+ context_lines .append (f"{ line_num :4d} : { line_content } " )
315+
316+ return "\n " .join (context_lines )
317+
318+
165319def pyformat (
166320 sql_template : str ,
167321 * ,
@@ -185,13 +339,36 @@ def pyformat(
185339
186340 Raises:
187341 TypeError: if a referenced variable is not of a supported type.
188- KeyError: if a referenced variable is not found.
342+ ValueError:
343+ if a referenced variable is not found (KeyError is caught and raised
344+ as ValueError with context).
189345 """
190- fields = _parse_fields (sql_template )
191-
192- format_kwargs = {}
346+ try :
347+ fields = _parse_fields (sql_template )
348+ except ValueError as e :
349+ raise ValueError (
350+ "Failed to parse SQL template. "
351+ "Did you mean to escape '{' and '}' by doubling them?\n "
352+ f"Error details: { e } "
353+ ) from e
354+
355+ format_kwargs : dict [str , str ] = {}
356+ seen_counts : dict [str , int ] = {}
193357 for name in fields :
194- value = pyformat_args [name ]
358+ seen_counts [name ] = seen_counts .get (name , 0 ) + 1
359+ try :
360+ value = pyformat_args [name ]
361+ except KeyError as e :
362+ positions = _find_all_field_positions (sql_template )
363+ occurrence_idx = seen_counts [name ] - 1
364+ pos = positions .get ((name , occurrence_idx ), - 1 )
365+ context = get_error_context_at_pos (sql_template , pos )
366+ raise ValueError (
367+ f"Undetected variable { name !r} in SQL template. "
368+ "Did you mean to escape '{' and '}' by doubling them?\n "
369+ f"{ context } "
370+ ) from e
371+
195372 format_kwargs [name ] = _field_to_template_value (
196373 name , value , session = session , dry_run = dry_run
197374 )
0 commit comments