Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4bb1d1d
Use tomlkit to dump updated dependencies
CasperWA Nov 11, 2023
c3b8df7
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Nov 17, 2023
15c7610
Merge remote-tracking branch 'origin/main' into cwa/close-203-seriali…
CasperWA Nov 22, 2023
7c01a3a
Merge remote-tracking branch 'origin/main' into cwa/close-203-seriali…
CasperWA Dec 7, 2023
af29519
Update test to take the new writing method into account
CasperWA Dec 7, 2023
1331d91
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Dec 7, 2023
7f6a971
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Dec 7, 2023
038d97b
Merge remote-tracking branch 'origin/main' into cwa/close-203-seriali…
CasperWA Dec 14, 2023
287637c
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Feb 26, 2024
c71f38f
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA May 27, 2024
c653b52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
6671e47
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA May 29, 2024
519b6d8
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Jul 29, 2024
880fc43
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Nov 20, 2024
212187e
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Apr 9, 2025
90278c9
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Apr 29, 2025
f2077db
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Jul 29, 2025
98b7168
Merge branch 'main' into cwa/close-203-serialize-via-tomlkit
CasperWA Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions ci_cd/tasks/update_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
parse_ignore_entries,
parse_ignore_rules,
regenerate_requirement,
update_file,
update_specifier_set,
warning_msg,
)
Expand All @@ -44,7 +43,6 @@
# Get logger
LOGGER = logging.getLogger(__name__)


VALID_PACKAGE_NAME_PATTERN = r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$"
"""
Pattern to validate package names.
Expand All @@ -57,8 +55,44 @@
"""


def _update_pyproject(
original_dependency: str, updated_dependency: str, pyproject: tomlkit.TOMLDocument
) -> None:
"""Update dependency in pyproject data structure.

First, check and update the dependency if it is in the "dependencies" group
Then, check and update if it is in any of the "optional-dependencies" groups.

Essentially, we allow for the original dependency to be in multiple groups.
"""
LOGGER.debug(
"Updating pyproject data structure for %r to %r",
original_dependency,
updated_dependency,
)

if original_dependency in pyproject["project"].get("dependencies", []):
index = pyproject["project"]["dependencies"].index(original_dependency)
pyproject["project"]["dependencies"][index] = updated_dependency.replace(
'"', "'"
)

for extra_name, extra_dependencies in (
pyproject["project"].get("optional-dependencies", {}).items()
):
if original_dependency in extra_dependencies:
index = pyproject["project"]["optional-dependencies"][extra_name].index(
original_dependency
)
pyproject["project"]["optional-dependencies"][extra_name][index] = (
updated_dependency.replace('"', "'")
)


