Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 202 additions & 38 deletions src/specify_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2576,7 +2576,7 @@ def preset_list():
@preset_app.command("add")
def preset_add(
preset_id: str = typer.Argument(None, help="Preset ID to install from catalog"),
from_url: str = typer.Option(None, "--from", help="Install from a URL (ZIP file)"),
from_url: str = typer.Option(None, "--from", help="Install from a URL (ZIP or .tar.gz/.tgz archive)"),
dev: str = typer.Option(None, "--dev", help="Install from local directory (development mode)"),
priority: int = typer.Option(10, "--priority", help="Resolution priority (lower = higher precedence, default 10)"),
):
Expand Down Expand Up @@ -2629,17 +2629,24 @@ def preset_add(
import urllib.request
import urllib.error
import tempfile
from .extensions import _detect_archive_format as _det_fmt

with tempfile.TemporaryDirectory() as tmpdir:
zip_path = Path(tmpdir) / "preset.zip"
archive_fmt = _det_fmt(from_url)
try:
with urllib.request.urlopen(from_url, timeout=60) as response:
zip_path.write_bytes(response.read())
if not archive_fmt:
content_type = response.headers.get("Content-Type", "")
archive_fmt = _det_fmt(from_url, content_type)
archive_data = response.read()
except urllib.error.URLError as e:
console.print(f"[red]Error:[/red] Failed to download: {e}")
raise typer.Exit(1)

manifest = manager.install_from_zip(zip_path, speckit_version, priority)
suffix = ".tar.gz" if archive_fmt == "tar.gz" else ".zip"
archive_path = Path(tmpdir) / f"preset{suffix}"
archive_path.write_bytes(archive_data)
manifest = manager.install_from_zip(archive_path, speckit_version, priority)

console.print(f"[green]✓[/green] Preset '{manifest.name}' v{manifest.version} installed (priority {priority})")

Expand Down Expand Up @@ -3573,7 +3580,7 @@ def catalog_remove(
def extension_add(
extension: str = typer.Argument(help="Extension name or path"),
dev: bool = typer.Option(False, "--dev", help="Install from local directory"),
from_url: Optional[str] = typer.Option(None, "--from", help="Install from custom URL"),
from_url: Optional[str] = typer.Option(None, "--from", help="Install from custom URL (ZIP or .tar.gz/.tgz archive)"),
priority: int = typer.Option(10, "--priority", help="Resolution priority (lower = higher precedence, default 10)"),
):
"""Install an extension."""
Expand Down Expand Up @@ -3612,10 +3619,11 @@ def extension_add(
manifest = manager.install_from_directory(source_path, speckit_version, priority=priority)

elif from_url:
# Install from URL (ZIP file)
# Install from URL (ZIP or tar.gz archive)
import urllib.request
import urllib.error
from urllib.parse import urlparse
from .extensions import _detect_archive_format

# Validate URL
parsed = urlparse(from_url)
Expand All @@ -3631,25 +3639,32 @@ def extension_add(
console.print("Only install extensions from sources you trust.\n")
console.print(f"Downloading from {from_url}...")

# Download ZIP to temp location
# Download archive to temp location; detect format from URL or Content-Type.
download_dir = project_root / ".specify" / "extensions" / ".cache" / "downloads"
download_dir.mkdir(parents=True, exist_ok=True)
zip_path = download_dir / f"{extension}-url-download.zip"
archive_fmt = _detect_archive_format(from_url)
archive_path = None

try:
with urllib.request.urlopen(from_url, timeout=60) as response:
zip_data = response.read()
zip_path.write_bytes(zip_data)
if not archive_fmt:
content_type = response.headers.get("Content-Type", "")
archive_fmt = _detect_archive_format(from_url, content_type)
archive_data = response.read()

# Install from downloaded ZIP
manifest = manager.install_from_zip(zip_path, speckit_version, priority=priority)
suffix = ".tar.gz" if archive_fmt == "tar.gz" else ".zip"
archive_path = download_dir / f"{extension}-url-download{suffix}"
archive_path.write_bytes(archive_data)

# Install from downloaded archive
manifest = manager.install_from_zip(archive_path, speckit_version, priority=priority)
except urllib.error.URLError as e:
console.print(f"[red]Error:[/red] Failed to download from {from_url}: {e}")
raise typer.Exit(1)
finally:
# Clean up downloaded ZIP
if zip_path.exists():
zip_path.unlink()
# Clean up the downloaded archive
if archive_path is not None and archive_path.exists():
archive_path.unlink()

else:
# Try bundled extensions first (shipped with spec-kit)
Expand Down Expand Up @@ -4303,27 +4318,47 @@ def extension_update(
# 5. Download new version
zip_path = catalog.download_extension(extension_id)
try:
# 6. Validate extension ID from ZIP BEFORE modifying installation
# Handle both root-level and nested extension.yml (GitHub auto-generated ZIPs)
with zipfile.ZipFile(zip_path, "r") as zf:
import yaml
manifest_data = None
namelist = zf.namelist()

# First try root-level extension.yml
if "extension.yml" in namelist:
with zf.open("extension.yml") as f:
manifest_data = yaml.safe_load(f) or {}
else:
# Look for extension.yml in a single top-level subdirectory
# (e.g., "repo-name-branch/extension.yml")
manifest_paths = [n for n in namelist if n.endswith("/extension.yml") and n.count("/") == 1]
if len(manifest_paths) == 1:
with zf.open(manifest_paths[0]) as f:
# 6. Validate extension ID from archive BEFORE modifying installation
# Handle both root-level and nested extension.yml (GitHub auto-generated archives)
from .extensions import _detect_archive_format
import tarfile
archive_fmt = _detect_archive_format(str(zip_path))
import yaml
manifest_data = None

if archive_fmt == "tar.gz":
with tarfile.open(zip_path, "r:gz") as tf:
# First try root-level extension.yml
try:
m = tf.getmember("extension.yml")
f = tf.extractfile(m)
if f is not None:
manifest_data = yaml.safe_load(f.read()) or {}
except KeyError:
# Look for extension.yml in a single top-level subdirectory
members = [m for m in tf.getmembers() if m.name.endswith("/extension.yml") and m.name.count("/") == 1]
if len(members) == 1:
f = tf.extractfile(members[0])
if f is not None:
manifest_data = yaml.safe_load(f.read()) or {}
else:
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()

# First try root-level extension.yml
if "extension.yml" in namelist:
with zf.open("extension.yml") as f:
manifest_data = yaml.safe_load(f) or {}
else:
# Look for extension.yml in a single top-level subdirectory
# (e.g., "repo-name-branch/extension.yml")
manifest_paths = [n for n in namelist if n.endswith("/extension.yml") and n.count("/") == 1]
if len(manifest_paths) == 1:
with zf.open(manifest_paths[0]) as f:
manifest_data = yaml.safe_load(f) or {}

if manifest_data is None:
raise ValueError("Downloaded extension archive is missing 'extension.yml'")
if manifest_data is None:
raise ValueError("Downloaded extension archive is missing 'extension.yml'")

zip_extension_id = manifest_data.get("extension", {}).get("id")
if zip_extension_id != extension_id:
Expand Down Expand Up @@ -4875,6 +4910,57 @@ def workflow_list():
console.print()


def _extract_workflow_yml(archive_path: Path, archive_fmt: str) -> bytes:
"""Extract ``workflow.yml`` from a ZIP or ``.tar.gz`` archive.

Searches the archive root and a single nested top-level subdirectory
(e.g., ``repo-name-1.0/workflow.yml``).

Args:
archive_path: Path to the downloaded archive.
archive_fmt: ``"zip"`` or ``"tar.gz"``.

Returns:
Raw bytes of the ``workflow.yml`` file.

Raises:
ValueError: If no ``workflow.yml`` is found in the archive.
"""
import tarfile

if archive_fmt == "tar.gz":
with tarfile.open(archive_path, "r:gz") as tf:
# Try root-level first.
try:
f = tf.extractfile(tf.getmember("workflow.yml"))
if f is not None:
return f.read()
except KeyError:
pass
# Look in a single top-level subdirectory.
candidates = [
m for m in tf.getmembers()
if m.name.endswith("/workflow.yml") and m.name.count("/") == 1
]
if len(candidates) == 1:
f = tf.extractfile(candidates[0])
if f is not None:
return f.read()
else:
with zipfile.ZipFile(archive_path, "r") as zf:
namelist = zf.namelist()
if "workflow.yml" in namelist:
return zf.read("workflow.yml")
candidates = [
n for n in namelist
if n.endswith("/workflow.yml") and n.count("/") == 1
]
if len(candidates) == 1:
return zf.read(candidates[0])

raise ValueError("No workflow.yml found in the downloaded archive")


@workflow_app.command("add")
def workflow_add(
source: str = typer.Argument(..., help="Workflow ID, URL, or local path"),
Expand Down Expand Up @@ -4928,6 +5014,7 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None:
from ipaddress import ip_address
from urllib.parse import urlparse
from urllib.request import urlopen # noqa: S310
from .extensions import _detect_archive_format

parsed_src = urlparse(source)
src_host = parsed_src.hostname or ""
Expand Down Expand Up @@ -4958,18 +5045,51 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None:
if final_parsed.scheme != "https" and not (final_parsed.scheme == "http" and final_lb):
console.print(f"[red]Error:[/red] URL redirected to non-HTTPS: {final_url}")
raise typer.Exit(1)

# Detect archive format from the final URL or Content-Type header.
archive_fmt = _detect_archive_format(final_url)
if not archive_fmt:
content_type = resp.headers.get("Content-Type", "")
archive_fmt = _detect_archive_format(final_url, content_type)

raw_data = resp.read()
except typer.Exit:
raise
except Exception as exc:
console.print(f"[red]Error:[/red] Failed to download workflow: {exc}")
raise typer.Exit(1)

tmp_path = None
try:
if archive_fmt in ("tar.gz", "zip"):
# Extract workflow.yml from the archive.
suffix = ".tar.gz" if archive_fmt == "tar.gz" else ".zip"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as arc_tmp:
arc_tmp.write(raw_data)
arc_tmp_path = Path(arc_tmp.name)
try:
wf_yaml = _extract_workflow_yml(arc_tmp_path, archive_fmt)
with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as tmp:
tmp.write(wf_yaml)
tmp_path = Path(tmp.name)
finally:
arc_tmp_path.unlink(missing_ok=True)
else:
# Treat as a plain YAML file (existing behaviour).
with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as tmp:
tmp.write(resp.read())
tmp.write(raw_data)
tmp_path = Path(tmp.name)
except typer.Exit:
raise
except Exception as exc:
console.print(f"[red]Error:[/red] Failed to download workflow: {exc}")
console.print(f"[red]Error:[/red] Failed to process downloaded workflow: {exc}")
raise typer.Exit(1)

try:
_validate_and_install_local(tmp_path, source)
finally:
tmp_path.unlink(missing_ok=True)
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
return

# Try as a local file/directory
Expand All @@ -4978,6 +5098,26 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None:
if source_path.is_file() and source_path.suffix in (".yml", ".yaml"):
_validate_and_install_local(source_path, str(source_path))
return
elif source_path.is_file() and (
source.endswith(".tar.gz") or source.endswith(".tgz") or source.endswith(".zip")
):
# Local archive file containing workflow.yml
from .extensions import _detect_archive_format
local_fmt = _detect_archive_format(source)
try:
wf_yaml = _extract_workflow_yml(source_path, local_fmt)
except (ValueError, Exception) as exc:
console.print(f"[red]Error:[/red] Failed to extract workflow from archive: {exc}")
raise typer.Exit(1)
import tempfile
with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as tmp:
tmp.write(wf_yaml)
tmp_local = Path(tmp.name)
try:
_validate_and_install_local(tmp_local, str(source_path))
finally:
tmp_local.unlink(missing_ok=True)
return
elif source_path.is_dir():
wf_file = source_path / "workflow.yml"
if not wf_file.exists():
Expand Down Expand Up @@ -5041,6 +5181,7 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None:

try:
from urllib.request import urlopen # noqa: S310 — URL comes from catalog
from .extensions import _detect_archive_format

workflow_dir.mkdir(parents=True, exist_ok=True)
with urlopen(workflow_url, timeout=30) as response: # noqa: S310
Expand All @@ -5063,7 +5204,30 @@ def _validate_and_install_local(yaml_path: Path, source_label: str) -> None:
f"[red]Error:[/red] Workflow '{source}' redirected to non-HTTPS URL: {final_url}"
)
raise typer.Exit(1)
workflow_file.write_bytes(response.read())

# Detect archive format from the final URL or Content-Type header.
cat_archive_fmt = _detect_archive_format(final_url)
if not cat_archive_fmt:
cat_ct = response.headers.get("Content-Type", "")
cat_archive_fmt = _detect_archive_format(final_url, cat_ct)

raw_response = response.read()

if cat_archive_fmt in ("tar.gz", "zip"):
# Download URL points to an archive — extract workflow.yml from it.
suffix = ".tar.gz" if cat_archive_fmt == "tar.gz" else ".zip"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as arc_f:
arc_f.write(raw_response)
arc_tmp = Path(arc_f.name)
try:
wf_yaml_bytes = _extract_workflow_yml(arc_tmp, cat_archive_fmt)
finally:
arc_tmp.unlink(missing_ok=True)
workflow_file.write_bytes(wf_yaml_bytes)
else:
workflow_file.write_bytes(raw_response)
except typer.Exit:
raise
except Exception as exc:
if workflow_dir.exists():
import shutil
Expand Down
Loading