Skip to content

Commit ad7559f

Browse files
committed
Infer types from length, map_size, tuple_size equality checks
1 parent 9815517 commit ad7559f

3 files changed

Lines changed: 162 additions & 25 deletions

File tree

lib/elixir/lib/module/types/apply.ex

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ defmodule Module.Types.Apply do
1515
# reduce the computation cost of inferred code.
1616
@max_clauses 16
1717

18+
@atom_true atom([true])
19+
@atom_false atom([false])
20+
1821
## Signatures
1922

2023
# Define strong arrows found in the standard library.
@@ -359,25 +362,15 @@ defmodule Module.Types.Apply do
359362

360363
# The functions implemented with custom do_remote functions work on values,
361364
# rather on types, hence the custom behaviour.
362-
defp do_remote(:erlang, name, [left, right] = args, _expected, expr, stack, context, of_fun)
365+
defp do_remote(:erlang, name, [left, right], expected, expr, stack, context, of_fun)
363366
when name in [:==, :"/=", :"=:=", :"=/="] do
364-
{left_type, context} = of_fun.(left, term(), expr, stack, context)
365-
{right_type, context} = of_fun.(right, term(), expr, stack, context)
366-
result = return(boolean(), [left_type, right_type], stack)
367-
368-
cond do
369-
not is_warning(stack) or Macro.quoted_literal?(args) ->
370-
{result, context}
371-
372-
name in [:==, :"/="] and number_type?(left_type) and number_type?(right_type) ->
373-
{result, context}
367+
left_literal? = Macro.quoted_literal?(left)
368+
right_literal? = Macro.quoted_literal?(right)
374369

375-
disjoint?(left_type, right_type) ->
376-
error = {:mismatched_comparison, left_type, right_type}
377-
remote_error(error, :erlang, name, 2, expr, stack, context)
378-
379-
true ->
380-
{result, context}
370+
case {left_literal?, right_literal?} do
371+
{true, false} -> custom_compare(name, right, left, expected, expr, stack, context, of_fun)
372+
{false, true} -> custom_compare(name, left, right, expected, expr, stack, context, of_fun)
373+
{literal?, _} -> compare(name, left, right, literal?, expr, stack, context, of_fun)
381374
end
382375
end
383376

@@ -479,6 +472,74 @@ defmodule Module.Types.Apply do
479472
remote_domain(mod, fun, args, expected, elem(expr, 1), stack, context)
480473
end
481474