def _format_and_update_dependency(
requirement: Requirement, raw_dependency_line: str, pyproject_path: Path
requirement: Requirement,
raw_dependency_line: str,
pyproject: tomlkit.TOMLDocument = None,
) -> None:
"""Regenerate dependency without changing anything but the formatting.

Expand All @@ -72,12 +106,8 @@ def _format_and_update_dependency(
)
LOGGER.debug("Regenerated dependency: %r", updated_dependency)
if updated_dependency != raw_dependency_line:
# Update pyproject.toml since the dependency formatting has changed
LOGGER.debug("Updating pyproject.toml for %r", requirement.name)
update_file(
pyproject_path,
(re.escape(raw_dependency_line), updated_dependency.replace('"', "'")),
)
# Update pyproject data structure since the dependency formatting has changed
_update_pyproject(raw_dependency_line, updated_dependency, pyproject)


@task(
Expand Down Expand Up @@ -192,7 +222,8 @@ def update_deps(
)

# Build the list of dependencies listed in pyproject.toml
dependencies: list[str] = pyproject.get("project", {}).get("dependencies", [])
dependencies: list[str] = []
dependencies.extend(pyproject.get("project", {}).get("dependencies", []))
for optional_deps in (
pyproject.get("project", {}).get("optional-dependencies", {}).values()
):
Expand Down Expand Up @@ -259,9 +290,7 @@ def update_deps(
LOGGER.info(msg)
print(info_msg(msg), flush=True)

_format_and_update_dependency(
parsed_requirement, dependency, pyproject_path
)
_format_and_update_dependency(parsed_requirement, dependency, pyproject)
already_handled_packages.add(parsed_requirement)
continue

Expand All @@ -278,9 +307,7 @@ def update_deps(
LOGGER.warning(msg)
print(warning_msg(msg), flush=True)

_format_and_update_dependency(
parsed_requirement, dependency, pyproject_path
)
_format_and_update_dependency(parsed_requirement, dependency, pyproject)
already_handled_packages.add(parsed_requirement)
continue

Expand Down Expand Up @@ -469,14 +496,8 @@ def update_deps(
)
LOGGER.debug("Updated dependency: %r", updated_dependency)

pattern_sub_line = re.escape(dependency)
replacement_sub_line = updated_dependency.replace('"', "'")
_update_pyproject(dependency, updated_dependency, pyproject)

LOGGER.debug("pattern_sub_line: %s", pattern_sub_line)
LOGGER.debug("replacement_sub_line: %s", replacement_sub_line)

# Update pyproject.toml
update_file(pyproject_path, (pattern_sub_line, replacement_sub_line))
already_handled_packages.add(parsed_requirement)
updated_packages[parsed_requirement.name] = ",".join(
str(_)
Expand All @@ -492,6 +513,9 @@ def update_deps(
f"{Emoji.CROSS_MARK.value} Errors occurred! See printed statements above."
)

# Update pyproject.toml
pyproject_path.write_text(tomlkit.dumps(pyproject), encoding="utf-8")

if updated_packages:
print(
f"{Emoji.PARTY_POPPER.value} Successfully updated the following "
Expand Down
40 changes: 25 additions & 15 deletions tests/tasks/test_update_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,22 @@ def test_skip_unnormalized_python_package_names(
r"statements above\.$"
)

# Due to an atomistic approach, if an error occurs, the pyproject.toml file will
# not be updated.
successful_expected_pyproject_file_data = """[project]
name = "{{ cookiecutter.project_slug }}"
requires-python = ">=3.8"

dependencies = []

[project.optional-dependencies]
dev = ["pytest~=7.4"]
all = ["{{ cookiecutter.project_slug }}[dev]"]
"""
erroneous_expected_pyproject_file_data = pyproject_file_data

if skip_unnormalized_python_package_names:
# This should end in success
update_deps(
context,
root_repo_path=str(tmp_path),
Expand All @@ -1134,7 +1149,13 @@ def test_skip_unnormalized_python_package_names(
terminal_error_msg.search(stdouterr.err) is None
), f"{terminal_error_msg!r} unexpectedly found in {stdouterr.err}"

assert (
pyproject_file.read_text(encoding="utf8")
== successful_expected_pyproject_file_data
)

else:
# This should end in failure
with pytest.raises(SystemExit, match=raise_msg):
update_deps(
context,
Expand Down Expand Up @@ -1167,18 +1188,7 @@ def test_skip_unnormalized_python_package_names(
terminal_error_msg.search(stdouterr.err) is not None
), f"{terminal_error_msg!r} not found in {stdouterr.err}"

# In both cases, the pyproject.toml file should be updated for pytest.
# When/if a more atomistic approach is taken, then this should *NOT* be the case
# for runs where an error occurs.
expected_pyproject_file_data = """[project]
name = "{{ cookiecutter.project_slug }}"
requires-python = ">=3.8"

dependencies = []

[project.optional-dependencies]
dev = ["pytest~=7.4"]
all = ["{{ cookiecutter.project_slug }}[dev]"]
"""

assert pyproject_file.read_text(encoding="utf8") == expected_pyproject_file_data
assert (
pyproject_file.read_text(encoding="utf8")
== erroneous_expected_pyproject_file_data
)
Loading