Skip to content

Commit fb13b3a

Browse files
authored
locker: refactor dependency walk logic
Resolves: #5141
1 parent eb27f81 commit fb13b3a

4 files changed

Lines changed: 586 additions & 193 deletions

File tree

src/poetry/packages/locker.py

Lines changed: 78 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333

3434
if TYPE_CHECKING:
35+
from poetry.core.semver.version_constraint import VersionConstraint
36+
from poetry.core.version.markers import BaseMarker
3537
from tomlkit.items import InlineTable
3638
from tomlkit.toml_document import TOMLDocument
3739

@@ -203,152 +205,130 @@ def locked_repository(self, with_dev_reqs: bool = False) -> Repository:
203205

204206
@staticmethod
205207
def __get_locked_package(
206-
_dependency: Dependency, packages_by_name: dict[str, list[Package]]
208+
dependency: Dependency,
209+
packages_by_name: dict[str, list[Package]],
210+
decided: dict[Package, Dependency] | None = None,
207211
) -> Package | None:
208212
"""
209213
Internal helper to identify corresponding locked package using dependency
210214
version constraints.
211215
"""
212-
for _package in packages_by_name.get(_dependency.name, []):
213-
if _dependency.constraint.allows(_package.version):
214-
return _package
215-
return None
216+
decided = decided or {}
217+
218+
# Get the packages that are consistent with this dependency.
219+
packages = [
220+
package
221+
for package in packages_by_name.get(dependency.name, [])
222+
if package.python_constraint.allows_all(dependency.python_constraint)
223+
and dependency.constraint.allows(package.version)
224+
]
225+
226+
# If we've previously made a choice that is compatible with the current
227+
# requirement, stick with it.
228+
for package in packages:
229+
old_decision = decided.get(package)
230+
if (
231+
old_decision is not None
232+
and not old_decision.marker.intersect(dependency.marker).is_empty()
233+
):
234+
return package
235+
236+
return next(iter(packages), None)
216237

217238
@classmethod
218-
def __walk_dependency_level(
239+
def __walk_dependencies(
219240
cls,
220241
dependencies: list[Dependency],
221-
level: int,
222-
pinned_versions: bool,
223242
packages_by_name: dict[str, list[Package]],
224-
project_level_dependencies: set[str],
225-
nested_dependencies: dict[tuple[str, str], Dependency],
226-
) -> dict[tuple[str, str], Dependency]:
227-
if not dependencies:
228-
return nested_dependencies
229-
230-
next_level_dependencies = []
243+
) -> dict[Package, Dependency]:
244+
nested_dependencies: dict[Package, Dependency] = {}
231245

232-
for requirement in dependencies:
233-
key = (requirement.name, requirement.pretty_constraint)
234-
locked_package = cls.__get_locked_package(requirement, packages_by_name)
235-
236-
if locked_package:
237-
# create dependency from locked package to retain dependency metadata
238-
# if this is not done, we can end-up with incorrect nested dependencies
239-
constraint = requirement.constraint
240-
pretty_constraint = requirement.pretty_constraint
241-
marker = requirement.marker
242-
requirement = locked_package.to_dependency()
243-
requirement.marker = requirement.marker.intersect(marker)
244-
245-
key = (requirement.name, pretty_constraint)
246+
visited: set[tuple[Dependency, BaseMarker]] = set()
247+
while dependencies:
248+
requirement = dependencies.pop(0)
249+
if (requirement, requirement.marker) in visited:
250+
continue
251+
visited.add((requirement, requirement.marker))
246252

247-
if not pinned_versions:
248-
requirement.set_constraint(constraint)
253+
locked_package = cls.__get_locked_package(
254+
requirement, packages_by_name, nested_dependencies
255+
)
249256

250-
for require in locked_package.requires:
251-
if require.marker.is_empty():
252-
require.marker = requirement.marker
253-
else:
254-
require.marker = require.marker.intersect(requirement.marker)
257+
if not locked_package:
258+
raise RuntimeError(f"Dependency walk failed at {requirement}")
255259

256-
require.marker = require.marker.intersect(locked_package.marker)
260+
# create dependency from locked package to retain dependency metadata
261+
# if this is not done, we can end-up with incorrect nested dependencies
262+
constraint = requirement.constraint
263+
marker = requirement.marker
264+
extras = requirement.extras
265+
requirement = locked_package.to_dependency()
266+
requirement.marker = requirement.marker.intersect(marker)
257267

