Skip to content

Commit c06fec3

Browse files
committed
fix: enforce fn-head nested parens at parse time
I dove way to deep into the rabbithole with these tests. They popped for me through proptests, I used LLMs to find all these variations, and I fixed them by having an LLM port Elixir's parser logic. 1. **`stab_expr` fn-clause head shapes** Ported concept: separate and normalize bare heads vs parenthesized heads before building `->` clauses. Source: [`elixir_parser.yrl` `stab_expr` rules](https://github.com/elixir-lang/elixir/blob/v1.18.2/lib/elixir/src/elixir_parser.yrl#L340-L353) 2. **`when` vs `<-` / `\\` precedence behavior in fn heads** Ported concept: preserve operator precedence model, and lower guard parsing precedence only for specific simple fn-head forms. Source: [precedence table](https://github.com/elixir-lang/elixir/blob/v1.18.2/lib/elixir/src/elixir_parser.yrl#L59-L63) 3. **`unwrap_when`-style reassociation** Ported concept: reassociate trailing `when` in head args so non-`when` operators bind to the last head arg/guard the same way compiler parser does. Source: [`unwrap_when/1`](https://github.com/elixir-lang/elixir/blob/v1.18.2/lib/elixir/src/elixir_parser.yrl#L1135-L1141) 4. **Nested-parens rejection semantics (`unexpected parentheses`)** Ported concept: reject nested parenthesized fn-head arg forms during fn-head construction, but keep Spitfire recovery by recording errors and continuing AST construction.
1 parent 19f2c88 commit c06fec3

2 files changed

Lines changed: 223 additions & 47 deletions

File tree

lib/spitfire.ex

Lines changed: 160 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,52 +1283,147 @@ defmodule Spitfire do
12831283
meta
12841284
end
12851285

1286-
lhs =
1287-
case lhs do
1288-
{:__block__, _, []} ->
1289-
[]
1286+
{lhs, parser} = normalize_stab_lhs(lhs, parser)
12901287

1291-
{:__block__, [{:parens, _} | _], [[{key, _} | _] = kw]} when is_atom(key) ->
1292-
[kw]
1288+
ast =
1289+
{token, meta, [lhs, rhs]}
12931290

1294-
{:__block__, [{:parens, _} | _], [[{{_, _, _}, _} | _] = kw]} ->
1295-
[kw]
1291+
parser = Map.put(parser, :nesting, old_nesting)
12961292

1297-
{:comma, _, lhs} ->
1298-
lhs
1293+
{ast, eat_eoe(parser)}
1294+
end
1295+
end
1296+
end
12991297

1300-
{:when, [{:parens, _} | when_meta], when_args} ->
1301-
[{:when, when_meta, when_args}]
1298+
# Normalize stab clause LHS forms into the shape expected by `->` construction.
1299+
# When parsing fn heads, this also performs grammar-level validation for
1300+
# disallowed nested parentheses while keeping recoverable AST output.
1301+
defp normalize_stab_lhs(lhs, parser) do
1302+
in_fn_head_context? = Map.get(parser, :fn_head_context?, false)
13021303

1303-
{:__block__, [{:parens, _} = paren_meta | _], exprs} ->
1304-
case exprs do
1305-
[[{key, _} | _] = kw] when is_atom(key) ->
1306-
[{:__block__, [paren_meta], [kw]}]
1304+
case lhs do
1305+
{:__block__, _, []} ->
1306+
{[], parser}
13071307

1308-
[[{{_, _, _}, _} | _] = kw] ->
1309-
[{:__block__, [paren_meta], [kw]}]
1308+
{:comma, comma_meta, comma_args} ->
1309+
parser =
1310+
if in_fn_head_context? do
1311+
parser
1312+
|> maybe_error_nested_top_parens(comma_meta)
1313+
|> maybe_error_invalid_fn_head_args(comma_args)
1314+
else
1315+
parser
1316+
end
13101317

1311-
[expr] ->
1312-
[{:__block__, [paren_meta], [expr]}]
1318+
{comma_args, parser}
13131319

1314-
_ ->
1315-
lhs
1316-
end
1320+
{:when, [{:parens, _} | when_meta], when_args} ->
1321+
{[{:when, when_meta, when_args}], parser}
13171322

1318-
lhs ->
1319-
[lhs]
1320-
end
1323+
{:__block__, [{:parens, _} = paren_meta | _], exprs} ->
1324+
case {parenthesized_kw_head(lhs), exprs} do
1325+
{{:ok, block_meta, _paren_meta, kw}, _} ->
1326+
parser =
1327+
if in_fn_head_context? do
1328+
maybe_error_nested_top_parens(parser, block_meta)
1329+
else
1330+
parser
1331+
end
13211332

