Skip to content

Commit 8653efd

Browse files
[ty] Expand class bases in per-base lint checks (#24699)
<!-- Thank you for contributing to Ruff/ty! To help us out with reviewing, please consider the following: - Does this pull request include a summary of the change? (See below.) - Does this pull request include a descriptive title? (Please prefix with `[ty]` for ty pull requests.) - Does this pull request include references to any relevant issues? - Does this PR follow our AI policy (https://github.com/astral-sh/.github/blob/main/AI_POLICY.md)? --> ## Summary <!-- What's the purpose of the change? What does it do, and why? --> The per-base checks in `check_static_class_definitions` iterate over `class.explicit_bases(db)` (the expanded bases list) but use the loop index to look up the AST node via `&class_node.bases()[i]`. When a starred base unpacks a fixed-length tuple, the expanded list is longer than the AST bases list and the indexing panics -- e.g. `class X(*(int, bool)): ...` panics in the `SUBCLASS_OF_FINAL_CLASS` branch. Closes astral-sh/ty#3293. A related panic was reported in astral-sh/ty#3290 and fixed in #24695, but #24695 only addressed the `try_mro` arms. This PR applies the same fix to the per-base lint loop, reusing the `expanded_class_base_entries` abstraction introduced there. ## Test Plan <!-- How was it tested? --> Two cases are added in `mdtest/mro.md` covering `subclass-of-final-class` via starred unpacking. --------- Co-authored-by: Charlie Marsh <charlie.r.marsh@gmail.com>
1 parent 2d2337f commit 8653efd

2 files changed

Lines changed: 79 additions & 44 deletions

File tree

crates/ty_python_semantic/resources/mdtest/mro.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,35 @@ reveal_mro(NameDuplicateBases) # revealed: (<class 'NameDuplicateBases'>, Unkno
614614
class StarredInvalidBases(*invalid_bases): ...
615615
```
616616

617+
Per-base lint checks also see the unpacked entries:
618+
619+
```py
620+
from typing import Generic, NamedTuple, Protocol
621+
622+
# error: [inconsistent-mro]
623+
# error: [subclass-of-final-class]
624+
class InheritsFromFinalViaStarred(*(int, bool)): ...
625+
626+
final_bases = (int, bool)
627+
628+
# error: [inconsistent-mro]
629+
# error: [subclass-of-final-class]
630+
class InheritsFromFinalViaNamedStarred(*final_bases): ...
631+
632+
# error: [instance-layout-conflict]
633+
# error: [invalid-named-tuple]
634+
# error: [invalid-named-tuple]
635+
class NamedTupleWithStarredBases(NamedTuple, *(int, str)): ...
636+
637+
# error: [inconsistent-mro]
638+
# error: [invalid-protocol]
639+
# error: [invalid-protocol]
640+
class ProtocolWithStarredBases(Protocol, *(int, str)): ...
641+
642+
# error: [invalid-base]
643+
class BareGenericInStarred(*(int, Generic)): ...
644+
```
645+
617646
## Inline tuple-literal starred bases point diagnostics at unpacked elements
618647

619648
<!-- snapshot-diagnostics -->

crates/ty_python_semantic/src/types/infer/builder/post_inference/static_class.rs

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -210,60 +210,68 @@ pub(crate) fn check_static_class_definitions<'db>(
210210
}
211211

212212
let mut disjoint_bases = IncompatibleBases::default();
213-
let mut protocol_base_with_generic_context = None;
213+
let mut protocol_base_with_generic_context: Option<(&ast::Expr, _)> = None;
214214
let mut direct_typed_dict_bases = vec![];
215215

216+
let class_definition = index.expect_single_definition(class_node);
217+
216218
// Iterate through the class's explicit bases to check for various possible errors:
217219
// - Check for inheritance from plain `Generic`,
218220
// - Check for inheritance from a `@final` classes
219221
// - If the class is a protocol class: check for inheritance from a non-protocol class
220222
// - If the class is a NamedTuple class: check for multiple inheritance that isn't `Generic[]`
221-
for (i, base_class) in class.explicit_bases(db).iter().enumerate() {
223+
let expanded_base_entries =
224+
expanded_class_base_entries(db, class.known(db), class_node, class_definition);
225+
for (i, entry) in expanded_base_entries.iter().enumerate() {
226+
let source_node = entry.source_node();
227+
let base_class = entry.ty();
228+
222229
if class_kind == Some(CodeGeneratorKind::NamedTuple)
223230
&& !matches!(
224231
base_class,
225232
Type::SpecialForm(SpecialFormType::NamedTuple)
226233
| Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(_))
227234
)
235+
&& let Some(node) = source_node
236+
&& let Some(builder) = context.report_lint(&INVALID_NAMED_TUPLE, node)
228237
{
229-
if let Some(builder) = context.report_lint(&INVALID_NAMED_TUPLE, &class_node.bases()[i])
230-
{
231-
builder.into_diagnostic(format_args!(
232-
"NamedTuple class `{}` cannot use multiple inheritance except with `Generic[]`",
233-
class.name(db),
234-
));
235-
}
238+
builder.into_diagnostic(format_args!(
239+
"NamedTuple class `{}` cannot use multiple inheritance except with `Generic[]`",
240+
class.name(db),
241+
));
236242
}
237243

238244
let base_class = match base_class {
239245
Type::SpecialForm(SpecialFormType::Generic) => {
240-
if let Some(builder) = context.report_lint(&INVALID_BASE, &class_node.bases()[i]) {
246+
if let Some(node) = source_node
247+
&& let Some(builder) = context.report_lint(&INVALID_BASE, node)
248+
{
241249
// Unsubscripted `Generic` can appear in the MRO of many classes,
242250
// but it is never valid as an explicit base class in user code.
243251
builder.into_diagnostic("Cannot inherit from plain `Generic`");
244252
}
245253
continue;
246254
}
247255
Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(new_context)) => {
248-
let Some((previous_index, previous_context)) = protocol_base_with_generic_context
256+
let Some((previous_node, previous_context)) = protocol_base_with_generic_context
249257
else {
250258
continue;
251259
};
252-
let prior_node = &class_node.bases()[previous_index];
253-
let Some(builder) = context.report_lint(&INVALID_GENERIC_CLASS, prior_node) else {
260+
let Some(builder) = context.report_lint(&INVALID_GENERIC_CLASS, previous_node)
261+
else {
254262
continue;
255263
};
256264
let mut diagnostic = builder.into_diagnostic(
257265
"Cannot both inherit from subscripted `Protocol` \
258266
and subscripted `Generic`",
259267
);
260-
if let ast::Expr::Subscript(prior_node) = prior_node
268+
if let ast::Expr::Subscript(previous_node) = previous_node
261269
&& new_context == previous_context
262270
{
263271
diagnostic.help("Remove the type parameters from the `Protocol` base");
264272
diagnostic.set_fix(Fix::unsafe_edit(Edit::range_deletion(TextRange::new(
265-
prior_node.value.end(),
266-
prior_node.end(),
273+
previous_node.value.end(),
274+
previous_node.end(),
267275
))));
268276
}
269277
continue;
@@ -273,16 +281,17 @@ pub(crate) fn check_static_class_definitions<'db>(
273281
// but it is semantically invalid.
274282
Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(generic_context)) => {
275283
if let Some(type_params) = class_node.type_params.as_deref() {
276-
let Some(builder) =
277-
context.report_lint(&INVALID_GENERIC_CLASS, &class_node.bases()[i])
278-
else {
284+
let Some(node) = source_node else {
285+
continue;
286+
};
287+
let Some(builder) = context.report_lint(&INVALID_GENERIC_CLASS, node) else {
279288
continue;
280289
};
281290
let mut diagnostic = builder.into_diagnostic(
282291
"Cannot both inherit from subscripted `Protocol` \
283292
and use PEP 695 type variables",
284293
);
285-
if let ast::Expr::Subscript(node) = &class_node.bases()[i] {
294+
if let ast::Expr::Subscript(node) = node {
286295
let source = source_text(db, context.file());
287296
let type_params_range = TextRange::new(
288297
type_params.start().saturating_add(TextSize::new(1)),
@@ -295,13 +304,15 @@ pub(crate) fn check_static_class_definitions<'db>(
295304
)));
296305
}
297306
}
298-
} else if protocol_base_with_generic_context.is_none() {
299-
protocol_base_with_generic_context = Some((i, generic_context));
307+
} else if let Some(node) = source_node
308+
&& protocol_base_with_generic_context.is_none()
309+
{
310+
protocol_base_with_generic_context = Some((node, generic_context));
300311
}
301312
continue;
302313
}
303-
Type::ClassLiteral(class) => ClassType::NonGeneric(*class),
304-
Type::GenericAlias(class) => ClassType::Generic(*class),
314+
Type::ClassLiteral(class) => ClassType::NonGeneric(class),
315+
Type::GenericAlias(class) => ClassType::Generic(class),
305316
_ => continue,
306317
};
307318

@@ -312,8 +323,8 @@ pub(crate) fn check_static_class_definitions<'db>(
312323
if is_protocol {
313324
if !base_class.is_protocol(db)
314325
&& !base_class.is_object(db)
315-
&& let Some(builder) =
316-
context.report_lint(&INVALID_PROTOCOL, &class_node.bases()[i])
326+
&& let Some(node) = source_node
327+
&& let Some(builder) = context.report_lint(&INVALID_PROTOCOL, node)
317328
{
318329
builder.into_diagnostic(format_args!(
319330
"Protocol class `{}` cannot inherit from non-protocol class `{}`",
@@ -323,8 +334,8 @@ pub(crate) fn check_static_class_definitions<'db>(
323334
}
324335
} else if class_kind == Some(CodeGeneratorKind::TypedDict) {
325336
if !base_class.class_literal(db).is_typed_dict(db)
326-
&& let Some(builder) =
327-
context.report_lint(&INVALID_TYPED_DICT_HEADER, &class_node.bases()[i])
337+
&& let Some(node) = source_node
338+
&& let Some(builder) = context.report_lint(&INVALID_TYPED_DICT_HEADER, node)
328339
{
329340
let mut diagnostic = builder.into_diagnostic(format_args!(
330341
"TypedDict class `{}` can only inherit from TypedDict classes",
@@ -344,16 +355,15 @@ pub(crate) fn check_static_class_definitions<'db>(
344355
}
345356
}
346357

347-
if base_class.is_final(db) {
348-
if let Some(builder) =
349-
context.report_lint(&SUBCLASS_OF_FINAL_CLASS, &class_node.bases()[i])
350-
{
351-
builder.into_diagnostic(format_args!(
352-
"Class `{}` cannot inherit from final class `{}`",
353-
class.name(db),
354-
base_class.name(db),
355-
));
356-
}
358+
if base_class.is_final(db)
359+
&& let Some(node) = source_node
360+
&& let Some(builder) = context.report_lint(&SUBCLASS_OF_FINAL_CLASS, node)
361+
{
362+
builder.into_diagnostic(format_args!(
363+
"Class `{}` cannot inherit from final class `{}`",
364+
class.name(db),
365+
base_class.name(db),
366+
));
357367
}
358368

359369
if let Some((base_class_literal, _)) = base_class.static_class_literal(db)
@@ -362,20 +372,20 @@ pub(crate) fn check_static_class_definitions<'db>(
362372
class.is_frozen_dataclass(db),
363373
)
364374
&& base_is_frozen != class_is_frozen
375+
&& let Some(node) = source_node
365376
{
366377
report_bad_frozen_dataclass_inheritance(
367378
context,
368379
class,
369380
class_node,
370381
base_class_literal,
371-
&class_node.bases()[i],
382+
node,
372383
base_is_frozen,
373384
);
374385
}
375386
}
376387

377388
// Check for starred variable-length tuples that cannot be unpacked
378-
let class_definition = index.expect_single_definition(class_node);
379389
for base in class_node.bases() {
380390
if let ast::Expr::Starred(starred) = base
381391
&& let starred_ty = definition_expression_type(db, class_definition, &starred.value)
@@ -390,15 +400,11 @@ pub(crate) fn check_static_class_definitions<'db>(
390400
match class.try_mro(db, None) {
391401
Err(mro_error) => match mro_error.reason() {
392402
StaticMroErrorKind::DuplicateBases(duplicates) => {
393-
let expanded_base_entries =
394-
expanded_class_base_entries(db, class.known(db), class_node, class_definition);
395403
for duplicate in duplicates {
396404
report_duplicate_bases(context, class, duplicate, &expanded_base_entries);
397405
}
398406
}
399407
StaticMroErrorKind::InvalidBases(bases) => {
400-
let expanded_base_entries =
401-
expanded_class_base_entries(db, class.known(db), class_node, class_definition);
402408
for (index, base_ty) in bases {
403409
if let Some(base_node) = expanded_base_entries[*index].source_node() {
404410
report_invalid_or_unsupported_base(context, base_node, *base_ty, class);

0 commit comments

Comments
 (0)