Skip to content

Commit 09d060f

Browse files
[ty] Memoize binary operator return types (#24700)
## Summary Especially for cases like astral-sh/ty#3039, we were running binary operator inference over and over, and throwing away everything except the return type. This PR adds a cached query for _just_ the return type, which is more lightweight than storing the entire `Bindings` but seemingly still very effective. For: ```python import pandas as pd df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) df["d"] = df["a"] + df["b"] + df["c"] + 1 + (df["a"] ** 2 + df["b"] ** 2 + df["c"] ** 2) ``` Codex reports a 3.32x speedup. Repeating that expression 20 times, Codex reports a 50.79x speedup (from 52.471s down to 1.033s). Closes astral-sh/ty#3039.
1 parent a494065 commit 09d060f

2 files changed

Lines changed: 48 additions & 41 deletions

File tree

crates/ty_python_semantic/src/types/call.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,29 @@ pub(super) use arguments::{Argument, CallArguments};
1111
pub(super) use bind::{Binding, Bindings, CallableBinding, MatchedArgument};
1212

1313
impl<'db> Type<'db> {
14+
/// Memoize the pure return-type part of binary dunder resolution so repeated identical
15+
/// expressions don't re-run overload selection at every call site.
16+
pub(crate) fn try_call_bin_op_return_type(
17+
db: &'db dyn Db,
18+
left_ty: Type<'db>,
19+
op: ast::Operator,
20+
right_ty: Type<'db>,
21+
) -> Option<Type<'db>> {
22+
#[salsa::tracked]
23+
fn try_call_bin_op_return_type_impl<'db>(
24+
db: &'db dyn Db,
25+
left_ty: Type<'db>,
26+
op: ast::Operator,
27+
right_ty: Type<'db>,
28+
) -> Option<Type<'db>> {
29+
Type::try_call_bin_op(db, left_ty, op, right_ty)
30+
.ok()
31+
.map(|bindings| bindings.return_type(db))
32+
}
33+
34+
try_call_bin_op_return_type_impl(db, left_ty, op, right_ty)
35+
}
36+
1437
pub(crate) fn try_call_bin_op(
1538
db: &'db dyn Db,
1639
left_ty: Type<'db>,

crates/ty_python_semantic/src/types/infer/builder/binary_expressions.rs

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
450450
)
451451
}
452452
// For bounded TypeVars or unconstrained TypeVars, fall through to the default handling.
453-
_ => Type::try_call_bin_op(db, left_ty, op, right_ty)
454-
.map(|outcome| outcome.return_type(db))
455-
.ok(),
453+
_ => Type::try_call_bin_op_return_type(db, left_ty, op, right_ty),
456454
}
457455
}
458456

@@ -483,9 +481,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
483481
)
484482
}
485483
// For bounded TypeVars or unconstrained TypeVars, fall through to the default handling.
486-
_ => Type::try_call_bin_op(db, left_ty, op, right_ty)
487-
.map(|outcome| outcome.return_type(db))
488-
.ok(),
484+
_ => Type::try_call_bin_op_return_type(db, left_ty, op, right_ty),
489485
}
490486
}
491487

@@ -511,9 +507,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
511507
)
512508
}
513509
// For bounded TypeVars or unconstrained TypeVars, fall through to the default handling.
514-
_ => Type::try_call_bin_op(db, left_ty, op, right_ty)
515-
.map(|outcome| outcome.return_type(db))
516-
.ok(),
510+
_ => Type::try_call_bin_op_return_type(db, left_ty, op, right_ty),
517511
}
518512
}
519513

@@ -524,34 +518,28 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
524518
// positional arguments get. In those cases we need to explicitly delegate to the base
525519
// type, so that it hits the `Type::Union` branches above.
526520
(Type::NewTypeInstance(newtype), rhs, _) => {
527-
Type::try_call_bin_op(db, left_ty, op, right_ty)
528-
.map(|outcome| outcome.return_type(db))
529-
.ok()
530-
.or_else(|| {
531-
self.infer_binary_expression_type_impl(
532-
node,
533-
emitted_division_by_zero_diagnostic,
534-
newtype.concrete_base_type(db),
535-
rhs,
536-
op,
537-
visitor,
538-
)
539-
})
521+
Type::try_call_bin_op_return_type(db, left_ty, op, right_ty).or_else(|| {
522+
self.infer_binary_expression_type_impl(
523+
node,
524+
emitted_division_by_zero_diagnostic,
525+
newtype.concrete_base_type(db),
526+
rhs,
527+
op,
528+
visitor,
529+
)
530+
})
540531
}
541532
(lhs, Type::NewTypeInstance(newtype), _) => {
542-
Type::try_call_bin_op(db, left_ty, op, right_ty)
543-
.map(|outcome| outcome.return_type(db))
544-
.ok()
545-
.or_else(|| {
546-
self.infer_binary_expression_type_impl(
547-
node,
548-
emitted_division_by_zero_diagnostic,
549-
lhs,
550-
newtype.concrete_base_type(db),
551-
op,
552-
visitor,
553-
)
554-
})
533+
Type::try_call_bin_op_return_type(db, left_ty, op, right_ty).or_else(|| {
534+
self.infer_binary_expression_type_impl(
535+
node,
536+
emitted_division_by_zero_diagnostic,
537+
lhs,
538+
newtype.concrete_base_type(db),
539+
op,
540+
visitor,
541+
)
542+
})
555543
}
556544

557545
(
@@ -854,9 +842,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
854842
Some(result)
855843
}
856844

857-
_ => Type::try_call_bin_op(db, left_ty, op, right_ty)
858-
.map(|outcome| outcome.return_type(db))
859-
.ok(),
845+
_ => Type::try_call_bin_op_return_type(db, left_ty, op, right_ty),
860846
}
861847
}
862848

@@ -1039,9 +1025,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
10391025
| Type::TypeGuard(_)
10401026
| Type::TypedDict(_),
10411027
op,
1042-
) => Type::try_call_bin_op(db, left_ty, op, right_ty)
1043-
.map(|outcome| outcome.return_type(db))
1044-
.ok(),
1028+
) => Type::try_call_bin_op_return_type(db, left_ty, op, right_ty),
10451029
}
10461030
}
10471031

0 commit comments

Comments
 (0)