Skip to content

Commit 8a2c815

Browse files
committed
Propagate hd/tl checks across clauses, closes #15358
1 parent 0b9b5b9 commit 8a2c815

2 files changed

Lines changed: 49 additions & 15 deletions

File tree

lib/elixir/lib/module/types/pattern.ex

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -413,9 +413,11 @@ defmodule Module.Types.Pattern do
413413

414414
defp match_var do
415415
version = make_ref()
416-
{version, {:match, [version: version], __MODULE__}}
416+
{version, match_var(version)}
417417
end
418418

419+
defp match_var(version), do: {:match, [version: version], __MODULE__}
420+
419421
defp match_error?({:match, _, __MODULE__}, _type), do: true
420422
defp match_error?(_var, type), do: empty?(type)
421423

@@ -1021,8 +1023,8 @@ defmodule Module.Types.Pattern do
10211023
guard_context: :andalso,
10221024
parent_version: nil,
10231025
vars: vars,
1024-
changed: %{},
1025-
subpatterns: %{}
1026+
subpatterns_vars: %{},
1027+
changed: %{}
10261028
})
10271029

10281030
{precise?, context} = of_guards(guards, stack, context)
@@ -1047,7 +1049,12 @@ defmodule Module.Types.Pattern do
10471049
end)
10481050

10491051
expr = Enum.reduce(guards, {:_, [], []}, &{:when, [], [&2, &1]})
1050-
context = %{context | vars: vars, conditional_vars: conditional_vars}
1052+
1053+
context = %{
1054+
context
1055+
| vars: Map.merge(vars, context.pattern_info.subpatterns_vars),
1056+
conditional_vars: conditional_vars
1057+
}
10511058

10521059
{precise? and Of.all_same_conditional_vars?(vars_conds),
10531060
Of.reduce_conditional_vars(vars_conds, expr, stack, context)}
@@ -1068,7 +1075,7 @@ defmodule Module.Types.Pattern do
10681075
{false, context}
10691076

10701077
{true, maybe_or_always} ->
1071-
{maybe_or_always == :always and context.pattern_info.subpatterns == %{}, context}
1078+
{maybe_or_always == :always, context}
10721079

10731080
_false_tuple_or_none ->
10741081
error = {:badguard, type, guard, context}
@@ -1327,7 +1334,7 @@ defmodule Module.Types.Pattern do
13271334
call,
13281335
expected,
13291336
stack,
1330-
%{pattern_info: %{subpatterns: subpatterns}} = context
1337+
context
13311338
)
13321339
when fun in [:hd, :tl] do
13331340
arg_key =
@@ -1338,18 +1345,38 @@ defmodule Module.Types.Pattern do
13381345

13391346
subpattern_key = {fun, arg_key}
13401347

1341-
{var, context} =
1342-
case subpatterns do
1343-
%{^subpattern_key => var} ->
1344-
{var, context}
1348+
{found?, var_version, context} =
1349+
case context.subpatterns do
1350+
%{^subpattern_key => var_version} ->
1351+
case context.vars do
1352+
%{^var_version => _} -> {true, var_version, context}
1353+
%{} -> {false, var_version, context}
1354+
end
13451355

13461356
%{} ->
1347-
{type, context} =
1348-
Apply.remote(:erlang, fun, [arg], expected, call, stack, context, &of_guard/5)
1357+
var_version = make_ref()
1358+
{false, var_version, put_in(context.subpatterns[subpattern_key], var_version)}
1359+
end
1360+
1361+
var = match_var(var_version)
13491362

1350-
{_, var} = match_var()
1351-
context = Of.declare_var(var, type, context)
1352-
{var, put_in(context.pattern_info.subpatterns[subpattern_key], var)}
1363+
context =
1364+
if found? do
1365+
context
1366+
else
1367+
{type, context} =
1368+
Apply.remote(:erlang, fun, [arg], expected, call, stack, context, &of_guard/5)
1369+
1370+
context = Of.declare_var(var, type, context)
1371+
1372+
update_in(context.pattern_info, fn %{subpatterns_vars: subpatterns_vars} = pattern_info ->
1373+
%{
1374+
pattern_info
1375+
| vars: false,
1376+
subpatterns_vars:
1377+
Map.put(subpatterns_vars, var_version, Map.fetch!(context.vars, var_version))
1378+
}
1379+
end)
13531380
end
13541381

13551382
of_guard(var, expected, call, stack, context)

lib/elixir/test/elixir/module/types/pattern_test.exs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,13 @@ defmodule Module.Types.PatternTest do
773773

774774
assert typecheck!([x, y], is_binary(x) when is_atom(y), {x, y}) ==
775775
dynamic(tuple([term(), term()]))
776+
777+
# with annotated hd/tl
778+
assert typecheck!([x], is_binary(x) when is_atom(hd(x)), x) ==
779+
dynamic(union(binary(), non_empty_list(term(), term())))
780+
781+
assert typecheck!([x], is_binary(hd(x)) when is_atom(hd(x)), x) ==
782+
dynamic(non_empty_list(term(), term()))
776783
end
777784

778785
test "conditional checks (andalso/orelse)" do

0 commit comments

Comments
 (0)