Skip to content

Commit 3a90cc8

Browse files
fix(bigframes): improve error message when unescaped { are found in SQL cells (googleapis#17346)
Hints to the user that they may need to escape `{` and `}` characters by doubling them, and includes context as to where to correct such errors. Fixes internal issue b/517909919 🦕 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 384724c commit 3a90cc8

2 files changed

Lines changed: 312 additions & 7 deletions

File tree

packages/bigframes/bigframes/core/pyformat.py

Lines changed: 182 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,160 @@ def _parse_fields(sql_template: str) -> list[str]:
162162
]
163163

164164

165+
def _is_escaped_open_brace(sql_template: str, idx: int, literal_char: str) -> bool:
166+
"""Checks if the character at idx in sql_template is an escaped open brace '{{'."""
167+
return sql_template[idx : idx + 2] == "{{" and literal_char == "{"
168+
169+
170+
def _is_escaped_close_brace(sql_template: str, idx: int, literal_char: str) -> bool:
171+
"""Checks if the character at idx in sql_template is an escaped close brace '}}'."""
172+
return sql_template[idx : idx + 2] == "}}" and literal_char == "}"
173+
174+
175+
def _consume_literal(sql_template: str, current_idx: int, literal_text: str) -> int:
176+
"""Advances current_idx past literal_text in sql_template, accounting for escaped braces.
177+
178+
A **literal** (or literal text) is the static part of the template string that
179+
does not contain formatting placeholders. The string.Formatter parser resolves
180+
escaped braces ('{{' and '}}') into single braces ('{' and '}') in its output
181+
literal_text.
182+
183+
This function aligns the resolved literal_text back to the original
184+
sql_template by consuming 2 characters from sql_template ('{{' or '}}') for
185+
every single escaped brace character in literal_text, and 1 character for
186+
everything else.
187+
188+
Returns:
189+
int: the advanced current_idx in sql_template.
190+
"""
191+
lit_idx = 0
192+
while lit_idx < len(literal_text):
193+
if _is_escaped_open_brace(sql_template, current_idx, literal_text[lit_idx]):
194+
current_idx += 2
195+
lit_idx += 1
196+
elif _is_escaped_close_brace(sql_template, current_idx, literal_text[lit_idx]):
197+
current_idx += 2
198+
lit_idx += 1
199+
elif (
200+
current_idx < len(sql_template)
201+
and sql_template[current_idx] == literal_text[lit_idx]
202+
):
203+
current_idx += 1
204+
lit_idx += 1
205+
else:
206+
raise RuntimeError(
207+
"Internal error: failed to align parsed SQL template with original query. "
208+
f"Expected {literal_text[lit_idx]!r} at position {current_idx} in template, "
209+
f"but found {sql_template[current_idx : current_idx + 2]!r}."
210+
)
211+
return current_idx
212+
213+
214+
def _is_escaped_brace(sql_template: str, idx: int) -> bool:
215+
"""Checks if the template has an escaped brace ('{{' or '}}') at the given index."""
216+
return sql_template[idx : idx + 2] in ("{{", "}}")
217+
218+
219+
def _advance_past_field(sql_template: str, current_idx: int) -> int:
220+
"""Advances current_idx past the format field starting at current_idx.
221+
222+
A **field** (or replacement field) is a placeholder in the template enclosed
223+
in braces (e.g., "{my_var}" or "{json_col: { "val": 1 } }").
224+
225+
This function assumes current_idx points to the opening '{' of a field.
226+
It parses forward, tracking nested braces to find the matching closing '}'
227+
that terminates the field, while ignoring escaped braces ('{{' and '}}')
228+
which do not affect the nesting level.
229+
230+
Returns:
231+
int: the index immediately after the closing '}' of the field.
232+
"""
233+
assert sql_template[current_idx] == "{"
234+
brace_count = 1
235+
current_idx += 1 # past '{'
236+
237+
while brace_count > 0 and current_idx < len(sql_template):
238+
if _is_escaped_brace(sql_template, current_idx):
239+
current_idx += 2
240+
elif sql_template[current_idx] == "{":
241+
brace_count += 1
242+
current_idx += 1
243+
elif sql_template[current_idx] == "}":
244+
brace_count -= 1
245+
current_idx += 1
246+
else:
247+
current_idx += 1
248+
249+
return current_idx
250+
251+
252+
def _find_all_field_positions(sql_template: str) -> dict[tuple[str, int], int]:
253+
"""Finds the character positions of all fields in the sql_template.
254+
255+
Returns:
256+
dict: a dict mapping (field_name, occurrence_idx) to character index.
257+
"""
258+
formatter = string.Formatter()
259+
current_idx = 0
260+
seen_counts: dict[str, int] = {}
261+
positions: dict[tuple[str, int], int] = {}
262+
263+
for literal_text, field_name, _, _ in formatter.parse(sql_template):
264+
current_idx = _consume_literal(sql_template, current_idx, literal_text)
265+
266+
if field_name is not None:
267+
occurrence_idx = seen_counts.get(field_name, 0)
268+
seen_counts[field_name] = occurrence_idx + 1
269+
270+
positions[(field_name, occurrence_idx)] = current_idx
271+
272+
current_idx = _advance_past_field(sql_template, current_idx)
273+
274+
return positions
275+
276+
277+
def get_error_context_at_pos(sql_template: str, pos: int) -> str:
278+
"""Create a helpful 'pointer' to where the problematic position is
279+
in the original SQL.
280+
281+
This should make the error message a lot friendlier, by providing more
282+
context towards the problematic syntax.
283+
"""
284+
if pos == -1:
285+
return ""
286+
287+
lines = sql_template.splitlines(keepends=True)
288+
289+
char_count = 0
290+
target_line_idx = -1
291+
for i, line in enumerate(lines):
292+
if char_count <= pos < char_count + len(line):
293+
target_line_idx = i
294+
break
295+
char_count += len(line)
296+
297+
if target_line_idx == -1:
298+
return ""
299+
300+
col_offset = pos - char_count
301+
302+
context_lines = []
303+
start_line = max(0, target_line_idx - 2)
304+
end_line = min(len(lines), target_line_idx + 3)
305+
306+
for i in range(start_line, end_line):
307+
line_num = i + 1
308+
line_content = lines[i].rstrip("\r\n")
309+
if i == target_line_idx:
310+
context_lines.append(f"{line_num:4d}: {line_content}")
311+
indent = 6 + col_offset
312+
context_lines.append(" " * indent + "^")
313+
else:
314+
context_lines.append(f"{line_num:4d}: {line_content}")
315+
316+
return "\n".join(context_lines)
317+
318+
165319
def pyformat(
166320
sql_template: str,
167321
*,
@@ -185,13 +339,36 @@ def pyformat(
185339
186340
Raises:
187341
TypeError: if a referenced variable is not of a supported type.
188-
KeyError: if a referenced variable is not found.
342+
ValueError:
343+
if a referenced variable is not found (KeyError is caught and raised
344+
as ValueError with context).
189345
"""
190-
fields = _parse_fields(sql_template)
191-
192-
format_kwargs = {}
346+
try:
347+
fields = _parse_fields(sql_template)
348+
except ValueError as e:
349+
raise ValueError(
350+
"Failed to parse SQL template. "
351+
"Did you mean to escape '{' and '}' by doubling them?\n"
352+
f"Error details: {e}"
353+
) from e
354+
355+
format_kwargs: dict[str, str] = {}
356+
seen_counts: dict[str, int] = {}
193357
for name in fields:
194-
value = pyformat_args[name]
358+
seen_counts[name] = seen_counts.get(name, 0) + 1
359+
try:
360+
value = pyformat_args[name]
361+
except KeyError as e:
362+
positions = _find_all_field_positions(sql_template)
363+
occurrence_idx = seen_counts[name] - 1
364+
pos = positions.get((name, occurrence_idx), -1)
365+
context = get_error_context_at_pos(sql_template, pos)
366+
raise ValueError(
367+
f"Undetected variable {name!r} in SQL template. "
368+
"Did you mean to escape '{' and '}' by doubling them?\n"
369+
f"{context}"
370+
) from e
371+
195372
format_kwargs[name] = _field_to_template_value(
196373
name, value, session=session, dry_run=dry_run
197374
)

packages/bigframes/tests/unit/core/test_pyformat.py

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,72 @@ def test_parse_fields(sql_template: str, expected: List[str]):
6262
assert fields == expected
6363

6464

65+
def test_get_error_context_at_pos_invalid_pos():
66+
assert pyformat.get_error_context_at_pos("SELECT 1", -1) == ""
67+
assert pyformat.get_error_context_at_pos("SELECT 1", 100) == ""
68+
69+
70+
def test_get_error_context_at_pos_single_line():
71+
sql = "SELECT {foo}"
72+
# pos of '{' is 7
73+
context = pyformat.get_error_context_at_pos(sql, 7)
74+
expected = " 1: SELECT {foo}\n ^"
75+
assert context == expected
76+
77+
78+
def test_get_error_context_at_pos_multi_line():
79+
sql = "SELECT 1\nFROM my_table\nWHERE col = {foo}\nAND active = True\nLIMIT 10"
80+
# Lines:
81+
# 1: SELECT 1 (len 9 including \n)
82+
# 2: FROM my_table (len 14 including \n) -> total 23
83+
# 3: WHERE col = {foo} -> '{' is at 23 + 12 = 35
84+
85+
context = pyformat.get_error_context_at_pos(sql, 35)
86+
expected = (
87+
" 1: SELECT 1\n"
88+
" 2: FROM my_table\n"
89+
" 3: WHERE col = {foo}\n"
90+
" ^\n"
91+
" 4: AND active = True\n"
92+
" 5: LIMIT 10"
93+
)
94+
assert context == expected
95+
96+
97+
def test_get_error_context_at_pos_multi_line_limits():
98+
# Test that it only shows at most 2 lines before and 2 lines after
99+
sql = (
100+
"LINE 1\n"
101+
"LINE 2\n"
102+
"LINE 3\n"
103+
"LINE 4\n"
104+
"LINE 5\n"
105+
"TARGET {foo}\n"
106+
"LINE 7\n"
107+
"LINE 8\n"
108+
"LINE 9\n"
109+
"LINE 10"
110+
)
111+
# Line lengths:
112+
# LINE 1\n (7)
113+
# LINE 2\n (7) -> 14
114+
# LINE 3\n (7) -> 21
115+
# LINE 4\n (7) -> 28
116+
# LINE 5\n (7) -> 35
117+
# TARGET {foo}\n -> '{' is at 35 + 7 = 42
118+
119+
context = pyformat.get_error_context_at_pos(sql, 42)
120+
expected = (
121+
" 4: LINE 4\n"
122+
" 5: LINE 5\n"
123+
" 6: TARGET {foo}\n"
124+
" ^\n"
125+
" 7: LINE 7\n"
126+
" 8: LINE 8"
127+
)
128+
assert context == expected
129+
130+
65131
def test_pyformat_with_unsupported_type_raises_typeerror(session):
66132
pyformat_args = {"my_object": object()}
67133
sql = "SELECT {my_object}"
@@ -70,13 +136,75 @@ def test_pyformat_with_unsupported_type_raises_typeerror(session):
70136
pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session)
71137

72138

73-
def test_pyformat_with_missing_variable_raises_keyerror(session):
139+
def test_pyformat_with_missing_variable_raises_valueerror(session):
74140
pyformat_args: Dict[str, Any] = {}
75141
sql = "SELECT {my_object}"
76142

77-
with pytest.raises(KeyError, match="my_object"):
143+
with pytest.raises(ValueError) as exc_info:
78144
pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session)
79145

146+
err_msg = str(exc_info.value)
147+
assert "Undetected variable 'my_object' in SQL template" in err_msg
148+
assert "Did you mean to escape '{' and '}'" in err_msg
149+
assert " 1: SELECT {my_object}" in err_msg
150+
assert " ^" in err_msg
151+
152+
153+
def test_pyformat_with_unescaped_braces_raises_valueerror_with_context(session):
154+
pyformat_args = {"active": True}
155+
sql = """SELECT * FROM my_table
156+
WHERE json_col = { "generation_config": { "temperature": 0.9 } }
157+
AND active = {active}
158+
"""
159+
160+
with pytest.raises(ValueError) as exc_info:
161+
pyformat.pyformat(sql, pyformat_args=pyformat_args, session=session)
162+
163+
err_msg = str(exc_info.value)
164+
assert "Undetected variable ' \"generation_config\"' in SQL template" in err_msg
165+
assert "Did you mean to escape '{' and '}'" in err_msg
166+
# The triple quote string starts with SELECT immediately, so lines are:
167+
# 1: SELECT * FROM my_table
168+
# 2: WHERE json_col = { "generation_config": { "temperature": 0.9 } }
169+
# 3: AND active = {active}
170+
assert " 1: SELECT * FROM my_table" in err_msg
171+
assert (
172+
' 2: WHERE json_col = { "generation_config": { "temperature": 0.9 } }'
173+
in err_msg
174+
)
175+
assert " ^" in err_msg
176+
assert " 3: AND active = {active}" in err_msg
177+
178+
179+
@pytest.mark.parametrize(
180+
("sql_template", "expected_error"),
181+
(
182+
pytest.param(
183+
"SELECT {foo",
184+
"expected '}' before end of string",
185+
id="missing_closing_brace",
186+
),
187+
pytest.param(
188+
"SELECT foo}",
189+
"Single '}' encountered in format string",
190+
id="missing_opening_brace",
191+
),
192+
),
193+
)
194+
def test_pyformat_with_malformed_template_raises_valueerror(
195+
session, sql_template: str, expected_error: str
196+
):
197+
pyformat_args: Dict[str, Any] = {}
198+
199+
# Case 1: Single '{' (unmatched)
200+
with pytest.raises(ValueError) as exc_info:
201+
pyformat.pyformat(sql_template, pyformat_args=pyformat_args, session=session)
202+
203+
error_message = str(exc_info.value)
204+
assert "Failed to parse SQL template" in error_message
205+
assert "Did you mean to escape '{' and '}'" in error_message
206+
assert expected_error in error_message
207+
80208

81209
def test_pyformat_with_no_variables(session):
82210
pyformat_args: Dict[str, Any] = {}

0 commit comments

Comments
 (0)