Skip to content

Commit 981aa2c

Browse files
[refactor to use match] AssertionRewriter.visit_Compare()
1 parent 3074cc3 commit 981aa2c

1 file changed

Lines changed: 15 additions & 13 deletions

File tree

src/_pytest/assertion/rewrite.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,12 +1116,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
11161116
def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11171117
self.push_format_context()
11181118
# We first check if we have overwritten a variable in the previous assert
1119-
if isinstance(
1120-
comp.left, ast.Name
1121-
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
1122-
comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment]
1123-
if isinstance(comp.left, ast.NamedExpr):
1124-
self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment]
1119+
match comp.left:
1120+
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
1121+
self.scope, {}
1122+
):
1123+
comp.left = self.variables_overwrite[self.scope][name_id]
1124+
case ast.NamedExpr(target=ast.Name(id=target_id)):
1125+
self.variables_overwrite[self.scope][target_id] = comp.left
11251126
left_res, left_expl = self.visit(comp.left)
11261127
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
11271128
left_expl = f"({left_expl})"
@@ -1133,13 +1134,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11331134
syms: list[ast.expr] = []
11341135
results = [left_res]
11351136
for i, op, next_operand in it:
1136-
if (
1137-
isinstance(next_operand, ast.NamedExpr)
1138-
and isinstance(left_res, ast.Name)
1139-
and next_operand.target.id == left_res.id
1140-
):
1141-
next_operand.target.id = self.variable()
1142-
self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment]
1137+
match (next_operand, left_res):
1138+
case (
1139+
ast.NamedExpr(target=ast.Name(id=target_id)),
1140+
ast.Name(id=name_id),
1141+
) if target_id == name_id:
1142+
next_operand.target.id = self.variable()
1143+
self.variables_overwrite[self.scope][name_id] = next_operand
1144+
11431145
next_res, next_expl = self.visit(next_operand)
11441146
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
11451147
next_expl = f"({next_expl})"

0 commit comments

Comments
 (0)