Skip to content

Commit 0c86469

Browse files
rchen152facebook-github-bot
authored andcommitted
Take extra_items into account when synthesizing get()
Summary: For #946. A few asserted types changed because we don't simplify unions containing `object`. Reviewed By: samwgoldman Differential Revision: D80870295 fbshipit-source-id: c52e3163f6b16410b8fbc3e2f7d49c3abde6abc7
1 parent ed45ff3 commit 0c86469

2 files changed

Lines changed: 19 additions & 10 deletions

File tree

pyrefly/lib/alt/class/typed_dict.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
436436
));
437437
}
438438
}
439-
let signatures = Vec1::from_vec_push(
439+
let value_ty = self.get_typed_dict_value_type_from_fields(cls, fields);
440+
let mut signatures = Vec1::from_vec_push(
440441
literal_signatures,
441442
OverloadType::Function(Function {
442443
signature: Callable::list(
@@ -447,17 +448,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
447448
self.stdlib.str().clone().to_type(),
448449
Required::Required,
449450
),
450-
Param::PosOnly(
451-
Some(DEFAULT_PARAM.clone()),
452-
object_ty.clone(),
453-
Required::Optional(None),
454-
),
455451
]),
456-
object_ty.clone(),
452+
Type::optional(value_ty.clone()),
457453
),
458454
metadata: metadata.clone(),
459455
}),
460456
);
457+
signatures.push(self.get_overload_with_default(&metadata, &self_param, None, value_ty));
461458
ClassSynthesizedField::new(Type::Overload(Overload {
462459
signatures,
463460
metadata: Box::new(metadata),

pyrefly/lib/test/typed_dict.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,8 @@ class C(TypedDict):
784784
y: str
785785
def f(c: C, k1: str, k2: int):
786786
assert_type(c.get("x"), int)
787-
assert_type(c.get(k1), object)
788-
assert_type(c.get(k1, 0), object)
787+
assert_type(c.get(k1), object | None)
788+
assert_type(c.get(k1, 0), int | object)
789789
c.get(k2) # E: No matching overload
790790
"#,
791791
);
@@ -1511,7 +1511,7 @@ def f(x: X):
15111511
);
15121512

15131513
testcase!(
1514-
test_get_extra_item,
1514+
test_getitem_extra_items,
15151515
r#"
15161516
from typing import assert_type, TypedDict
15171517
class A(TypedDict, extra_items=bool):
@@ -1681,3 +1681,15 @@ def f(a: A, k: str):
16811681
assert_type(a.setdefault(k, 0), int | str)
16821682
"#,
16831683
);
1684+
1685+
testcase!(
1686+
test_get_extra_items,
1687+
r#"
1688+
from typing import assert_type, TypedDict
1689+
class A(TypedDict, extra_items=int):
1690+
x: str
1691+
def f(a: A, k: str):
1692+
assert_type(a.get(k), str | int | None)
1693+
assert_type(a.get(k, b'hello world'), str | int | bytes)
1694+
"#,
1695+
);

0 commit comments

Comments
 (0)