Skip to content

Commit 46bd39f

Browse files
fix
1 parent b1e5a2f commit 46bd39f

3 files changed

Lines changed: 149 additions & 39 deletions

File tree

pyrefly/lib/alt/function.rs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
368368
} else {
369369
let metadata = self
370370
.merge_overload_metadata_with_implementation(&defs, def.metadata().clone());
371-
let sigs = self.extract_signatures(
372-
metadata.kind.function_name().as_ref(),
373-
defs,
374-
errors,
371+
let sigs = self.normalize_async_generator_overloads(
372+
self.extract_signatures(
373+
metadata.kind.function_name().as_ref(),
374+
defs,
375+
errors,
376+
),
377+
def.ty.callable_signatures()[0],
375378
);
376379
self.check_signature_consistency(&sigs, &def, errors);
377380
Type::Overload(Overload {
@@ -1658,6 +1661,49 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
16581661
metadata
16591662
}
16601663

1664+
fn normalize_async_generator_overload_signature(&self, sig: &Callable) -> Callable {
1665+
if let Some((_, _, ret)) = self.unwrap_coroutine(&sig.ret)
1666+
&& self.decompose_async_generator(&ret).is_some()
1667+
{
1668+
let mut sig = sig.clone();
1669+
sig.ret = ret;
1670+
sig
1671+
} else {
1672+
sig.clone()
1673+
}
1674+
}
1675+
1676+
fn normalize_async_generator_overloads(
1677+
&self,
1678+
overloads: Vec1<(TextRange, OverloadType)>,
1679+
impl_sig: &Callable,
1680+
) -> Vec1<(TextRange, OverloadType)> {
1681+
if self.decompose_async_generator(&impl_sig.ret).is_none() {
1682+
return overloads;
1683+
}
1684+
overloads.mapped(|(range, overload)| {
1685+
(
1686+
range,
1687+
match overload {
1688+
OverloadType::Function(func) => OverloadType::Function(Function {
1689+
signature: self
1690+
.normalize_async_generator_overload_signature(&func.signature),
1691+
metadata: func.metadata,
1692+
}),
1693+
OverloadType::Forall(forall) => OverloadType::Forall(Forall {
1694+
tparams: forall.tparams,
1695+
body: Function {
1696+
signature: self.normalize_async_generator_overload_signature(
1697+
&forall.body.signature,
1698+
),
1699+
metadata: forall.body.metadata,
1700+
},
1701+
}),
1702+
},
1703+
)
1704+
})
1705+
}
1706+
16611707
/// Substitute each type parameter with its upper bound, for overload consistency checking.
16621708
fn subst_function(&self, tparams: &TParams, func: Function) -> Function {
16631709
let mp = tparams

pyrefly/lib/alt/unwrap.rs

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -374,45 +374,81 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
374374
}
375375

376376
pub fn decompose_generator(&self, ty: &Type) -> Option<(Type, Type, Type)> {
377-
let yield_ty = self.fresh_var();
378-
let send_ty = self.fresh_var();
379-
let return_ty = self.fresh_var();
380-
let generator_ty = self.heap.mk_class_type(self.stdlib.generator(
381-
yield_ty.to_type(self.heap),
382-
send_ty.to_type(self.heap),
383-
return_ty.to_type(self.heap),
384-
));
385-
if self.is_subset_eq(&generator_ty, ty) {
386-
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
387-
let send_ty = self
388-
.resolve_var_opt(ty, send_ty)
389-
.unwrap_or_else(|| self.heap.mk_none());
390-
let return_ty = self
391-
.resolve_var_opt(ty, return_ty)
392-
.unwrap_or_else(|| self.heap.mk_none());
393-
Some((yield_ty, send_ty, return_ty))
394-
} else {
395-
None
377+
match ty {
378+
Type::Union(box Union { members, .. }) => {
379+
let results: Option<Vec<_>> = members
380+
.iter()
381+
.map(|member| self.decompose_generator(member))
382+
.collect();
383+
let (yield_tys, send_tys, return_tys) = results?.into_iter().fold(
384+
(Vec::new(), Vec::new(), Vec::new()),
385+
|mut acc, (yield_ty, send_ty, return_ty)| {
386+
acc.0.push(yield_ty);
387+
acc.1.push(send_ty);
388+
acc.2.push(return_ty);
389+
acc
390+
},
391+
);
392+
Some((
393+
self.unions(yield_tys),
394+
self.unions(send_tys),
395+
self.unions(return_tys),
396+
))
397+
}
398+
_ => {
399+
let yield_ty = self.fresh_var();
400+
let send_ty = self.fresh_var();
401+
let return_ty = self.fresh_var();
402+
let generator_ty = self.heap.mk_class_type(self.stdlib.generator(
403+
yield_ty.to_type(self.heap),
404+
send_ty.to_type(self.heap),
405+
return_ty.to_type(self.heap),
406+
));
407+
if self.is_subset_eq(&generator_ty, ty) {
408+
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
409+
let send_ty = self
410+
.resolve_var_opt(ty, send_ty)
411+
.unwrap_or_else(|| self.heap.mk_none());
412+
let return_ty = self
413+
.resolve_var_opt(ty, return_ty)
414+
.unwrap_or_else(|| self.heap.mk_none());
415+
Some((yield_ty, send_ty, return_ty))
416+
} else {
417+
None
418+
}
419+
}
396420
}
397421
}
398422

399423
pub fn decompose_async_generator(&self, ty: &Type) -> Option<(Type, Type)> {
400-
let yield_ty = self.fresh_var();
401-
let send_ty = self.fresh_var();
402-
let async_generator_ty = self.heap.mk_class_type(
403-
self.stdlib
404-
.async_generator(yield_ty.to_type(self.heap), send_ty.to_type(self.heap)),
405-
);
406-
if self.is_subset_eq(&async_generator_ty, ty) {
407-
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
408-
let send_ty = self
409-
.resolve_var_opt(ty, send_ty)
410-
.unwrap_or_else(|| self.heap.mk_none());
411-
Some((yield_ty, send_ty))
412-
} else if ty.is_any() {
413-
Some((self.heap.mk_any_explicit(), self.heap.mk_any_explicit()))
414-
} else {
415-
None
424+
match ty {
425+
Type::Union(box Union { members, .. }) => {
426+
let results: Option<Vec<_>> = members
427+
.iter()
428+
.map(|member| self.decompose_async_generator(member))
429+
.collect();
430+
let (yield_tys, send_tys) = results?.into_iter().unzip();
431+
Some((self.unions(yield_tys), self.unions(send_tys)))
432+
}
433+
_ => {
434+
let yield_ty = self.fresh_var();
435+
let send_ty = self.fresh_var();
436+
let async_generator_ty = self.heap.mk_class_type(
437+
self.stdlib
438+
.async_generator(yield_ty.to_type(self.heap), send_ty.to_type(self.heap)),
439+
);
440+
if self.is_subset_eq(&async_generator_ty, ty) {
441+
let yield_ty: Type = self.resolve_var_opt(ty, yield_ty)?;
442+
let send_ty = self
443+
.resolve_var_opt(ty, send_ty)
444+
.unwrap_or_else(|| self.heap.mk_none());
445+
Some((yield_ty, send_ty))
446+
} else if ty.is_any() {
447+
Some((self.heap.mk_any_explicit(), self.heap.mk_any_explicit()))
448+
} else {
449+
None
450+
}
451+
}
416452
}
417453
}
418454

pyrefly/lib/test/overload.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,34 @@ def f[T](x: T) -> T:
834834
"#,
835835
);
836836

837+
testcase!(
838+
test_overloaded_async_iterator_impl,
839+
r#"
840+
from collections.abc import AsyncIterator
841+
from typing import Literal, assert_type, overload
842+
843+
class Patch:
844+
pass
845+
846+
class FullState(Patch):
847+
pass
848+
849+
@overload
850+
async def astream(*, diff: Literal[True] = ...) -> AsyncIterator[Patch]: ...
851+
@overload
852+
async def astream(*, diff: Literal[False]) -> AsyncIterator[FullState]: ...
853+
async def astream(*, diff: bool = True) -> AsyncIterator[Patch] | AsyncIterator[FullState]:
854+
if diff:
855+
yield Patch()
856+
else:
857+
yield FullState()
858+
859+
assert_type(astream(), AsyncIterator[Patch])
860+
assert_type(astream(diff=True), AsyncIterator[Patch])
861+
assert_type(astream(diff=False), AsyncIterator[FullState])
862+
"#,
863+
);
864+
837865
testcase!(
838866
test_param_consistency,
839867
r#"

0 commit comments

Comments
 (0)