Skip to content

Commit 85f0a3f

Browse files
fix
1 parent 947a3ee commit 85f0a3f

3 files changed

Lines changed: 118 additions & 26 deletions

File tree

pyrefly/lib/alt/function.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
377377
defs,
378378
errors,
379379
);
380+
let sigs = if let [impl_sig] = def.ty.callable_signatures().as_slice() {
381+
self.normalize_async_generator_overloads(sigs, impl_sig)
382+
} else {
383+
sigs
384+
};
380385
self.check_signature_consistency(&sigs, &def, errors);
381386
Type::Overload(Overload {
382387
signatures: sigs.mapped(|(_, sig)| sig),
@@ -1764,6 +1769,49 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
17641769
metadata
17651770
}
17661771

1772+
fn normalize_async_generator_overload_signature(&self, sig: &Callable) -> Callable {
1773+
if let Some((_, _, ret)) = self.unwrap_coroutine(&sig.ret)
1774+
&& self.decompose_async_generator(&ret).is_some()
1775+
{
1776+
let mut sig = sig.clone();
1777+
sig.ret = ret;
1778+
sig
1779+
} else {
1780+
sig.clone()
1781+
}
1782+
}
1783+
1784+
fn normalize_async_generator_overloads(
1785+
&self,
1786+
overloads: Vec1<(TextRange, OverloadType)>,
1787+
impl_sig: &Callable,
1788+
) -> Vec1<(TextRange, OverloadType)> {
1789+
if self.decompose_async_generator(&impl_sig.ret).is_none() {
1790+
return overloads;
1791+
}
1792+
overloads.mapped(|(range, overload)| {
1793+
(
1794+
range,
1795+
match overload {
1796+
OverloadType::Function(func) => OverloadType::Function(Function {
1797+
signature: self
1798+
.normalize_async_generator_overload_signature(&func.signature),
1799+
metadata: func.metadata,
1800+
}),
1801+
OverloadType::Forall(forall) => OverloadType::Forall(Forall {
1802+
tparams: forall.tparams,
1803+
body: Function {
1804+
signature: self.normalize_async_generator_overload_signature(
1805+
&forall.body.signature,
1806+
),
1807+
metadata: forall.body.metadata,
1808+
},
1809+
}),
1810+
},
1811+
)
1812+
})
1813+
}
1814+
17671815
/// Substitute each type parameter with its upper bound, for overload consistency checking.
17681816
fn subst_function(&self, tparams: &TParams, func: Function) -> Function {
17691817
let mp = tparams

pyrefly/lib/alt/unwrap.rs

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
360360

361361
pub fn decompose_generator(&self, ty: &Type) -> Option<(Type, Type, Type)> {
362362
match ty {
363-
Type::Union(u) => {
364-
let mut yield_tys = Vec::new();
365-
let mut send_tys = Vec::new();
366-
let mut return_tys = Vec::new();
367-
for member in &u.members {
368-
let (y, s, r) = self.decompose_generator(member)?;
369-
yield_tys.push(y);
370-
send_tys.push(s);
371-
return_tys.push(r);
372-
}
363+
Type::Union(box Union { members, .. }) => {
364+
let results: Option<Vec<_>> = members
365+
.iter()
366+
.map(|member| self.decompose_generator(member))
367+
.collect();
368+
let (yield_tys, send_tys, return_tys) = results?.into_iter().fold(
369+
(Vec::new(), Vec::new(), Vec::new()),
370+
|mut acc, (yield_ty, send_ty, return_ty)| {
371+
acc.0.push(yield_ty);
372+
acc.1.push(send_ty);
373+
acc.2.push(return_ty);
374+
acc
375+
},
376+
);
373377
Some((
374378
self.unions(yield_tys),
375379
self.unions(send_tys),
@@ -402,22 +406,34 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
402406
}
403407

404408
pub fn decompose_async_generator(&self, ty: &Type) -> Option<(Type, Type)> {
405-
let yield_ty = self.fresh_var();
406-
let send_ty = self.fresh_var();
407-
let async_generator_ty = self.heap.mk_class_type(
408-
self.stdlib
409-
.async_generator(yield_ty.to_type(self.heap), send_ty.to_type(self.heap)),
410-
);
411-
if self.is_subset_eq(&async_generator_ty, ty) {
412-
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
413-
let send_ty = self
414-
.resolve_var_opt(ty, send_ty)
415-
.unwrap_or_else(|| self.heap.mk_none());
416-
Some((yield_ty, send_ty))
417-
} else if ty.is_any() {
418-
Some((self.heap.mk_any_explicit(), self.heap.mk_any_explicit()))
419-
} else {
420-
None
409+
match ty {
410+
Type::Union(box Union { members, .. }) => {
411+
let results: Option<Vec<_>> = members
412+
.iter()
413+
.map(|member| self.decompose_async_generator(member))
414+
.collect();
415+
let (yield_tys, send_tys) = results?.into_iter().unzip();
416+
Some((self.unions(yield_tys), self.unions(send_tys)))
417+
}
418+
_ => {
419+
let yield_ty = self.fresh_var();
420+
let send_ty = self.fresh_var();
421+
let async_generator_ty = self.heap.mk_class_type(
422+
self.stdlib
423+
.async_generator(yield_ty.to_type(self.heap), send_ty.to_type(self.heap)),
424+
);
425+
if self.is_subset_eq(&async_generator_ty, ty) {
426+
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
427+
let send_ty = self
428+
.resolve_var_opt(ty, send_ty)
429+
.unwrap_or_else(|| self.heap.mk_none());
430+
Some((yield_ty, send_ty))
431+
} else if ty.is_any() {
432+
Some((self.heap.mk_any_explicit(), self.heap.mk_any_explicit()))
433+
} else {
434+
None
435+
}
436+
}
421437
}
422438
}
423439

pyrefly/lib/test/overload.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,34 @@ def f[T](x: T) -> T:
834834
"#,
835835
);
836836

837+
testcase!(
838+
test_overloaded_async_iterator_impl,
839+
r#"
840+
from collections.abc import AsyncIterator
841+
from typing import Literal, assert_type, overload
842+
843+
class Patch:
844+
pass
845+
846+
class FullState(Patch):
847+
pass
848+
849+
@overload
850+
async def astream(*, diff: Literal[True] = ...) -> AsyncIterator[Patch]: ...
851+
@overload
852+
async def astream(*, diff: Literal[False]) -> AsyncIterator[FullState]: ...
853+
async def astream(*, diff: bool = True) -> AsyncIterator[Patch] | AsyncIterator[FullState]:
854+
if diff:
855+
yield Patch()
856+
else:
857+
yield FullState()
858+
859+
assert_type(astream(), AsyncIterator[Patch])
860+
assert_type(astream(diff=True), AsyncIterator[Patch])
861+
assert_type(astream(diff=False), AsyncIterator[FullState])
862+
"#,
863+
);
864+
837865
testcase!(
838866
test_param_consistency,
839867
r#"

0 commit comments

Comments
 (0)