Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions check_dist/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def load_hatch_config(pyproject_path: str | Path = "pyproject.toml") -> dict:
return config.get("tool", {}).get("hatch", {}).get("build", {})


# ── Copier template defaults ─────────────────────────────────────────

# Per-extension type defaults for sdist/wheel present/absent patterns.
# Keys follow the ``add_extension`` value in ``.copier-answers.yaml``.
_EXTENSION_DEFAULTS: dict[str, dict] = {
Expand Down Expand Up @@ -217,6 +215,39 @@ def _module_name_from_project(project_name: str) -> str:
return re.sub(r"[\s-]+", "_", project_name).strip("_")


def _normalize_name(name: str) -> str:
"""Normalize a name by stripping underscores, hyphens, and periods."""
return re.sub(r"[-_.]+", "", name).lower()


def _resolve_module_from_hatch(module: str, hatch_config: dict) -> str:
"""Resolve the module name from hatch packages configuration.

If any package in the hatch sdist or wheel ``packages`` (or
``only-include``) is equivalent to *module* after normalizing away
underscores, hyphens, and periods, return that package name instead.
This handles projects where the distribution name differs from the
importable package name (e.g. ``jupyter-fs`` vs ``jupyterfs``).
"""
norm = _normalize_name(module)
candidates: list[str] = []
for target in ("sdist", "wheel"):
target_cfg = hatch_config.get("targets", {}).get(target, {})
for key in ("only-include", "packages"):
vals = target_cfg.get(key)
if vals:
candidates.extend(vals)
# Also check top-level packages / only-include
for key in ("only-include", "packages"):
vals = hatch_config.get(key)
if vals:
candidates.extend(vals)
for candidate in candidates:
if _normalize_name(candidate) == norm:
return candidate
return module


def copier_defaults(copier_config: dict, hatch_config: dict | None = None) -> dict | None:
"""Derive default check-dist config from copier answers.

Expand All @@ -240,6 +271,8 @@ def copier_defaults(copier_config: dict, hatch_config: dict | None = None) -> di
return None

module = _module_name_from_project(project_name)
if hatch_config:
module = _resolve_module_from_hatch(module, hatch_config)

sdist_present_extra = list(ext_defaults.get("sdist_present_extra", []))

Expand Down Expand Up @@ -301,9 +334,6 @@ def _filter_extras_by_hatch(extras: list[str], hatch_config: dict) -> list[str]:
return extras


# ── Building ──────────────────────────────────────────────────────────


def build_dists(source_dir: str, output_dir: str, *, no_isolation: bool = False) -> list[str]:
"""Build sdist and wheel into *output_dir*.

Expand Down Expand Up @@ -378,9 +408,6 @@ def _find_pre_built(source_dir: str) -> str | None:
return None


# ── Listing files ─────────────────────────────────────────────────────


def list_sdist_files(sdist_path: str) -> list[str]:
"""List files inside an sdist, stripping the top-level directory."""
files: list[str] = []
Expand All @@ -407,9 +434,6 @@ def list_wheel_files(wheel_path: str) -> list[str]:
return sorted(name for name in zf.namelist() if not name.endswith("/"))


# ── VCS integration ───────────────────────────────────────────────────


def get_vcs_files(source_dir: str) -> list[str]:
"""Return files tracked by git in *source_dir*."""
try:
Expand All @@ -426,9 +450,6 @@ def get_vcs_files(source_dir: str) -> list[str]:
return sorted(f for f in result.stdout.split("\0") if f)


# ── Pattern matching ──────────────────────────────────────────────────


def matches_pattern(filepath: str, pattern: str) -> bool:
"""Check whether *filepath* matches *pattern*.

Expand Down Expand Up @@ -479,9 +500,6 @@ def _matches_hatch_pattern(filepath: str, pattern: str) -> bool:
return False


# ── Checking helpers ──────────────────────────────────────────────────


def check_present(files: list[str], patterns: list[str], dist_type: str) -> list[str]:
"""Return error strings for any *patterns* not found in *files*."""
errors: list[str] = []
Expand Down Expand Up @@ -664,9 +682,6 @@ def check_sdist_vs_vcs(
return errors


# ── Main entry point ──────────────────────────────────────────────────


def check_dist(
source_dir: str = ".",
*,
Expand Down Expand Up @@ -742,7 +757,6 @@ def check_dist(
if not wheel_path:
errors.append("No wheel found in pre-built directory")

# ── sdist checks ─────────────────────────────────────────
if sdist_path:
sdist_files = list_sdist_files(sdist_path)
messages.append(f"\nsdist ({os.path.basename(sdist_path)}) – {len(sdist_files)} file(s):")
Expand All @@ -760,7 +774,6 @@ def check_dist(
errors.extend(check_absent(sdist_files, config["sdist"]["absent"], "sdist", present_patterns=config["sdist"]["present"]))
errors.extend(check_wrong_platform_extensions(sdist_files, "sdist"))

# ── wheel checks ─────────────────────────────────────────
if wheel_path:
wheel_files = list_wheel_files(wheel_path)
messages.append(f"\nwheel ({os.path.basename(wheel_path)}) – {len(wheel_files)} file(s):")
Expand Down
Loading
Loading