Skip to content

Commit 69307a3

Browse files
authored
fix(specifiers): === uses original string, not normalized (#1124)
1 parent e938536 commit 69307a3

2 files changed

Lines changed: 147 additions & 4 deletions

File tree

src/packaging/specifiers.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,13 +727,21 @@ def filter(
727727
# Filter versions
728728
for version in iterable:
729729
parsed_version = _coerce_version(version if key is None else key(version))
730+
match = False
730731
if parsed_version is None:
731732
# === operator can match arbitrary (non-version) strings
732733
if self.operator == "===" and self._compare_arbitrary(
733734
version, self.version
734735
):
735736
yield version
736-
elif operator_callable(parsed_version, self.version):
737+
elif self.operator == "===":
738+
match = self._compare_arbitrary(
739+
version if key is None else key(version), self.version
740+
)
741+
else:
742+
match = operator_callable(parsed_version, self.version)
743+
744+
if match and parsed_version is not None:
737745
# If it's not a prerelease or prereleases are allowed, yield it directly
738746
if not parsed_version.is_prerelease or include_prereleases:
739747
found_non_prereleases = True
@@ -898,7 +906,13 @@ class SpecifierSet(BaseSpecifier):
898906
specifiers (``>=3.0,!=3.1``), or no specifier at all.
899907
"""
900908

901-
__slots__ = ("_canonicalized", "_prereleases", "_resolved_ops", "_specs")
909+
__slots__ = (
910+
"_canonicalized",
911+
"_has_arbitrary",
912+
"_prereleases",
913+
"_resolved_ops",
914+
"_specs",
915+
)
902916

903917
def __init__(
904918
self,
@@ -928,8 +942,13 @@ def __init__(
928942
split_specifiers = [s.strip() for s in specifiers.split(",") if s.strip()]
929943

930944
self._specs: tuple[Specifier, ...] = tuple(map(Specifier, split_specifiers))
945+
# Fast substring check; avoids iterating parsed specs.
946+
self._has_arbitrary = "===" in specifiers
931947
else:
932948
self._specs = tuple(specifiers)
949+
# Substring check works for both Specifier objects and plain
950+
# strings (setuptools passes lists of strings).
951+
self._has_arbitrary = any("===" in str(s) for s in self._specs)
933952

934953
self._canonicalized = len(self._specs) <= 1
935954
self._resolved_ops: list[tuple[CallableOperator, str, str]] | None = None
@@ -1025,6 +1044,7 @@ def __and__(self, other: SpecifierSet | str) -> SpecifierSet:
10251044
specifier = SpecifierSet()
10261045
specifier._specs = self._specs + other._specs
10271046
specifier._canonicalized = len(specifier._specs) <= 1
1047+
specifier._has_arbitrary = self._has_arbitrary or other._has_arbitrary
10281048
specifier._resolved_ops = None
10291049

10301050
# Combine prerelease settings: use common or non-None value
@@ -1137,7 +1157,12 @@ def contains(
11371157
if version is not None and installed and version.is_prerelease:
11381158
prereleases = True
11391159

1140-
check_item = item if version is None else version
1160+
# When item is a string and === is involved, keep it as-is
1161+
# so the comparison isn't done against the normalized form.
1162+
if version is None or (self._has_arbitrary and not isinstance(item, Version)):
1163+
check_item = item
1164+
else:
1165+
check_item = version
11411166
return bool(list(self.filter([check_item], prereleases=prereleases)))
11421167

11431168
@typing.overload
@@ -1289,6 +1314,11 @@ def _filter_versions(
12891314
yield item
12901315
elif exclude_prereleases and parsed.is_prerelease:
12911316
pass
1292-
elif all(op_fn(parsed, ver) for op_fn, ver, _ in ops):
1317+
elif all(
1318+
str(item if key is None else key(item)).lower() == ver.lower()
1319+
if op == "==="
1320+
else op_fn(parsed, ver)
1321+
for op_fn, ver, op in ops
1322+
):
12931323
# Short-circuits on the first failing operator.
12941324
yield item

tests/test_specifiers.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,61 @@ def test_arbitrary_equality(
691691
spec = Specifier(spec_str)
692692
assert spec.contains(version) == expected
693693

694+
@pytest.mark.parametrize(
695+
("spec_str", "version", "expected"),
696+
[
697+
# Zero padding: unnormalized spec vs string/Version
698+
# Strings preserve their original form, so "1.01" != "1.1"
699+
("===1.1", "1.01", False),
700+
("===1.01", "1.1", False),
701+
("===1.01", "1.01", True),
702+
("===1.1", "1.1", True),
703+
# Version objects are normalized, so Version("1.01") -> "1.1"
704+
("===1.1", Version("1.01"), True),
705+
("===1.1", Version("1.1"), True),
706+
("===1.01", Version("1.01"), False),
707+
("===1.01", Version("1.1"), False),
708+
# Prerelease separator normalization (issue #766)
709+
# "1.a1" is valid PEP 440, normalizes to "1a1"
710+
("===1.a1", "1.a1", True),
711+
("===1a1", "1.a1", False),
712+
("===1.a1", "1a1", False),
713+
("===1a1", "1a1", True),
714+
("===1.a1", Version("1.a1"), False),
715+
("===1a1", Version("1.a1"), True),
716+
# Epoch normalization: "0!1.0" normalizes to "1.0"
717+
("===0!1.0", "0!1.0", True),
718+
("===0!1.0", "1.0", False),
719+
("===1.0", "0!1.0", False),
720+
("===0!1.0", Version("1.0"), False),
721+
("===1.0", Version("0!1.0"), True),
722+
# Leading zeros in release segments
723+
("===01.0", "01.0", True),
724+
("===01.0", "1.0", False),
725+
("===1.0", "01.0", False),
726+
("===01.0", Version("1.0"), False),
727+
("===1.0", Version("01.0"), True),
728+
# Post-release normalization: "post" vs "-" separator
729+
("===1.0.post1", "1.0.post1", True),
730+
("===1.0-1", "1.0-1", True),
731+
("===1.0-1", "1.0.post1", False),
732+
("===1.0.post1", "1.0-1", False),
733+
("===1.0-1", Version("1.0.post1"), False),
734+
("===1.0.post1", Version("1.0-1"), True),
735+
# Dev normalization
736+
("===1.0.dev01", "1.0.dev01", True),
737+
("===1.0.dev01", "1.0.dev1", False),
738+
("===1.0.dev1", "1.0.dev01", False),
739+
("===1.0.dev01", Version("1.0.dev1"), False),
740+
("===1.0.dev1", Version("1.0.dev01"), True),
741+
],
742+
)
743+
def test_arbitrary_equality_normalization(
744+
self, spec_str: str, version: str | Version, expected: bool
745+
) -> None:
746+
spec = Specifier(spec_str, prereleases=True)
747+
assert spec.contains(version) == expected
748+
694749
@pytest.mark.parametrize(
695750
("specifier", "expected"),
696751
[
@@ -830,6 +885,33 @@ def test_specifier_filter(
830885

831886
assert result == expected
832887

888+
@pytest.mark.parametrize(
889+
("specifier", "input", "expected"),
890+
[
891+
# Strings preserve original form
892+
("===1.01", ["1.01", "1.1", "1.0.1"], ["1.01"]),
893+
("===1.1", ["1.01", "1.1"], ["1.1"]),
894+
# Version objects use normalized form
895+
(
896+
"===1.1",
897+
[Version("1.01"), Version("1.1")],
898+
[Version("1.01"), Version("1.1")],
899+
),
900+
("===1.01", [Version("1.01"), Version("1.1")], []),
901+
# Mixed strings and Version objects
902+
("===1.1", ["1.01", "1.1", Version("1.01")], ["1.1", Version("1.01")]),
903+
("===1.01", ["1.01", "1.1", Version("1.01")], ["1.01"]),
904+
# Prerelease separator
905+
("===1.a1", ["1.a1", "1a1"], ["1.a1"]),
906+
("===1a1", ["1.a1", "1a1", Version("1.a1")], ["1a1", Version("1.a1")]),
907+
],
908+
)
909+
def test_specifier_filter_arbitrary_equality_normalization(
910+
self, specifier: str, input: list[str | Version], expected: list[str | Version]
911+
) -> None:
912+
spec = Specifier(specifier, prereleases=True)
913+
assert list(spec.filter(input)) == expected
914+
833915
@pytest.mark.parametrize(
834916
("prereleases", "expected_indexes"),
835917
[
@@ -1917,6 +1999,37 @@ def test_contains_arbitrary_equality_contains(
19171999
spec = SpecifierSet(specifier)
19182000
assert spec.contains(version) == expected
19192001

2002+
@pytest.mark.parametrize(
2003+
("spec_str", "version", "expected"),
2004+
[
2005+
# Zero padding: string preserves original, Version normalizes
2006+
("===1.1", "1.01", False),
2007+
("===1.01", "1.1", False),
2008+
("===1.01", "1.01", True),
2009+
("===1.1", "1.1", True),
2010+
("===1.1", Version("1.01"), True),
2011+
("===1.01", Version("1.01"), False),
2012+
# Prerelease separator normalization
2013+
("===1.a1", "1.a1", True),
2014+
("===1a1", "1.a1", False),
2015+
("===1.a1", "1a1", False),
2016+
("===1a1", "1a1", True),
2017+
("===1.a1", Version("1.a1"), False),
2018+
("===1a1", Version("1.a1"), True),
2019+
# Combined with other operators
2020+
(">=1.0,===1.01", "1.01", True),
2021+
(">=1.0,===1.1", "1.1", True),
2022+
(">=1.0,===1.1", "1.01", False),
2023+
(">=1.0,===1.1", Version("1.01"), True),
2024+
(">=1.0,===1.01", Version("1.01"), False),
2025+
],
2026+
)
2027+
def test_contains_arbitrary_equality_normalization(
2028+
self, spec_str: str, version: str | Version, expected: bool
2029+
) -> None:
2030+
spec = SpecifierSet(spec_str, prereleases=True)
2031+
assert spec.contains(version) == expected
2032+
19202033
@pytest.mark.parametrize(
19212034
("specifier", "expected"),
19222035
[

0 commit comments

Comments
 (0)