258-
if key not in nested_dependencies:
259-
next_level_dependencies.append(require)
268+
requirement.set_constraint(constraint)
260269

261-
if requirement.name in project_level_dependencies and level == 0:
262-
# project level dependencies take precedence
263-
continue
270+
for require in locked_package.requires:
271+
if require.in_extras and extras.isdisjoint(require.in_extras):
272+
continue
264273

265-
if not locked_package:
266-
# we make a copy to avoid any side-effects
267-
requirement = deepcopy(requirement)
274+
require = deepcopy(require)
275+
require.marker = require.marker.intersect(
276+
requirement.marker.without_extras()
277+
)
278+
if not require.marker.is_empty():
279+
dependencies.append(require)
268280

281+
key = locked_package
269282
if key not in nested_dependencies:
270283
nested_dependencies[key] = requirement
271284
else:
272285
nested_dependencies[key].marker = nested_dependencies[key].marker.union(
273286
requirement.marker
274287
)
275288

276-
return cls.__walk_dependency_level(
277-
dependencies=next_level_dependencies,
278-
level=level + 1,
279-
pinned_versions=pinned_versions,
280-
packages_by_name=packages_by_name,
281-
project_level_dependencies=project_level_dependencies,
282-
nested_dependencies=nested_dependencies,
283-
)
289+
return nested_dependencies
284290

285291
@classmethod
286292
def get_project_dependencies(
287293
cls,
288294
project_requires: list[Dependency],
289295
locked_packages: list[Package],
290-
pinned_versions: bool = False,
291-
with_nested: bool = False,
292-
) -> Iterable[Dependency]:
296+
) -> Iterable[tuple[Package, Dependency]]:
293297
# group packages entries by name, this is required because requirement might use
294-
# different constraints
298+
# different constraints.
295299
packages_by_name: dict[str, list[Package]] = {}
296300
for pkg in locked_packages:
297301
if pkg.name not in packages_by_name:
298302
packages_by_name[pkg.name] = []
299303
packages_by_name[pkg.name].append(pkg)
300304

301-
project_level_dependencies = set()
302-
dependencies = []
303-
304-
for dependency in project_requires:
305-
dependency = deepcopy(dependency)
306-
locked_package = cls.__get_locked_package(dependency, packages_by_name)
307-
if locked_package:
308-
locked_dependency = locked_package.to_dependency()
309-
locked_dependency.marker = dependency.marker.intersect(
310-
locked_package.marker
311-
)
312-
313-
if not pinned_versions:
314-
locked_dependency.set_constraint(dependency.constraint)
315-
316-
dependency = locked_dependency
317-
318-
project_level_dependencies.add(dependency.name)
319-
dependencies.append(dependency)
320-
321-
if not with_nested:
322-
# return only with project level dependencies
323-
return dependencies
305+
# Put higher versions first so that we prefer them.
306+
for packages in packages_by_name.values():
307+
packages.sort(key=lambda package: package.version, reverse=True)
324308

