diff --git a/pyrefly/lib/solver/subset.rs b/pyrefly/lib/solver/subset.rs index 20f62ecb5b..e899bd979c 100644 --- a/pyrefly/lib/solver/subset.rs +++ b/pyrefly/lib/solver/subset.rs @@ -149,6 +149,26 @@ impl SubsetWithSnapshotResult { } impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { + fn is_subset_overload_candidate( + &mut self, + got: &Type, + want: &Type, + ) -> Result<(), SubsetError> { + let vars = got + .collect_maybe_placeholder_vars() + .into_iter() + .chain(want.collect_maybe_placeholder_vars()) + .collect::>(); + match self.with_snapshot(&vars, |me| me.is_subset_eq(got, want)) { + SubsetWithSnapshotResult::Ok => Ok(()), + SubsetWithSnapshotResult::InstantiationErrors(snapshot) => { + self.solver.restore_vars(snapshot); + Ok(()) + } + SubsetWithSnapshotResult::Err(e) => Err(e), + } + } + /// Can a function with l_args be called as a function with u_args? fn is_subset_param_list( &mut self, @@ -1136,7 +1156,7 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { fn is_subset_overload(&mut self, overload: &Overload, want: &Type) -> Result<(), SubsetError> { if any(overload.signatures.iter(), |l| { - self.is_subset_eq(&l.as_type(), want) + self.is_subset_overload_candidate(&l.as_type(), want) }) .is_ok() { @@ -1188,7 +1208,7 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { ret.clone(), ))); any(overload.signatures.iter(), |l| { - self.is_subset_eq(&l.as_type(), &callable) + self.is_subset_overload_candidate(&l.as_type(), &callable) }) }) .is_ok() @@ -1640,8 +1660,15 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { metadata: _, }), ) => { - self.is_subset_params(&l.params, &u.params)?; - self.is_subset_eq(&l.ret, &u.ret) + let want_has_vars = + Type::Callable(Box::new(u.clone())).may_contain_placeholder_var(); + if want_has_vars { + self.is_subset_eq(&l.ret, &u.ret)?; + self.is_subset_params(&l.params, &u.params) + } else { + self.is_subset_params(&l.params, &u.params)?; + self.is_subset_eq(&l.ret, &u.ret) + } } (Type::TypedDict(TypedDict::Anonymous(got)), Type::TypedDict(want)) => { self.is_subset_anonymous_typed_dict(got, want) diff --git a/pyrefly/lib/test/overload.rs b/pyrefly/lib/test/overload.rs index b44b6a8381..8bdd4e46e6 100644 --- a/pyrefly/lib/test/overload.rs +++ b/pyrefly/lib/test/overload.rs @@ -441,6 +441,23 @@ bar: Callable[[int, bytes | str], int] = foo "#, ); +testcase!( + test_bound_method_overload_assignable_to_callable, + r#" +from typing import Callable, assert_type +import itertools + +_FORMAT_PATH = "{}/{}".format +_DIRECT: Callable[[str, str], str] = _FORMAT_PATH + +def get_paths(path: str) -> list[str]: + parts = path.split("/") + result = list(itertools.accumulate(parts, _FORMAT_PATH)) + assert_type(result, list[str]) + return result + "#, +); + testcase!( test_overload_assignable_to_callable_return_supertype, r#"