Skip to content

Commit 9c157ae

Browse files
fix
1 parent b430a08 commit 9c157ae

3 files changed

Lines changed: 118 additions & 26 deletions

File tree

pyrefly/lib/alt/function.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
375375
defs,
376376
errors,
377377
);
378+
let sigs = if let [impl_sig] = def.ty.callable_signatures().as_slice() {
379+
self.normalize_async_generator_overloads(sigs, impl_sig)
380+
} else {
381+
sigs
382+
};
378383
self.check_signature_consistency(&sigs, &def, errors);
379384
Type::Overload(Overload {
380385
signatures: sigs.mapped(|(_, sig)| sig),
@@ -1704,6 +1709,49 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
17041709
metadata
17051710
}
17061711

1712+
fn normalize_async_generator_overload_signature(&self, sig: &Callable) -> Callable {
1713+
if let Some((_, _, ret)) = self.unwrap_coroutine(&sig.ret)
1714+
&& self.decompose_async_generator(&ret).is_some()
1715+
{
1716+
let mut sig = sig.clone();
1717+
sig.ret = ret;
1718+
sig
1719+
} else {
1720+
sig.clone()
1721+
}
1722+
}
1723+
1724+
fn normalize_async_generator_overloads(
1725+
&self,
1726+
overloads: Vec1<(TextRange, OverloadType)>,
1727+
impl_sig: &Callable,
1728+
) -> Vec1<(TextRange, OverloadType)> {
1729+
if self.decompose_async_generator(&impl_sig.ret).is_none() {
1730+
return overloads;
1731+
}
1732+
overloads.mapped(|(range, overload)| {
1733+
(
1734+
range,
1735+
match overload {
1736+
OverloadType::Function(func) => OverloadType::Function(Function {
1737+
signature: self
1738+
.normalize_async_generator_overload_signature(&func.signature),
1739+
metadata: func.metadata,
1740+
}),
1741+
OverloadType::Forall(forall) => OverloadType::Forall(Forall {
1742+
tparams: forall.tparams,
1743+
body: Function {
1744+
signature: self.normalize_async_generator_overload_signature(
1745+
&forall.body.signature,
1746+
),
1747+
metadata: forall.body.metadata,
1748+
},
1749+
}),
1750+
},
1751+
)
1752+
})
1753+
}
1754+
17071755
/// Substitute each type parameter with its upper bound, for overload consistency checking.
17081756
fn subst_function(&self, tparams: &TParams, func: Function) -> Function {
17091757
let mp = tparams

pyrefly/lib/alt/unwrap.rs

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
375375

376376
pub fn decompose_generator(&self, ty: &Type) -> Option<(Type, Type, Type)> {
377377
match ty {
378-
Type::Union(u) => {
379-
let mut yield_tys = Vec::new();
380-
let mut send_tys = Vec::new();
381-
let mut return_tys = Vec::new();
382-
for member in &u.members {
383-
let (y, s, r) = self.decompose_generator(member)?;
384-
yield_tys.push(y);
385-
send_tys.push(s);
386-
return_tys.push(r);
387-
}
378+
Type::Union(box Union { members, .. }) => {
379+
let results: Option<Vec<_>> = members
380+
.iter()
381+
.map(|member| self.decompose_generator(member))
382+
.collect();
383+
let (yield_tys, send_tys, return_tys) = results?.into_iter().fold(
384+
(Vec::new(), Vec::new(), Vec::new()),
385+
|mut acc, (yield_ty, send_ty, return_ty)| {
386+
acc.0.push(yield_ty);
387+
acc.1.push(send_ty);
388+
acc.2.push(return_ty);
389+
acc
390+
},
391+
);
388392
Some((
389393
self.unions(yield_tys),
390394
self.unions(send_tys),
@@ -417,22 +421,34 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
417421
}
418422

419423
pub fn decompose_async_generator(&self, ty: &Type) -> Option<(Type, Type)> {
420-
let yield_ty = self.fresh_var();
421-
let send_ty = self.fresh_var();
422-
let async_generator_ty = self.heap.mk_class_type(
423-
self.stdlib
424-
.async_generator(yield_ty.to_type(self.heap), send_ty.to_type(self.heap)),
425-
);
426-
if self.is_subset_eq(&async_generator_ty, ty) {
427-
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
428-
let send_ty = self
429-
.resolve_var_opt(ty, send_ty)
430-
.unwrap_or_else(|| self.heap.mk_none());
431-
Some((yield_ty, send_ty))
432-
} else if ty.is_any() {
433-
Some((self.heap.mk_any_explicit(), self.heap.mk_any_explicit()))
434-
} else {
435-
None
424+
match ty {
425+
Type::Union(box Union { members, .. }) => {
426+
let results: Option<Vec<_>> = members
427+
.iter()
428+
.map(|member| self.decompose_async_generator(member))
429+
.collect();
430+
let (yield_tys, send_tys) = results?.into_iter().unzip();
431+
Some((self.unions(yield_tys), self.unions(send_tys)))
432+
}
433+
_ => {
434+
let yield_ty = self.fresh_var();
435+
let send_ty = self.fresh_var();
436+
let async_generator_ty = self.heap.mk_class_type(
437+
self.stdlib
438+
.async_generator(yield_ty.to_type(self.heap), send_ty.to_type(self.heap)),
439+
);
440+
if self.is_subset_eq(&async_generator_ty, ty) {
441+
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
442+
let send_ty = self
443+
.resolve_var_opt(ty, send_ty)
444+
.unwrap_or_else(|| self.heap.mk_none());
445+
Some((yield_ty, send_ty))
446+
} else if ty.is_any() {
447+
Some((self.heap.mk_any_explicit(), self.heap.mk_any_explicit()))
448+
} else {
449+
None
450+
}
451+
}
436452
}
437453
}
438454

pyrefly/lib/test/overload.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,34 @@ def f[T](x: T) -> T:
834834
"#,
835835
);
836836

837+
testcase!(
838+
test_overloaded_async_iterator_impl,
839+
r#"
840+
from collections.abc import AsyncIterator
841+
from typing import Literal, assert_type, overload
842+
843+
class Patch:
844+
pass
845+
846+
class FullState(Patch):
847+
pass
848+
849+
@overload
850+
async def astream(*, diff: Literal[True] = ...) -> AsyncIterator[Patch]: ...
851+
@overload
852+
async def astream(*, diff: Literal[False]) -> AsyncIterator[FullState]: ...
853+
async def astream(*, diff: bool = True) -> AsyncIterator[Patch] | AsyncIterator[FullState]:
854+
if diff:
855+
yield Patch()
856+
else:
857+
yield FullState()
858+
859+
assert_type(astream(), AsyncIterator[Patch])
860+
assert_type(astream(diff=True), AsyncIterator[Patch])
861+
assert_type(astream(diff=False), AsyncIterator[FullState])
862+
"#,
863+
);
864+
837865
testcase!(
838866
test_param_consistency,
839867
r#"

0 commit comments

Comments
 (0)