475+
@empty_list empty_list()
476+
@non_empty_list non_empty_list(term())
477+
@empty_map empty_map()
478+
@non_empty_map difference(open_map(), empty_map())
479+
480+
defp custom_compare(
481+
name,
482+
{{:., _, [:erlang, fun]}, _, [arg]} = left,
483+
literal,
484+
expected,
485+
expr,
486+
stack,
487+
context,
488+
of_fun
489+
)
490+
when fun in [:length, :map_size, :tuple_size] and is_integer(literal) and literal >= 0 do
491+
case booleaness(expected) do
492+
:undefined ->
493+
compare(name, left, literal, false, expr, stack, context, of_fun)
494+
495+
boolean ->
496+
{polarity, return} =
497+
case boolean do
498+
:always_true -> {name in [:==, :"=:="], @atom_true}
499+
:always_false -> {name in [:"/=", :"=/="], @atom_false}
500+
end
501+
502+
expected =
503+
case fun do
504+
:length when :erlang.xor(polarity, literal > 0) -> @empty_list
505+
:length -> @non_empty_list
506+
:map_size when :erlang.xor(polarity, literal > 0) -> @empty_map
507+
:map_size -> @non_empty_map
508+
:tuple_size when polarity -> tuple(List.duplicate(term(), literal))
509+
:tuple_size -> difference(open_tuple([]), tuple(List.duplicate(term(), literal)))
510+
end
511+
512+
{actual, context} = of_fun.(arg, expected, expr, stack, context)
513+
result = if compatible?(actual, expected), do: return, else: boolean()
514+
{result, context}
515+
end
516+
end
517+
518+
defp custom_compare(name, left, right, _expected, expr, stack, context, of_fun) do
519+
compare(name, left, right, false, expr, stack, context, of_fun)
520+
end
521+
522+
defp compare(name, left, right, literal?, expr, stack, context, of_fun) do
523+
{left_type, context} = of_fun.(left, term(), expr, stack, context)
524+
{right_type, context} = of_fun.(right, term(), expr, stack, context)
525+
result = return(boolean(), [left_type, right_type], stack)
526+
527+
cond do
528+
literal? or not is_warning(stack) ->
529+
{result, context}
530+
531+
name in [:==, :"/="] and number_type?(left_type) and number_type?(right_type) ->
532+
{result, context}
533+
534+
disjoint?(left_type, right_type) ->
535+
error = {:mismatched_comparison, left_type, right_type}
536+
remote_error(error, :erlang, name, 2, expr, stack, context)
537+
538+
true ->
539+
{result, context}
540+
end
541+
end
542+
482543
@doc """
483544
Returns the domain of an unknown module.
484545

lib/elixir/lib/module/types/descr.ex

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,9 @@ defmodule Module.Types.Descr do
919919
:sets.from_list([false], version: 2)
920920
]
921921

922+
@false_atoms :sets.from_list([false], version: 2)
923+
@true_atoms :sets.from_list([true], version: 2)
924+
922925
@doc """
923926
Returns true if the type can never be true.
924927
"""
@@ -937,6 +940,30 @@ defmodule Module.Types.Descr do
937940
end
938941
end
939942

943+
@doc """
944+
Compute the booleaness of an element.
945+
946+
It is either :undefined, :always_true, or :always_false.
947+
"""
948+
def booleaness(:term), do: :undefined
949+
950+
def booleaness(%{} = descr) do
951+
descr = Map.get(descr, :dynamic, descr)
952+
953+
case descr do
954+
%{atom: {:union, set}}
955+
when map_size(descr) == 1 and set == @false_atoms ->
956+
:always_false
957+
958+
%{atom: {:union, set}}
959+
when map_size(descr) == 1 and set == @true_atoms ->
960+
:always_true
961+
962+
_ ->
963+
:undefined
964+
end
965+
end
966+
940967
@doc """
941968
Compute the truthiness of an element.
942969

lib/elixir/test/elixir/module/types/pattern_test.exs

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -723,34 +723,34 @@ defmodule Module.Types.PatternTest do
723723

724724
test "domain checks" do
725725
# Regular domain check
726-
assert typecheck!([x], length(x) == 3, x) == dynamic(list(term()))
726+
assert typecheck!([x, z], length(x) == z, x) == dynamic(list(term()))
727727

728728
# erlang-or propagates
729-
assert typecheck!([x, y], :erlang.or(length(x) == 3, map_size(y) == 1), {x, y}) ==
729+
assert typecheck!([x, y, z], :erlang.or(length(x) == z, map_size(y) == z), {x, y}) ==
730730
dynamic(tuple([list(term()), open_map()]))
731731

732732
# erlang-and propagates
733-
assert typecheck!([x, y], :erlang.and(length(x) == 3, map_size(y) == 1), {x, y}) ==
733+
assert typecheck!([x, y, z], :erlang.and(length(x) == z, map_size(y) == z), {x, y}) ==
734734
dynamic(tuple([list(term()), open_map()]))
735735

736736
# or with mixed checks
737-
assert typecheck!([x], length(x) == 3 or is_map(x), x) ==
737+
assert typecheck!([x, z], length(x) == z or is_map(x), x) ==
738738
dynamic(list(term()))
739739

740740
# or does not propagate
741-
assert typecheck!([x, y], length(x) == 3 or map_size(y) == 1, {x, y}) ==
741+
assert typecheck!([x, y, z], length(x) == z or map_size(y) == z, {x, y}) ==
742742
dynamic(tuple([list(term()), term()]))
743743

744744
# and propagates
745-
assert typecheck!([x, y], length(x) == 3 and map_size(y) == 1, {x, y}) ==
745+
assert typecheck!([x, y, z], length(x) == z and map_size(y) == z, {x, y}) ==
746746
dynamic(tuple([list(term()), open_map()]))
747747