1322-
ast =
1323-
{token, meta, [lhs, rhs]}
1333+
{[kw], parser}
13241334

1325-
parser = Map.put(parser, :nesting, old_nesting)
1335+
{:error, [expr]} ->
1336+
{[{:__block__, [paren_meta], [expr]}], parser}
13261337

1327-
{ast, eat_eoe(parser)}
1338+
{:error, _} ->
1339+
{lhs, parser}
1340+
end
1341+
1342+
lhs ->
1343+
{[lhs], parser}
1344+
end
1345+
end
1346+
1347+
# In fn-head context, only certain LHS shapes should lower `when` precedence.
1348+
# This allows `<-` and `\\` to be consumed by the guard expression when the
1349+
# head is explicitly grouped (empty parens, parenthesized comma args, or
1350+
# parenthesized keyword heads).
1351+
defp fn_head_simple_for_when_precedence?(lhs) do
1352+
match?({:__block__, _, []}, lhs) or
1353+
match?({:comma, [{:parens, _} | _], _}, lhs) or
1354+
match?({:ok, _block_meta, _paren_meta, _kw}, parenthesized_kw_head(lhs))
1355+
end
1356+
1357+
# Validate raw fn-head LHS before `when` normalization can flatten argument
1358+
# structure and hide nested-parens evidence.
1359+
defp maybe_error_invalid_fn_head_lhs(parser, lhs) do
1360+
case lhs do
1361+
{:comma, comma_meta, comma_args} ->
1362+
parser
1363+
|> maybe_error_nested_top_parens(comma_meta)
1364+
|> maybe_error_invalid_fn_head_args(comma_args)
1365+
1366+
lhs ->
1367+
case parenthesized_kw_head(lhs) do
1368+
{:ok, block_meta, _paren_meta, _kw} ->
1369+
maybe_error_nested_top_parens(parser, block_meta)
1370+
1371+
:error ->
1372+
parser
1373+
end
1374+
end
1375+
end
1376+
1377+
# Reject nested grouped tuple/keyword arguments inside fn heads, for example
1378+
# `fn (a, (b, c)) -> ... end` and `fn ((a, b), c) -> ... end`.
1379+
defp maybe_error_invalid_fn_head_args(parser, args) do
1380+
Enum.reduce(args, parser, fn arg, parser ->
1381+
case arg do
1382+
{:comma, [{:parens, parens_meta} | _], _} ->
1383+
put_error(parser, {nested_parens_error_meta(parens_meta), "unexpected parentheses"})
1384+
1385+
_ ->
1386+
case parenthesized_kw_head(arg) do
1387+
{:ok, _block_meta, {:parens, parens_meta}, _kw} ->
1388+
put_error(parser, {nested_parens_error_meta(parens_meta), "unexpected parentheses"})
1389+
1390+
:error ->
1391+
parser
1392+
end
13281393
end
1394+
end)
1395+
end
1396+
1397+
# Reject top-level fn-head wrappers that carry more than one parens entry,
1398+
# such as `((a, b))` and `((a: 1))`.
1399+
defp maybe_error_nested_top_parens(parser, meta) do
1400+
case Enum.at(Keyword.get_values(meta, :parens), 1) do
1401+
nil ->
1402+
parser
1403+
1404+
inner_parens ->
1405+
put_error(parser, {nested_parens_error_meta(inner_parens), "unexpected parentheses"})
13291406
end
13301407
end
13311408

