Skip to content

Commit fff97fd

Browse files
committed
Refactor shared entry points for args, match, generators
1 parent 53a5675 commit fff97fd

3 files changed

Lines changed: 45 additions & 69 deletions

File tree

lib/elixir/lib/module/types.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ defmodule Module.Types do
324324
if stack.mode == :traversal do
325325
expected
326326
else
327-
Pattern.of_domain(trees, expected, context)
327+
Pattern.of_domain(trees, context)
328328
end
329329

330330
{type_index, inferred} =

lib/elixir/lib/module/types/expr.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ defmodule Module.Types.Expr do
337337
{acc, context} =
338338
of_clauses_fun(clauses, domain, @pending, nil, :fn, stack, context, [], fn
339339
trees, body, context, acc ->
340-
args = Pattern.of_domain(trees, domain, context)
340+
args = Pattern.of_domain(trees, context)
341341
add_inferred(acc, args, body)
342342
end)
343343

lib/elixir/lib/module/types/pattern.ex

Lines changed: 43 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ defmodule Module.Types.Pattern do
2929
is refined, we restart at step 2.
3030
3131
"""
32-
def of_head(patterns, _guards, _expected, _tag, _meta, %{mode: :traversal}, context) do
32+
def of_head(patterns, _guards, expected, _tag, _meta, %{mode: :traversal}, context) do
3333
term = term()
34-
{Enum.map(patterns, &{&1, term}), context}
34+
{Enum.zip_with(patterns, expected, &{term, &2, &1}), context}
3535
end
3636

3737
def of_head(patterns, guards, expected, tag, meta, stack, context) do
@@ -45,11 +45,11 @@ defmodule Module.Types.Pattern do
4545
@doc """
4646
Computes the domain from the pattern tree and expected types.
4747
"""
48-
def of_domain([{_pattern, tree} | trees], [type | expected], context) do
49-
[intersection(of_pattern_tree(tree, context), type) | of_domain(trees, expected, context)]
48+
def of_domain([{tree, expected, _pattern} | trees], context) do
49+
[intersection(of_pattern_tree(tree, context), expected) | of_domain(trees, context)]
5050
end
5151

52-
def of_domain([], [], _context) do
52+
def of_domain([], _context) do
5353
[]
5454
end
5555

@@ -59,50 +59,27 @@ defmodule Module.Types.Pattern do
5959

6060
defp of_pattern_args(patterns, expected, tag, stack, context) do
6161
context = init_pattern_info(context)
62-
{trees, context} = of_pattern_args_index(patterns, 0, [], stack, context)
62+
{trees, context} = of_pattern_args_zip(patterns, expected, 0, [], stack, context)
6363
{pattern_info, context} = pop_pattern_info(context)
6464

6565
context =
66-
case of_pattern_args_tree(trees, expected, 0, [], tag, stack, context) do
67-
{:ok, types, context} ->
68-
of_pattern_recur(types, tag, pattern_info, stack, context)
69-
70-
{:error, context} ->
71-
error_vars(pattern_info, context)
66+
case of_pattern_intersect(trees, 0, [], pattern_info, tag, stack, context) do
67+
{:ok, _types, context} -> context
68+
{:error, context} -> context
7269
end
7370

7471
{trees, context}
7572
end
7673

77-
defp of_pattern_args_index([pattern | tail], index, acc, stack, context) do
74+
defp of_pattern_args_zip([pattern | tail], [expected | types], index, acc, stack, context) do
7875
{tree, context} = of_pattern(pattern, [%{root: {:arg, index}, expr: pattern}], stack, context)
79-
acc = [{pattern, tree} | acc]
80-
of_pattern_args_index(tail, index + 1, acc, stack, context)
76+
acc = [{tree, expected, pattern} | acc]
77+
of_pattern_args_zip(tail, types, index + 1, acc, stack, context)
8178
end
8279

83-
defp of_pattern_args_index([], _index, acc, _stack, context),
80+
defp of_pattern_args_zip([], _types, _index, acc, _stack, context),
8481
do: {Enum.reverse(acc), context}
8582

86-
defp of_pattern_args_tree(
87-
[{pattern, tree} | tail],
88-
[type | expected_types],
89-
index,
90-
acc,
91-
tag,
92-
stack,
93-
context
94-
) do
95-
with {:ok, type, context} <-
96-
of_pattern_intersect(tree, type, pattern, index, tag, stack, context) do
97-
acc = [type | acc]
98-
of_pattern_args_tree(tail, expected_types, index + 1, acc, tag, stack, context)
99-
end
100-
end
101-
102-
defp of_pattern_args_tree([], [], _index, acc, _tag, _stack, context) do
103-
{:ok, Enum.reverse(acc), context}
104-
end
105-
10683
@doc """
10784
Handles the match operator.
10885
"""
@@ -117,12 +94,14 @@ defmodule Module.Types.Pattern do
11794
{tree, context} = of_pattern(pattern, [%{root: {:arg, 0}, expr: expr}], stack, context)
11895
{pattern_info, context} = pop_pattern_info(context)
11996
{expected, context} = expected_fun.(of_pattern_tree(tree, context), context)
120-
tag = {:match, expected}
12197

122-
{[type], context} =
123-
of_single_pattern_recur(expected, tag, tree, pattern_info, expr, stack, context)
98+
args = [{tree, expected, expr}]
99+
tag = {:match, expected}
124100

125-
{type, context}
101+
case of_pattern_intersect(args, 0, [], pattern_info, tag, stack, context) do
102+
{:ok, [type], context} -> {type, context}
103+
{:error, context} -> {expected, context}
104+
end
126105
end
127106

128107
@doc """
@@ -138,26 +117,34 @@ defmodule Module.Types.Pattern do
138117
context = init_pattern_info(context)
139118
{tree, context} = of_pattern(pattern, [%{root: {:arg, 0}, expr: expr}], stack, context)
140119
{pattern_info, context} = pop_pattern_info(context)
120+
args = [{tree, expected, pattern}]
141121

