Skip to content

Commit 386c5d9

Browse files
committed
Infer types from length, map_size, tuple_size comparison checks
1 parent ad7559f commit 386c5d9

2 files changed

Lines changed: 211 additions & 31 deletions

File tree

lib/elixir/lib/module/types/apply.ex

Lines changed: 98 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -374,35 +374,43 @@ defmodule Module.Types.Apply do
374374
end
375375
end
376376

377-
defp do_remote(:erlang, name, [left, right], _expected, expr, stack, context, of_fun)
377+
defp do_remote(:erlang, name, [left, right], expected, expr, stack, context, of_fun)
378378
when name in [:>=, :"=<", :>, :<, :min, :max] do
379-
{left_type, context} = of_fun.(left, term(), expr, stack, context)
380-
{right_type, context} = of_fun.(right, term(), expr, stack, context)
379+
case sized_order(name, left, right, expected) do
380+
{arg, expected, return} ->
381+
{actual, context} = of_fun.(arg, expected, expr, stack, context)
382+
result = if compatible?(actual, expected), do: return, else: boolean()
383+
{result, context}
381384

382-
result =
383-
if name in [:min, :max] do
384-
union(left_type, right_type)
385-
else
386-
return(boolean(), [left_type, right_type], stack)
387-
end
385+
:none ->
386+
{left_type, context} = of_fun.(left, term(), expr, stack, context)
387+
{right_type, context} = of_fun.(right, term(), expr, stack, context)
388388

389-
if is_warning(stack) do
390-
common = intersection(left_type, right_type)
389+
result =
390+
if name in [:min, :max] do
391+
union(left_type, right_type)
392+
else
393+
return(boolean(), [left_type, right_type], stack)
394+
end
391395

392-
cond do
393-
empty?(common) and not (number_type?(left_type) and number_type?(right_type)) ->
394-
error = {:mismatched_comparison, left_type, right_type}
395-
remote_error(error, :erlang, name, 2, expr, stack, context)
396+
if is_warning(stack) do
397+
common = intersection(left_type, right_type)
396398

397-
match?({false, _}, map_fetch_key(dynamic(common), :__struct__)) ->
398-
error = {:struct_comparison, left_type, right_type}
399-
remote_error(error, :erlang, name, 2, expr, stack, context)
399+
cond do
400+
empty?(common) and not (number_type?(left_type) and number_type?(right_type)) ->
401+
error = {:mismatched_comparison, left_type, right_type}
402+
remote_error(error, :erlang, name, 2, expr, stack, context)
400403

401-
true ->
404+
match?({false, _}, map_fetch_key(dynamic(common), :__struct__)) ->
405+
error = {:struct_comparison, left_type, right_type}
406+
remote_error(error, :erlang, name, 2, expr, stack, context)
407+
408+
true ->
409+
{result, context}
410+
end
411+
else
402412
{result, context}
403-
end
404-
else
405-
{result, context}
413+
end
406414
end
407415
end
408416

@@ -540,6 +548,74 @@ defmodule Module.Types.Apply do
540548
end
541549
end
542550

