Skip to content

Commit 368d2fc

Browse files
[refactor] Tighten SetComparisonFunction to Iterator[str] (#14587)
Addresses review feedback on PR #14523: * drop the redundant ``: Iterator[str]`` annotation on ``source`` — every branch already produces an ``Iterator[str]``. * return ``Iterator[str]`` from ``SetComparisonFunction`` instead of ``Iterable[str]`` so the call site no longer needs ``iter(...)``; the ``!=`` branch is promoted from a list-returning lambda to a named generator so the new contract holds. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ff77cd8 commit 368d2fc

2 files changed

Lines changed: 14 additions & 7 deletions

File tree

src/_pytest/assertion/_compare_set.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from collections.abc import Iterable
54
from collections.abc import Iterator
65
from collections.abc import Set as AbstractSet
76
from typing import TypeAlias
@@ -77,14 +76,24 @@ def _compare_lt_set(
7776

7877
SetComparisonFunction: TypeAlias = Callable[
7978
[AbstractSet[object], AbstractSet[object], _HighlightFunc, int],
80-
Iterable[str],
79+
Iterator[str],
8180
]
8281

82+
83+
def _both_sets_are_equal(
84+
left: AbstractSet[object],
85+
right: AbstractSet[object],
86+
highlighter: _HighlightFunc,
87+
verbose: int = 0,
88+
) -> Iterator[str]:
89+
yield "Both sets are equal"
90+
91+
8392
SET_COMPARISON_FUNCTIONS: dict[str, SetComparisonFunction] = {
8493
# == can't be done here without a prior refactor because there's an additional
8594
# explanation for iterable in _compare_eq_any
8695
# "==": _compare_eq_set,
87-
"!=": lambda *a, **kw: ["Both sets are equal"],
96+
"!=": _both_sets_are_equal,
8897
">=": _compare_gte_set,
8998
"<=": _compare_lte_set,
9099
">": _compare_gt_set,

src/_pytest/assertion/util.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def assertrepr_compare(
177177

178178
try:
179179
if op == "==":
180-
source: Iterator[str] = _compare_eq_any(
180+
source = _compare_eq_any(
181181
left,
182182
right,
183183
highlighter,
@@ -187,9 +187,7 @@ def assertrepr_compare(
187187
elif op == "not in" and istext(left) and istext(right):
188188
source = _notin_text(left, right, verbose)
189189
elif op in {"!=", ">=", "<=", ">", "<"} and isset(left) and isset(right):
190-
source = iter(
191-
SET_COMPARISON_FUNCTIONS[op](left, right, highlighter, verbose)
192-
)
190+
source = SET_COMPARISON_FUNCTIONS[op](left, right, highlighter, verbose)
193191
else:
194192
source = iter(())
195193

0 commit comments

Comments
 (0)