142-
{_, context} =
143-
of_single_pattern_recur(expected, tag, tree, pattern_info, expr, stack, context)
122+
context =
123+
case of_pattern_intersect(args, 0, [], pattern_info, tag, stack, context) do
124+
{:ok, _types, context} -> context
125+
{:error, context} -> context
126+
end
144127

145128
{_, context} = Enum.map_reduce(guards, context, &of_guard(&1, @guard, &1, stack, &2))
146129
context
147130
end
148131

149-
defp of_single_pattern_recur(expected, tag, tree, pattern_info, expr, stack, context) do
150-
case of_pattern_intersect(tree, expected, expr, 0, tag, stack, context) do
151-
{:ok, type, context} ->
152-
{[type], of_pattern_recur([type], tag, pattern_info, stack, context)}
132+
defp of_pattern_intersect([head | tail], index, acc, pattern_info, tag, stack, context) do
133+
{tree, expected, pattern} = head
134+
actual = of_pattern_tree(tree, context)
135+
type = intersection(actual, expected)
153136

154-
{:error, context} ->
155-
{[expected], error_vars(pattern_info, context)}
137+
if empty?(type) do
138+
context = badpattern_error(pattern, index, tag, stack, context)
139+
{:error, error_vars(pattern_info, context)}
140+
else
141+
of_pattern_intersect(tail, index + 1, [type | acc], pattern_info, tag, stack, context)
156142
end
157143
end
158144

159-
defp of_pattern_recur(types, tag, pattern_info, stack, context) do
145+
defp of_pattern_intersect([], _index, acc, pattern_info, tag, stack, context) do
160146
{args_paths, vars_paths, vars_deps} = pattern_info
147+
types = Enum.reverse(acc)
161148

162149
try do
163150
args_paths
@@ -193,35 +180,35 @@ defmodule Module.Types.Pattern do
193180
{Map.merge(changed, Map.get(vars_deps, version, %{})), context}
194181
end)
195182
catch
196-
context -> error_vars(pattern_info, context)
183+
context -> {:error, error_vars(pattern_info, context)}
197184
else
198185
{changed, context} ->
199-
of_pattern_var_deps(changed, vars_paths, vars_deps, stack, context)
186+
{:ok, types, of_pattern_recur(changed, vars_paths, vars_deps, stack, context)}
200187
end
201188
end
202189

203-
defp of_pattern_var_deps(changed, _vars_paths, _vars_deps, _stack, context)
190+
defp of_pattern_recur(changed, _vars_paths, _vars_deps, _stack, context)
204191
when changed == %{} do
205192
context
206193
end
207194

208-
defp of_pattern_var_deps(previous_changed, vars_paths, vars_deps, stack, context) do
195+
defp of_pattern_recur(previous_changed, vars_paths, vars_deps, stack, context) do
209196
{changed, context} =
210197
previous_changed
211198
|> Map.keys()
212199
|> Enum.reduce({%{}, context}, fn version, {changed, context} ->
213-
{var_changed?, context} = of_pattern_var_dep(vars_paths, version, stack, context)
200+
{var_changed?, context} = of_pattern_recur_var(vars_paths, version, stack, context)
214201

215202
case var_changed? do
216203
false -> {changed, context}
217204
true -> {Map.merge(changed, Map.get(vars_deps, version, %{})), context}
218205
end
219206
end)
220207

221-
of_pattern_var_deps(changed, vars_paths, vars_deps, stack, context)
208+
of_pattern_recur(changed, vars_paths, vars_deps, stack, context)
222209
end
223210

224-
defp of_pattern_var_dep(vars_paths, version, stack, context) do
211+
defp of_pattern_recur_var(vars_paths, version, stack, context) do
225212
paths = Map.get(vars_paths, version, [])
226213

227214
case context.vars do
@@ -303,17 +290,6 @@ defmodule Module.Types.Pattern do
303290
end
304291
end
305292

306-
defp of_pattern_intersect(tree, expected, expr, index, tag, stack, context) do
307-
actual = of_pattern_tree(tree, context)
308-
type = intersection(actual, expected)
309-
310-
if empty?(type) do
311-
{:error, badpattern_error(expr, index, tag, stack, context)}
312-
else
313-
{:ok, type, context}
314-
end
315-
end
316-
317293
defp of_pattern_var([], type, _context) do
318294
{:ok, type}
319295
end

0 commit comments

Comments
 (0)