1409+
# Prefer the closing paren location when available to match parser diagnostics.
1410+
defp nested_parens_error_meta(parens_meta) do
1411+
Keyword.take(parens_meta[:closing] || parens_meta, [:line, :column])
1412+
end
1413+
1414+
# Extract parenthesized keyword-list heads like `(a: 1)` or `('x': 1)` from
1415+
# grouped AST nodes so fn-head normalization can preserve paren metadata.
1416+
defp parenthesized_kw_head({:__block__, [{:parens, _} = paren_meta | _] = block_meta, [[{key, _} | _] = kw]})
1417+
when is_atom(key) do
1418+
{:ok, block_meta, paren_meta, kw}
1419+
end
1420+
1421+
defp parenthesized_kw_head({:__block__, [{:parens, _} = paren_meta | _] = block_meta, [[{{_, _, _}, _} | _] = kw]}) do
1422+
{:ok, block_meta, paren_meta, kw}
1423+
end
1424+
1425+
defp parenthesized_kw_head(_), do: :error
1426+
13321427
# Widen stab_state when outer expression is more complete than when `->` was first detected.
13331428
defp maybe_widen_stab_state(parser, ast) do
13341429
case {Map.get(parser, :stab_state), ast} do
@@ -1396,11 +1491,17 @@ defmodule Spitfire do
13961491
# e.g., `() when bar 1, 2, 3 -> foo()` should parse `bar 1, 2, 3` as the guard
13971492
{rhs, parser} =
13981493
if token == :when do
1399-
# Check if when has simple LHS (empty block or comma args).
1494+
# Check if when has simple LHS (empty block or parenthesized comma args).
14001495
# If so and we're in fn context, use lower precedence to allow <- in guard.
1401-
in_fn_context = Map.get(parser, :stop_before_stab_op?, false)
1402-
simple_lhs = match?({:__block__, _, []}, lhs) or match?({:comma, _, _}, lhs)
1403-
when_precedence = if in_fn_context and simple_lhs, do: @list_comma, else: effective_precedence
1496+
# For bare comma args (no parens), keep normal precedence so `<-` and
1497+
# `\\` bind to the trailing head argument instead of becoming part of
1498+
# the `when` guard.
1499+
# Also validate the raw fn-head lhs here before `when` flattening.
1500+
in_fn_context = Map.get(parser, :fn_head_context?, false)
1501+
parser = if in_fn_context, do: maybe_error_invalid_fn_head_lhs(parser, lhs), else: parser
1502+
1503+
when_precedence =
1504+
if in_fn_context and fn_head_simple_for_when_precedence?(lhs), do: @list_comma, else: effective_precedence
14041505