748748
# not or does propagate
749-
assert typecheck!([x, y], not (length(x) == 3 or map_size(y) == 1), {x, y}) ==
749+
assert typecheck!([x, y, z], not (length(x) == z or map_size(y) == z), {x, y}) ==
750750
dynamic(tuple([list(term()), open_map()]))
751751

752752
# not and does not propagate
753-
assert typecheck!([x, y], not (length(x) == 3 and map_size(y) == 1), {x, y}) ==
753+
assert typecheck!([x, y, z], not (length(x) == z and map_size(y) == z), {x, y}) ==
754754
dynamic(tuple([list(term()), term()]))
755755
end
756756

@@ -772,4 +772,53 @@ defmodule Module.Types.PatternTest do
772772
"""
773773
end
774774
end
775+
776+
describe "equality in guards" do
777+
test "length" do
778+
assert typecheck!([x], length(x) != 0, x) == dynamic(non_empty_list(term()))
779+
assert typecheck!([x], not (length(x) != 0), x) == dynamic(empty_list())
780+
781+
assert typecheck!([x], 0 != length(x), x) == dynamic(non_empty_list(term()))
782+
assert typecheck!([x], not (0 != length(x)), x) == dynamic(empty_list())
783+
end
784+
785+
@non_empty_map difference(open_map(), empty_map())
786+
787+
test "map_size" do
788+
assert typecheck!([x], map_size(x) == 0, x) == dynamic(empty_map())
789+
assert typecheck!([x], map_size(x) != 0, x) == dynamic(@non_empty_map)
790+
assert typecheck!([x], not (map_size(x) == 0), x) == dynamic(@non_empty_map)
791+
assert typecheck!([x], not (map_size(x) != 0), x) == dynamic(empty_map())
792+
793+
assert typecheck!([x], 0 == map_size(x), x) == dynamic(empty_map())
794+
assert typecheck!([x], 0 != map_size(x), x) == dynamic(@non_empty_map)
795+
assert typecheck!([x], not (0 == map_size(x)), x) == dynamic(@non_empty_map)
796+
assert typecheck!([x], not (0 != map_size(x)), x) == dynamic(empty_map())
797+
end
798+
799+
@non_empty_tuple difference(open_tuple([]), tuple([]))
800+
@non_empty_binary_tuple difference(open_tuple([]), tuple([term(), term()]))
801+
802+
test "tuple_size" do
803+
assert typecheck!([x], tuple_size(x) == 0, x) == dynamic(tuple([]))
804+
assert typecheck!([x], tuple_size(x) != 0, x) == dynamic(@non_empty_tuple)
805+
assert typecheck!([x], not (tuple_size(x) == 0), x) == dynamic(@non_empty_tuple)
806+
assert typecheck!([x], not (tuple_size(x) != 0), x) == dynamic(tuple([]))
807+
808+
assert typecheck!([x], 0 == tuple_size(x), x) == dynamic(tuple([]))
809+
assert typecheck!([x], 0 != tuple_size(x), x) == dynamic(@non_empty_tuple)
810+
assert typecheck!([x], not (0 == tuple_size(x)), x) == dynamic(@non_empty_tuple)
811+
assert typecheck!([x], not (0 != tuple_size(x)), x) == dynamic(tuple([]))
812+
813+
assert typecheck!([x], tuple_size(x) == 2, x) == dynamic(tuple([term(), term()]))
814+
assert typecheck!([x], tuple_size(x) != 2, x) == dynamic(@non_empty_binary_tuple)
815+
assert typecheck!([x], not (tuple_size(x) == 2), x) == dynamic(@non_empty_binary_tuple)
816+
assert typecheck!([x], not (tuple_size(x) != 2), x) == dynamic(tuple([term(), term()]))
817+
818+
assert typecheck!([x], 2 == tuple_size(x), x) == dynamic(tuple([term(), term()]))
819+
assert typecheck!([x], 2 != tuple_size(x), x) == dynamic(@non_empty_binary_tuple)
820+
assert typecheck!([x], not (2 == tuple_size(x)), x) == dynamic(@non_empty_binary_tuple)
821+
assert typecheck!([x], not (2 != tuple_size(x)), x) == dynamic(tuple([term(), term()]))
822+
end
823+
end
775824
end

0 commit comments

Comments
 (0)