diff --git a/ci_cd/tasks/update_deps.py b/ci_cd/tasks/update_deps.py index 64c2e0fb..3a2d7a17 100644 --- a/ci_cd/tasks/update_deps.py +++ b/ci_cd/tasks/update_deps.py @@ -30,7 +30,6 @@ parse_ignore_entries, parse_ignore_rules, regenerate_requirement, - update_file, update_specifier_set, warning_msg, ) @@ -44,7 +43,6 @@ # Get logger LOGGER = logging.getLogger(__name__) - VALID_PACKAGE_NAME_PATTERN = r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$" """ Pattern to validate package names. @@ -57,8 +55,44 @@ """ +def _update_pyproject( + original_dependency: str, updated_dependency: str, pyproject: tomlkit.TOMLDocument +) -> None: + """Update dependency in pyproject data structure. + + First, check and update the dependency if it is in the "dependencies" group + Then, check and update if it is in any of the "optional-dependencies" groups. + + Essentially, we allow for the original dependency to be in multiple groups. + """ + LOGGER.debug( + "Updating pyproject data structure for %r to %r", + original_dependency, + updated_dependency, + ) + + if original_dependency in pyproject["project"].get("dependencies", []): + index = pyproject["project"]["dependencies"].index(original_dependency) + pyproject["project"]["dependencies"][index] = updated_dependency.replace( + '"', "'" + ) + + for extra_name, extra_dependencies in ( + pyproject["project"].get("optional-dependencies", {}).items() + ): + if original_dependency in extra_dependencies: + index = pyproject["project"]["optional-dependencies"][extra_name].index( + original_dependency + ) + pyproject["project"]["optional-dependencies"][extra_name][index] = ( + updated_dependency.replace('"', "'") + ) + + def _format_and_update_dependency( - requirement: Requirement, raw_dependency_line: str, pyproject_path: Path + requirement: Requirement, + raw_dependency_line: str, + pyproject: tomlkit.TOMLDocument = None, ) -> None: """Regenerate dependency without changing anything but the formatting. @@ -72,12 +106,8 @@ def _format_and_update_dependency( ) LOGGER.debug("Regenerated dependency: %r", updated_dependency) if updated_dependency != raw_dependency_line: - # Update pyproject.toml since the dependency formatting has changed - LOGGER.debug("Updating pyproject.toml for %r", requirement.name) - update_file( - pyproject_path, - (re.escape(raw_dependency_line), updated_dependency.replace('"', "'")), - ) + # Update pyproject data structure since the dependency formatting has changed + _update_pyproject(raw_dependency_line, updated_dependency, pyproject) @task( @@ -192,7 +222,8 @@ def update_deps( ) # Build the list of dependencies listed in pyproject.toml - dependencies: list[str] = pyproject.get("project", {}).get("dependencies", []) + dependencies: list[str] = [] + dependencies.extend(pyproject.get("project", {}).get("dependencies", [])) for optional_deps in ( pyproject.get("project", {}).get("optional-dependencies", {}).values() ): @@ -259,9 +290,7 @@ def update_deps( LOGGER.info(msg) print(info_msg(msg), flush=True) - _format_and_update_dependency( - parsed_requirement, dependency, pyproject_path - ) + _format_and_update_dependency(parsed_requirement, dependency, pyproject) already_handled_packages.add(parsed_requirement) continue @@ -278,9 +307,7 @@ def update_deps( LOGGER.warning(msg) print(warning_msg(msg), flush=True) - _format_and_update_dependency( - parsed_requirement, dependency, pyproject_path - ) + _format_and_update_dependency(parsed_requirement, dependency, pyproject) already_handled_packages.add(parsed_requirement) continue @@ -469,14 +496,8 @@ def update_deps( ) LOGGER.debug("Updated dependency: %r", updated_dependency) - pattern_sub_line = re.escape(dependency) - replacement_sub_line = updated_dependency.replace('"', "'") + _update_pyproject(dependency, updated_dependency, pyproject) - LOGGER.debug("pattern_sub_line: %s", pattern_sub_line) - LOGGER.debug("replacement_sub_line: %s", replacement_sub_line) - - # Update pyproject.toml - update_file(pyproject_path, (pattern_sub_line, replacement_sub_line)) already_handled_packages.add(parsed_requirement) updated_packages[parsed_requirement.name] = ",".join( str(_) @@ -492,6 +513,9 @@ def update_deps( f"{Emoji.CROSS_MARK.value} Errors occurred! See printed statements above." ) + # Update pyproject.toml + pyproject_path.write_text(tomlkit.dumps(pyproject), encoding="utf-8") + if updated_packages: print( f"{Emoji.PARTY_POPPER.value} Successfully updated the following " diff --git a/tests/tasks/test_update_deps.py b/tests/tasks/test_update_deps.py index 92ed4403..49190c76 100644 --- a/tests/tasks/test_update_deps.py +++ b/tests/tasks/test_update_deps.py @@ -1110,7 +1110,22 @@ def test_skip_unnormalized_python_package_names( r"statements above\.$" ) + # Due to an atomistic approach, if an error occurs, the pyproject.toml file will + # not be updated. + successful_expected_pyproject_file_data = """[project] +name = "{{ cookiecutter.project_slug }}" +requires-python = ">=3.8" + +dependencies = [] + +[project.optional-dependencies] +dev = ["pytest~=7.4"] +all = ["{{ cookiecutter.project_slug }}[dev]"] +""" + erroneous_expected_pyproject_file_data = pyproject_file_data + if skip_unnormalized_python_package_names: + # This should end in success update_deps( context, root_repo_path=str(tmp_path), @@ -1134,7 +1149,13 @@ def test_skip_unnormalized_python_package_names( terminal_error_msg.search(stdouterr.err) is None ), f"{terminal_error_msg!r} unexpectedly found in {stdouterr.err}" + assert ( + pyproject_file.read_text(encoding="utf8") + == successful_expected_pyproject_file_data + ) + else: + # This should end in failure with pytest.raises(SystemExit, match=raise_msg): update_deps( context, @@ -1167,18 +1188,7 @@ def test_skip_unnormalized_python_package_names( terminal_error_msg.search(stdouterr.err) is not None ), f"{terminal_error_msg!r} not found in {stdouterr.err}" - # In both cases, the pyproject.toml file should be updated for pytest. - # When/if a more atomistic approach is taken, then this should *NOT* be the case - # for runs where an error occurs. - expected_pyproject_file_data = """[project] -name = "{{ cookiecutter.project_slug }}" -requires-python = ">=3.8" - -dependencies = [] - -[project.optional-dependencies] -dev = ["pytest~=7.4"] -all = ["{{ cookiecutter.project_slug }}[dev]"] -""" - - assert pyproject_file.read_text(encoding="utf8") == expected_pyproject_file_data + assert ( + pyproject_file.read_text(encoding="utf8") + == erroneous_expected_pyproject_file_data + )