14051506
{rhs, parser} =
14061507
with_context(parser, %{stop_before_stab_op?: true}, fn parser ->
@@ -1448,26 +1549,37 @@ defmodule Spitfire do
14481549
# Empty block without parens
14491550
{token, newlines ++ meta, [rhs]}
14501551

1451-
{:__block__, [{:parens, _} = paren_meta | _], [[{key, _} | _] = kw]} when is_atom(key) ->
1452-
# (a: 1) when ... - preserve parens meta for stab
1453-
{token, [paren_meta | newlines ++ meta], [kw, rhs]}
1454-
1455-
{:__block__, [{:parens, _} = paren_meta | _], [[{{_, _, _}, _} | _] = kw]} ->
1456-
# Parenthesized kw list with interpolated key
1457-
{token, [paren_meta | newlines ++ meta], [kw, rhs]}
1458-
14591552
{:comma, [{:parens, _} = paren_meta | _], args} ->
14601553
{token, [paren_meta | newlines ++ meta], args ++ [rhs]}
14611554

14621555
{:comma, _, args} ->
14631556
{token, newlines ++ meta, args ++ [rhs]}
14641557

14651558
_ ->
1466-
{token, newlines ++ meta, [lhs, rhs]}
1559+
case parenthesized_kw_head(lhs) do
1560+
{:ok, _block_meta, paren_meta, kw} ->
1561+
# (a: 1) when ... - preserve parens meta for stab
1562+
{token, [paren_meta | newlines ++ meta], [kw, rhs]}
1563+
1564+
:error ->
1565+
{token, newlines ++ meta, [lhs, rhs]}
1566+
end
14671567
end
14681568

14691569
_ ->
1470-
{token, newlines ++ meta, [lhs, rhs]}
1570+
case lhs do
1571+
{:comma, comma_meta, args} when args != [] ->
1572+
{leading, [last]} = Enum.split(args, -1)
1573+
{:comma, comma_meta, leading ++ [{token, newlines ++ meta, [last, rhs]}]}
1574+
1575+
{:when, when_meta, when_args} when length(when_args) > 2 ->
1576+
{leading, [second_last, guard]} = Enum.split(when_args, -2)
1577+
when_node = {:when, when_meta, [second_last, guard]}
1578+
{:comma, [], leading ++ [{token, newlines ++ meta, [when_node, rhs]}]}
1579+
1580+
_ ->
1581+
{token, newlines ++ meta, [lhs, rhs]}
1582+
end
14711583
end
14721584

14731585
{ast, parser}
@@ -1491,7 +1603,8 @@ defmodule Spitfire do
14911603
(unmatched_expr?(lhs) and rhs_has_bare_comma?(rhs_parser)) do
14921604
# When the RHS of `|` has low-precedence operators (::, when, <-, \\) or
14931605
# the LHS is an unmatched_expr (do-end) and the RHS has no-parens commas,
1494-
# treat `|` as a regular pipe operator (matching Elixir's LALR grammar).
1606+
# treat `|` as a regular infix operator so RHS parsing completes before
1607+
# map-update pair extraction.
14951608
parse_infix_expression(parser, lhs)
14961609
else
14971610
{pairs, pairs_parser} = parse_map_update_pairs(rhs_parser)
@@ -2112,11 +2225,11 @@ defmodule Spitfire do
21122225
newlines = get_newlines(parser)
21132226
parser = parser |> next_token() |> eat_eoe()
21142227

2115-
# fn creates its own stab scope
2228+
# fn creates its own stab scope and enables fn-head specific validation.
21162229
parser = Map.delete(parser, :stab_state)
21172230

21182231
{exprs, parser} =
2119-
with_context(parser, %{stop_before_stab_op?: true}, fn parser ->
2232+
with_context(parser, %{stop_before_stab_op?: true, fn_head_context?: true}, fn parser ->
21202233
while2 current_token(parser) not in [:end, :eof] <- parser do
21212234
{ast, parser} =
21222235
case Map.get(parser, :stab_state) do

test/spitfire_test.exs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,42 @@ defmodule SpitfireTest do
14831483
end
14841484
end
14851485

1486+
test "fn args with parenthesized heads and low-precedence operators" do
1487+
codes = [
1488+
"fn a, u\\\\c -> :ok end",
1489+
"fn (a, 0<-c) -> :ok end",
1490+
"fn (a, b<-c<-d) -> :ok end",
1491+
"fn (a, b\\\\c\\\\d) -> :ok end",
1492+
"fn (a, b<-c) when is_integer(c) -> :ok end",
1493+
"fn (a, b\\\\c) when is_integer(c) -> :ok end",
1494+
"fn (a, b, d<-c) when is_integer(c) -> :ok end",
1495+
"fn (a, b, d\\\\c) when is_integer(c) -> :ok end",
1496+
"fn (a, b<-c) when c<-c -> :ok end",
1497+
"fn (a, b\\\\c) when c\\\\c -> :ok end",
1498+
"fn (a, b<-c<-d) when c<-c -> :ok end",
1499+
"fn (a, b\\\\c\\\\d) when c\\\\c -> :ok end",
1500+
"fn a, (b<-c) when c<-c -> :ok end",
1501+
"fn a, (b\\\\c) when c\\\\c -> :ok end",
1502+
"fn a, b<-c when c<-c -> :ok end",
1503+
"fn a, b\\\\c when c\\\\c -> :ok end",
1504+
"fn a, b when c<-c -> :ok end",
1505+
"fn a, b when c\\\\c -> :ok end",
1506+
"fn a, b when c<-c<-d -> :ok end",
1507+
"fn (a: 1) when c<-c -> :ok end",
1508+
"fn (a: 1) when c<-c<-d -> :ok end",
1509+
"fn (a: 1) when c\\\\c -> :ok end",
1510+
"fn (a: 1) when c\\\\c\\\\d -> :ok end",
1511+
"fn (x: 1, y: 2) when c<-c -> :ok end",
1512+
"fn (x: 1, y: 2) when c\\\\c -> :ok end",
1513+
"fn ('x': 1, y: 2) when c<-c -> :ok end",
1514+
"fn (x: 1, 'y': 2) when c\\\\c -> :ok end"
1515+
]
1516+
1517+
for code <- codes do
1518+
assert Spitfire.parse(code) == s2q(code)
1519+
end
1520+
end
1521+
14861522
test "capture operator" do
14871523
codes = [
14881524
~s'''
@@ -2662,6 +2698,33 @@ defmodule SpitfireTest do
26622698
}
26632699
end
26642700

2701+
test "rejects nested parenthesized fn args" do
2702+
codes = [
2703+
# whole arg list double/triple-wrapped
2704+
"fn ((a, b)) -> :ok end",
2705+
"fn (((a, b))) -> :ok end",
2706+
"fn ((a, b)) when true -> :ok end",
2707+
"fn ((a, b<-c)) -> :ok end",
2708+
"fn ((a, b\\\\c)) -> :ok end",
2709+
"fn (((a, b<-c))) -> :ok end",
2710+
# keyword list double-wrapped
2711+
"fn ((a: 1)) -> :ok end",
2712+
"fn ((a: 1)) when true -> :ok end",
2713+
# individual args as parenthesized tuples
2714+
"fn ((a, b), c) -> :ok end",
2715+
"fn (a, (b, c)) -> :ok end",
2716+
"fn ((a, b), (c, d)) -> :ok end",
2717+
"fn ((a, b), (c, d)) when true -> :ok end",
2718+
"fn ((a, (b<-c))) -> :ok end",
2719+
"fn ((a, (b\\\\c))) -> :ok end"
2720+
]
2721+
2722+
for code <- codes do
2723+
assert {:error, _} = s2q(code)
2724+
assert {:error, _} = Spitfire.parse(code)
2725+
end
2726+
end
2727+
26652728
test "example from github issue" do
26662729
code = ~S'''
26672730
defmodule Foo do

0 commit comments

Comments
 (0)