551+
defp sized_order(name, left, right, expected) do
552+
if name in [:>=, :"=<", :>, :<] do
553+
case {left, right} do
554+
{{{:., _, [:erlang, fun]}, _, [arg]}, size}
555+
when is_integer(size) and size >= 0 and fun in [:length, :map_size, :tuple_size] ->
556+
case booleaness(expected) do
557+
:always_true -> sized_order(name, fun, size, arg, @atom_true)
558+
:always_false -> sized_order(invert_order(name), fun, size, arg, @atom_false)
559+
:undefined -> :none
560+
end
561+
562+
{size, {{:., _, [:erlang, fun]}, _, [arg]}}
563+
when is_integer(size) and size >= 0 and fun in [:length, :map_size, :tuple_size] ->
564+
case booleaness(expected) do
565+
:always_true -> sized_order(invert_order(name), fun, size, arg, @atom_true)
566+
:always_false -> sized_order(name, fun, size, arg, @atom_false)
567+
:undefined -> :none
568+
end
569+
570+
_ ->
571+
:none
572+
end
573+
else
574+
:none
575+
end
576+
end
577+
578+
defp sized_order(name, fun, size, arg, return) do
579+
case expected_order(fun, name, size) do
580+
:none -> :none
581+
expected -> {arg, expected, return}
582+
end
583+
end
584+
585+
defp expected_order(_, :<, 0), do: :none
586+
587+
defp expected_order(:tuple_size, :<, size),
588+
do: difference(open_tuple([]), open_tuple(List.duplicate(term(), size)))
589+
590+
defp expected_order(:tuple_size, :"=<", 0),
591+
do: tuple([])
592+
593+
defp expected_order(:tuple_size, :"=<", size),
594+
do: difference(open_tuple([]), open_tuple(List.duplicate(term(), size + 1)))
595+
596+
defp expected_order(:tuple_size, :>, size),
597+
do: open_tuple(List.duplicate(term(), size + 1))
598+
599+
defp expected_order(:tuple_size, :>=, size),
600+
do: open_tuple(List.duplicate(term(), size))
601+
602+
defp expected_order(:map_size, :<, 1), do: @empty_map
603+
defp expected_order(:map_size, :"=<", 0), do: @empty_map
604+
defp expected_order(:map_size, :>, _), do: @non_empty_map
605+
defp expected_order(:map_size, :>=, size) when size > 0, do: @non_empty_map
606+
607+
defp expected_order(:length, :<, 1), do: @empty_list
608+
defp expected_order(:length, :"=<", 0), do: @empty_list
609+
defp expected_order(:length, :>, _), do: @non_empty_list
610+
defp expected_order(:length, :>=, size) when size > 0, do: @non_empty_list
611+
612+
defp expected_order(_, _, _), do: :none
613+
614+
defp invert_order(:>=), do: :<
615+
defp invert_order(:"=<"), do: :>
616+
defp invert_order(:>), do: :"=<"
617+
defp invert_order(:<), do: :>=
618+
543619
@doc """
544620
Returns the domain of an unknown module.
545621

lib/elixir/test/elixir/module/types/pattern_test.exs

Lines changed: 113 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -773,18 +773,51 @@ defmodule Module.Types.PatternTest do
773773
end
774774
end
775775

776-
describe "equality in guards" do
777-
test "length" do
776+
describe "comparison in guards" do
777+
test "length equality" do
778778
assert typecheck!([x], length(x) != 0, x) == dynamic(non_empty_list(term()))
779779
assert typecheck!([x], not (length(x) != 0), x) == dynamic(empty_list())
780780

781781
assert typecheck!([x], 0 != length(x), x) == dynamic(non_empty_list(term()))
782782
assert typecheck!([x], not (0 != length(x)), x) == dynamic(empty_list())
783783
end
784784

785+
test "length ordered" do
786+
assert typecheck!([x], length(x) < 0, x) == dynamic(list(term()))
787+
assert typecheck!([x], length(x) >= 0, x) == dynamic(list(term()))
788+
assert typecheck!([x], length(x) <= 0, x) == dynamic(empty_list())
789+
790+
assert typecheck!([x], 0 <= length(x), x) == dynamic(non_empty_list(term()))
791+
assert typecheck!([x], 0 >= length(x), x) == dynamic(list(term()))
792+
assert typecheck!([x], 0 < length(x), x) == dynamic(list(term()))
793+
assert typecheck!([x], 0 > length(x), x) == dynamic(empty_list())
794+
795+
assert typecheck!([x], not (length(x) > 0), x) == dynamic(empty_list())
796+
assert typecheck!([x], not (length(x) < 0), x) == dynamic(list(term()))
797+
assert typecheck!([x], not (length(x) >= 0), x) == dynamic(list(term()))
798+
assert typecheck!([x], not (length(x) <= 0), x) == dynamic(non_empty_list(term()))
799+
800+
assert typecheck!([x], length(x) < 1, x) == dynamic(empty_list())
801+
802+
assert typecheck!([x], length(x) > 2, x) == dynamic(non_empty_list(term()))
803+
assert typecheck!([x], length(x) < 2, x) == dynamic(list(term()))
804+
assert typecheck!([x], length(x) >= 2, x) == dynamic(non_empty_list(term()))
805+
assert typecheck!([x], length(x) <= 2, x) == dynamic(list(term()))
806+
807+
assert typecheck!([x], 2 <= length(x), x) == dynamic(non_empty_list(term()))
808+
assert typecheck!([x], 2 >= length(x), x) == dynamic(list(term()))
809+
assert typecheck!([x], 2 < length(x), x) == dynamic(non_empty_list(term()))
810+
assert typecheck!([x], 2 > length(x), x) == dynamic(list(term()))
811+
812+
assert typecheck!([x], not (length(x) > 2), x) == dynamic(list(term()))
813+
assert typecheck!([x], not (length(x) < 2), x) == dynamic(non_empty_list(term()))
814+
assert typecheck!([x], not (length(x) >= 2), x) == dynamic(list(term()))
815+
assert typecheck!([x], not (length(x) <= 2), x) == dynamic(non_empty_list(term()))
816+
end
817+
785818
@non_empty_map difference(open_map(), empty_map())
786819

787-
test "map_size" do
820+
test "map_size equality" do
788821
assert typecheck!([x], map_size(x) == 0, x) == dynamic(empty_map())
789822
assert typecheck!([x], map_size(x) != 0, x) == dynamic(@non_empty_map)
790823
assert typecheck!([x], not (map_size(x) == 0), x) == dynamic(@non_empty_map)
@@ -796,10 +829,49 @@ defmodule Module.Types.PatternTest do
796829
assert typecheck!([x], not (0 != map_size(x)), x) == dynamic(empty_map())
797830
end
798831

832+
test "map_size ordered" do
833+
assert typecheck!([x], map_size(x) > 0, x) == dynamic(@non_empty_map)
834+
assert typecheck!([x], map_size(x) < 0, x) == dynamic(open_map())
835+
assert typecheck!([x], map_size(x) >= 0, x) == dynamic(open_map())
836+
assert typecheck!([x], map_size(x) <= 0, x) == dynamic(empty_map())
837+
838+
assert typecheck!([x], 0 <= map_size(x), x) == dynamic(@non_empty_map)
839+
assert typecheck!([x], 0 >= map_size(x), x) == dynamic(open_map())
840+
assert typecheck!([x], 0 < map_size(x), x) == dynamic(open_map())
841+
assert typecheck!([x], 0 > map_size(x), x) == dynamic(empty_map())
842+
843+
assert typecheck!([x], not (map_size(x) > 0), x) == dynamic(empty_map())
844+
assert typecheck!([x], not (map_size(x) < 0), x) == dynamic(open_map())
845+
assert typecheck!([x], not (map_size(x) >= 0), x) == dynamic(open_map())
846+
assert typecheck!([x], not (map_size(x) <= 0), x) == dynamic(@non_empty_map)
847+
848+
assert typecheck!([x], map_size(x) < 1, x) == dynamic(empty_map())
849+
850+
assert typecheck!([x], map_size(x) > 2, x) == dynamic(@non_empty_map)
851+
assert typecheck!([x], map_size(x) < 2, x) == dynamic(open_map())
852+
assert typecheck!([x], map_size(x) >= 2, x) == dynamic(@non_empty_map)
853+
assert typecheck!([x], map_size(x) <= 2, x) == dynamic(open_map())
854+
855+
assert typecheck!([x], 2 <= map_size(x), x) == dynamic(@non_empty_map)
856+
assert typecheck!([x], 2 >= map_size(x), x) == dynamic(open_map())
857+
assert typecheck!([x], 2 < map_size(x), x) == dynamic(@non_empty_map)
858+
assert typecheck!([x], 2 > map_size(x), x) == dynamic(open_map())
859+
860+
assert typecheck!([x], not (map_size(x) > 2), x) == dynamic(open_map())
861+
assert typecheck!([x], not (map_size(x) < 2), x) == dynamic(@non_empty_map)
862+
assert typecheck!([x], not (map_size(x) >= 2), x) == dynamic(open_map())
863+
assert typecheck!([x], not (map_size(x) <= 2), x) == dynamic(@non_empty_map)
864+
end
865+
799866
@non_empty_tuple difference(open_tuple([]), tuple([]))
800-
@non_empty_binary_tuple difference(open_tuple([]), tuple([term(), term()]))
867+
@non_binary_tuple difference(open_tuple([]), tuple([term(), term()]))
868+
869+
@open_binary_tuple open_tuple([term(), term()])
870+
@open_ternary_tuple open_tuple([term(), term(), term()])
871+
@non_open_binary_tuple difference(open_tuple([]), open_tuple([term(), term()]))
872+
@non_open_ternary_tuple difference(open_tuple([]), open_tuple([term(), term(), term()]))
801873

802-
test "tuple_size" do
874+
test "tuple_size equality" do
803875
assert typecheck!([x], tuple_size(x) == 0, x) == dynamic(tuple([]))
804876
assert typecheck!([x], tuple_size(x) != 0, x) == dynamic(@non_empty_tuple)
805877
assert typecheck!([x], not (tuple_size(x) == 0), x) == dynamic(@non_empty_tuple)
@@ -811,14 +883,46 @@ defmodule Module.Types.PatternTest do
811883
assert typecheck!([x], not (0 != tuple_size(x)), x) == dynamic(tuple([]))
812884

813885
assert typecheck!([x], tuple_size(x) == 2, x) == dynamic(tuple([term(), term()]))
814-
assert typecheck!([x], tuple_size(x) != 2, x) == dynamic(@non_empty_binary_tuple)
815-
assert typecheck!([x], not (tuple_size(x) == 2), x) == dynamic(@non_empty_binary_tuple)
886+
assert typecheck!([x], tuple_size(x) != 2, x) == dynamic(@non_binary_tuple)
887+
assert typecheck!([x], not (tuple_size(x) == 2), x) == dynamic(@non_binary_tuple)
816888
assert typecheck!([x], not (tuple_size(x) != 2), x) == dynamic(tuple([term(), term()]))
817889

818890
assert typecheck!([x], 2 == tuple_size(x), x) == dynamic(tuple([term(), term()]))
819-
assert typecheck!([x], 2 != tuple_size(x), x) == dynamic(@non_empty_binary_tuple)
820-
assert typecheck!([x], not (2 == tuple_size(x)), x) == dynamic(@non_empty_binary_tuple)
891+
assert typecheck!([x], 2 != tuple_size(x), x) == dynamic(@non_binary_tuple)
892+
assert typecheck!([x], not (2 == tuple_size(x)), x) == dynamic(@non_binary_tuple)
821893
assert typecheck!([x], not (2 != tuple_size(x)), x) == dynamic(tuple([term(), term()]))
822894
end
895+
896+
test "tuple_size ordered" do
897+
assert typecheck!([x], tuple_size(x) > 0, x) == dynamic(open_tuple([term()]))
898+
assert typecheck!([x], tuple_size(x) < 0, x) == dynamic(open_tuple([]))
899+
assert typecheck!([x], tuple_size(x) >= 0, x) == dynamic(open_tuple([]))
900+
assert typecheck!([x], tuple_size(x) <= 0, x) == dynamic(tuple([]))
901+
902+
assert typecheck!([x], 0 <= tuple_size(x), x) == dynamic(open_tuple([term()]))
903+
assert typecheck!([x], 0 >= tuple_size(x), x) == dynamic(open_tuple([]))
904+
assert typecheck!([x], 0 < tuple_size(x), x) == dynamic(open_tuple([]))
905+
assert typecheck!([x], 0 > tuple_size(x), x) == dynamic(tuple([]))
906+
907+
assert typecheck!([x], not (tuple_size(x) > 0), x) == dynamic(tuple([]))
908+
assert typecheck!([x], not (tuple_size(x) < 0), x) == dynamic(open_tuple([]))
909+
assert typecheck!([x], not (tuple_size(x) >= 0), x) == dynamic(open_tuple([]))
910+
assert typecheck!([x], not (tuple_size(x) <= 0), x) == dynamic(open_tuple([term()]))
911+
912+
assert typecheck!([x], tuple_size(x) > 2, x) == dynamic(@open_ternary_tuple)
913+
assert typecheck!([x], tuple_size(x) < 2, x) == dynamic(@non_open_binary_tuple)
914+
assert typecheck!([x], tuple_size(x) >= 2, x) == dynamic(@open_binary_tuple)
915+
assert typecheck!([x], tuple_size(x) <= 2, x) == dynamic(@non_open_ternary_tuple)
916+
917+
assert typecheck!([x], 2 <= tuple_size(x), x) == dynamic(@open_ternary_tuple)
918+
assert typecheck!([x], 2 >= tuple_size(x), x) == dynamic(@non_open_binary_tuple)
919+
assert typecheck!([x], 2 < tuple_size(x), x) == dynamic(@open_binary_tuple)
920+
assert typecheck!([x], 2 > tuple_size(x), x) == dynamic(@non_open_ternary_tuple)
921+
922+
assert typecheck!([x], not (tuple_size(x) > 2), x) == dynamic(@non_open_ternary_tuple)
923+
assert typecheck!([x], not (tuple_size(x) < 2), x) == dynamic(@open_binary_tuple)
924+
assert typecheck!([x], not (tuple_size(x) >= 2), x) == dynamic(@non_open_binary_tuple)
925+
assert typecheck!([x], not (tuple_size(x) <= 2), x) == dynamic(@open_ternary_tuple)
926+
end
823927
end
824928
end

0 commit comments

Comments
 (0)