Skip to content

Commit 118492e

Browse files
committed
Add type inference for literal equality in guards
1 parent bc4f917 commit 118492e

2 files changed

Lines changed: 94 additions & 2 deletions

File tree

lib/elixir/lib/module/types/apply.ex

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ defmodule Module.Types.Apply do
477477
remote_domain(mod, fun, args, expected, elem(expr, 1), stack, context)
478478
end
479479

480+
@number union(integer(), float())
480481
@empty_list empty_list()
481482
@non_empty_list non_empty_list(term())
482483
@empty_map empty_map()
@@ -526,8 +527,41 @@ defmodule Module.Types.Apply do
526527
end
527528
end
528529

529-
defp custom_compare(name, left, right, _expected, expr, stack, context, of_fun) do
530-
compare(name, left, right, false, expr, stack, context, of_fun)
530+
defp custom_compare(name, arg, literal, expected, expr, stack, context, of_fun) do
531+
{literal_type, context} = of_fun.(literal, term(), expr, stack, context)
532+
533+
case booleaness(expected) do
534+
booleaness when booleaness in [:maybe_both, :none] ->
535+
compare(name, arg, literal, false, expr, stack, context, of_fun)
536+
537+
booleaness ->
538+
{polarity, return} =
539+
case booleaness do
540+
:maybe_true -> {name in [:==, :"=:="], @atom_true}
541+
:maybe_false -> {name in [:"/=", :"=/="], @atom_false}
542+
end
543+
544+
# If it is a singleton, we can always be precise
545+
if singleton?(literal_type) do
546+
expected = if polarity, do: literal_type, else: negation(literal_type)
547+
{actual, context} = of_fun.(arg, expected, expr, stack, context)
548+
result = if compatible?(actual, expected), do: return, else: boolean()
549+
{result, context}
550+
else
551+
expected =
552+
cond do
553+
# We are checking for `not x == 1` or similar, we can't say anything about x
554+
polarity == false -> term()
555+
# We are checking for `x == 1`, make sure x is integer or float
556+
number_type?(literal_type) and name in [:==, :"/="] -> union(literal_type, @number)
557+
# Otherwise we have the literal type as is
558+
true -> literal_type
559+
end
560+
561+
{_, context} = of_fun.(arg, expected, expr, stack, context)
562+
{boolean(), context}
563+
end
564+
end
531565
end
532566

533567
defp compare(name, left, right, literal?, expr, stack, context, of_fun) do

lib/elixir/test/elixir/module/types/pattern_test.exs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,64 @@ defmodule Module.Types.PatternTest do
818818
end
819819
end
820820

821+
describe "equality in guards" do
822+
test "with non-singleton literals" do
823+
assert typecheck!([x], x == "foo", x) == dynamic(binary())
824+
assert typecheck!([x], x === "foo", x) == dynamic(binary())
825+
assert typecheck!([x], not (x == "foo"), x) == dynamic()
826+
assert typecheck!([x], not (x === "foo"), x) == dynamic()
827+
828+
assert typecheck!([x], x != "foo", x) == dynamic()
829+
assert typecheck!([x], x !== "foo", x) == dynamic()
830+
assert typecheck!([x], not (x != "foo"), x) == dynamic(binary())
831+
assert typecheck!([x], not (x !== "foo"), x) == dynamic(binary())
832+
end
833+
834+
test "with number literals" do
835+
assert typecheck!([x], x == 1, x) == dynamic(union(integer(), float()))
836+
assert typecheck!([x], x === 1, x) == dynamic(integer())
837+
assert typecheck!([x], not (x == 1), x) == dynamic()
838+
assert typecheck!([x], not (x === 1), x) == dynamic()
839+
840+
assert typecheck!([x], x != 1, x) == dynamic()
841+
assert typecheck!([x], x !== 1, x) == dynamic()
842+
assert typecheck!([x], not (x != 1), x) == dynamic(union(integer(), float()))
843+
assert typecheck!([x], not (x !== 1), x) == dynamic(integer())
844+
845+
assert typecheck!([x], x == 1.0, x) == dynamic(union(integer(), float()))
846+
assert typecheck!([x], x === 1.0, x) == dynamic(float())
847+
assert typecheck!([x], not (x == 1.0), x) == dynamic()
848+
assert typecheck!([x], not (x === 1.0), x) == dynamic()
849+
850+
assert typecheck!([x], x != 1.0, x) == dynamic()
851+
assert typecheck!([x], x !== 1.0, x) == dynamic()
852+
assert typecheck!([x], not (x != 1.0), x) == dynamic(union(integer(), float()))
853+
assert typecheck!([x], not (x !== 1.0), x) == dynamic(float())
854+
end
855+
856+
test "with singleton literals" do
857+
assert typecheck!([x], x == :foo, x) == dynamic(atom([:foo]))
858+
assert typecheck!([x], x === :foo, x) == dynamic(atom([:foo]))
859+
assert typecheck!([x], not (x == :foo), x) == dynamic(negation(atom([:foo])))
860+
assert typecheck!([x], not (x === :foo), x) == dynamic(negation(atom([:foo])))
861+
862+
assert typecheck!([x], x != :foo, x) == dynamic(negation(atom([:foo])))
863+
assert typecheck!([x], x !== :foo, x) == dynamic(negation(atom([:foo])))
864+
assert typecheck!([x], not (x != :foo), x) == dynamic(atom([:foo]))
865+
assert typecheck!([x], not (x !== :foo), x) == dynamic(atom([:foo]))
866+
867+
assert typecheck!([x], x == [], x) == dynamic(empty_list())
868+
assert typecheck!([x], x === [], x) == dynamic(empty_list())
869+
assert typecheck!([x], not (x == []), x) == dynamic(negation(empty_list()))
870+
assert typecheck!([x], not (x === []), x) == dynamic(negation(empty_list()))
871+
872+
assert typecheck!([x], x != [], x) == dynamic(negation(empty_list()))
873+
assert typecheck!([x], x !== [], x) == dynamic(negation(empty_list()))
874+
assert typecheck!([x], not (x != []), x) == dynamic(empty_list())
875+
assert typecheck!([x], not (x !== []), x) == dynamic(empty_list())
876+
end
877+
end
878+
821879
describe "comparison in guards" do
822880
test "length equality" do
823881
assert typecheck!([x], length(x) != 0, x) == dynamic(non_empty_list(term()))

0 commit comments

Comments
 (0)