325-
nested_dependencies = cls.__walk_dependency_level(
326-
dependencies=dependencies,
327-
level=0,
328-
pinned_versions=pinned_versions,
309+
nested_dependencies = cls.__walk_dependencies(
310+
dependencies=project_requires,
329311
packages_by_name=packages_by_name,
330-
project_level_dependencies=project_level_dependencies,
331-
nested_dependencies={},
332312
)
333313

334-
# Merge same dependencies using marker union
335-
for requirement in dependencies:
336-
key = (requirement.name, requirement.pretty_constraint)
337-
if key not in nested_dependencies:
338-
nested_dependencies[key] = requirement
339-
else:
340-
nested_dependencies[key].marker = nested_dependencies[key].marker.union(
341-
requirement.marker
342-
)
343-
344-
return sorted(nested_dependencies.values(), key=lambda x: x.name.lower())
314+
return nested_dependencies.items()
345315

346316
def get_project_dependency_packages(
347317
self,
348318
project_requires: list[Dependency],
319+
project_python_marker: VersionConstraint | None = None,
349320
dev: bool = False,
350321
extras: bool | Sequence[str] | None = None,
351322
) -> Iterator[DependencyPackage]:
323+
# Apply the project python marker to all requirements.
324+
if project_python_marker is not None:
325+
marked_requires: list[Dependency] = []
326+
for require in project_requires:
327+
require = deepcopy(require)
328+
require.marker = require.marker.intersect(project_python_marker)
329+
marked_requires.append(require)
330+
project_requires = marked_requires
331+
352332
repository = self.locked_repository(with_dev_reqs=dev)
353333

354334
# Build a set of all packages required by our selected extras
@@ -379,16 +359,10 @@ def get_project_dependency_packages(
379359

380360
selected.append(dependency)
381361

382-
for dependency in self.get_project_dependencies(
362+
for package, dependency in self.get_project_dependencies(
383363
project_requires=selected,
384364
locked_packages=repository.packages,
385-
with_nested=True,
386365
):
387-
try:
388-
package = repository.find_packages(dependency=dependency)[0]
389-
except IndexError:
390-
continue
391-
392366
for extra in dependency.extras:
393367
package.requires_extras.append(extra)
394368

src/poetry/utils/exporter.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import itertools
43
import urllib.parse
54

65
from typing import TYPE_CHECKING
@@ -70,21 +69,22 @@ def _export_requirements_txt(
7069
content = ""
7170
dependency_lines = set()
7271

73-
for package, groups in itertools.groupby(
74-
self._poetry.locker.get_project_dependency_packages(
75-
project_requires=self._poetry.package.all_requires,
76-
dev=dev,
77-
extras=extras,
78-
),
79-
lambda dependency_package: dependency_package.package,
72+
# Get project dependencies.
73+
root_package = (
74+
self._poetry.package.clone()
75+
if dev
76+
else self._poetry.package.with_dependency_groups(["default"], only=True)
77+
)
78+
79+
for dependency_package in self._poetry.locker.get_project_dependency_packages(
80+
project_requires=root_package.all_requires,
81+
project_python_marker=root_package.python_marker,
82+
dev=dev,
83+
extras=extras,
8084
):
8185
line = ""
82-
dependency_packages = list(groups)
83-
dependency = dependency_packages[0].dependency
84-
marker = dependency.marker
85-
for dep_package in dependency_packages[1:]:
86-
marker = marker.union(dep_package.dependency.marker)
87-
dependency.marker = marker
86+
dependency = dependency_package.dependency
87+
package = dependency_package.package
8888

8989
if package.develop:
9090
line += "-e "

tests/console/commands/test_export.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def _export_requirements(tester: CommandTester, poetry: Poetry) -> None:
8484
assert poetry.locker.lock.exists()
8585

8686
expected = """\
87-
foo==1.0.0
87+
foo==1.0.0 ;\
88+
python_version >= "2.7" and python_version < "2.8" or\
89+
python_version >= "3.4" and python_version < "4.0"
8890
"""
8991

9092
assert content == expected
@@ -113,7 +115,9 @@ def test_export_fails_on_invalid_format(tester: CommandTester, do_lock: None):
113115
def test_export_prints_to_stdout_by_default(tester: CommandTester, do_lock: None):
114116
tester.execute("--format requirements.txt")
115117
expected = """\
116-
foo==1.0.0
118+
foo==1.0.0 ;\
119+
python_version >= "2.7" and python_version < "2.8" or\
120+
python_version >= "3.4" and python_version < "4.0"
117121
"""
118122
assert tester.io.fetch_output() == expected
119123

@@ -123,16 +127,22 @@ def test_export_uses_requirements_txt_format_by_default(
123127
):
124128
tester.execute()
125129
expected = """\
126-
foo==1.0.0
130+
foo==1.0.0 ;\
131+
python_version >= "2.7" and python_version < "2.8" or\
132+
python_version >= "3.4" and python_version < "4.0"
127133
"""
128134
assert tester.io.fetch_output() == expected
129135

130136

131137
def test_export_includes_extras_by_flag(tester: CommandTester, do_lock: None):
132138
tester.execute("--format requirements.txt --extras feature_bar")
133139
expected = """\
134-
bar==1.1.0
135-
foo==1.0.0
140+
bar==1.1.0 ;\
141+
python_version >= "2.7" and python_version < "2.8" or\
142+
python_version >= "3.4" and python_version < "4.0"
143+
foo==1.0.0 ;\
144+
python_version >= "2.7" and python_version < "2.8" or\
145+
python_version >= "3.4" and python_version < "4.0"
136146
"""
137147
assert tester.io.fetch_output() == expected
138148

0 commit comments

Comments
 (0)