Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions hooks/post_gen_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@
import pprint
import subprocess
import sys

# Used indirectly in the below Jinja2 block
from collections import OrderedDict # pylint: disable=unused-import
from collections import OrderedDict
from logging import basicConfig, getLogger
from pathlib import Path

import git
import yaml
from cookiecutter.repository import expand_abbreviations

LOG_FORMAT = json.dumps(
{
Expand All @@ -35,38 +31,31 @@

def get_context() -> dict:
"""Return the context as a dict"""
import git
from cookiecutter.repository import expand_abbreviations

cookiecutter = None
timestamp = datetime.datetime.now(datetime.UTC).isoformat(timespec="seconds")

##############
# This section leverages cookiecutter's jinja interpolation
# pylint: disable-next=unhashable-member
cookiecutter_context_ordered: OrderedDict[str, str] = {{cookiecutter | pprint}} # type: ignore
cookiecutter_context: dict[str, str] = dict(cookiecutter_context_ordered)

project_name = cookiecutter_context["project_slug"] # pylint: disable=unsubscriptable-object
project_description = cookiecutter_context["project_short_description"] # pylint: disable=unsubscriptable-object
template = cookiecutter_context["_template"] # pylint: disable=unsubscriptable-object
output = cookiecutter_context["_output_dir"] # pylint: disable=unsubscriptable-object
##############

try:
if Path(template).is_absolute():
template_path: Path = Path(template).resolve()
else:
output_path: Path = Path(output).resolve()
template_path: Path = output_path.joinpath(template)

# IMPORTANT: If the specified template is remote (http/git/ssh) this SHOULD raise an exception. The remote logic is in the except block
repo: git.Repo = git.Repo(template_path)
project_name = cookiecutter_context["project_slug"]
project_description = cookiecutter_context["project_short_description"]
template = cookiecutter_context["_template"]
output = cookiecutter_context["_output_dir"]
# Get the branch specified via --checkout, but fall back to main
branch = cookiecutter_context.get("_checkout") or "main"

# Expect this is a local template
branch: str = str(repo.active_branch)
dirty: bool = repo.is_dirty(untracked_files=True)
template_commit_hash = git.cmd.Git().ls_remote(template_path, "HEAD")[:40]
except (git.exc.InvalidGitRepositoryError, git.exc.NoSuchPathError):
# This exception handling occurs every time the template repo is remote
# Check if template is a remote URL or abbreviation
is_remote_template = any(
template.startswith(prefix) for prefix in ["http://", "https://", "git@", "gh:", "gl:", "bb:"]
)

if is_remote_template:
# From https://github.com/cookiecutter/cookiecutter/blob/b4451231809fb9e4fc2a1e95d433cb030e4b9e06/cookiecutter/config.py#L22
abbreviations: dict[str, str] = {
"gh": "https://github.com/{0}.git",
Expand All @@ -75,11 +64,36 @@ def get_context() -> dict:
}
template_repo: str = expand_abbreviations(template, abbreviations)

# This currently assumes main until https://github.com/cookiecutter/cookiecutter/issues/1759 is resolved
branch: str = "main"
dirty: bool = False

# For remote templates, get the commit hash from the remote
template_commit_hash = git.cmd.Git().ls_remote(template_repo, branch)[:40]
# Store the expanded URL as the template location
template_location = template_repo
else:
# This is a local template path
if Path(template).is_absolute():
template_path: Path = Path(template).resolve()
else:
output_path: Path = Path(output).resolve()
template_path: Path = output_path.joinpath(template).resolve()

try:
repo: git.Repo = git.Repo(template_path)

# Get info from the local repository
branch: str = str(repo.active_branch)
dirty: bool = repo.is_dirty(untracked_files=True)
# Get the actual commit hash from the local repository
template_commit_hash = repo.head.commit.hexsha
Comment thread
JonZeolla marked this conversation as resolved.
# Store the fully qualified template path for local templates
template_location = str(template_path)
except (git.exc.InvalidGitRepositoryError, git.exc.NoSuchPathError):
# Not a git repository, fall back to unknown values
branch = "unknown"
dirty = False
template_commit_hash = "unknown"
template_location = str(template_path)

context: dict[str, str | dict[str, str | bool | dict[str, str | bool | dict[str, str]]]] = {}
context["name"] = project_name
Expand All @@ -91,12 +105,12 @@ def get_context() -> dict:
context["origin"]["template"]["branch"] = branch
context["origin"]["template"]["commit hash"] = template_commit_hash
context["origin"]["template"]["dirty"] = dirty
context["origin"]["template"]["location"] = template
context["origin"]["template"]["location"] = template_location
context["origin"]["template"]["cookiecutter"] = {}
context["origin"]["template"]["cookiecutter"] = cookiecutter_context

# Filter out unwanted cookiecutter context
del cookiecutter_context["_output_dir"] # pylint: disable=unsubscriptable-object
del cookiecutter_context["_output_dir"]

return context

Expand Down
Loading