Skip to content

Commit 9f63a9c

Browse files
iamaeroplaneclaude
andcommitted
fix(extensions): address fourth round of PR review comments
Rollback fixes: - Preserve installed_at timestamp after successful update (was reset by install_from_zip calling registry.add) - Fix rollback to only delete extension_dir if backup exists (avoids destroying valid installation when failure happens before modification) - Fix rollback to remove NEW command files created by failed install (files that weren't in original backup are now cleaned up) - Fix rollback to delete hooks key entirely when backup_hooks is None (original config had no hooks key, so restore should remove it) Cross-command consistency fix: - Add display name resolution to `extension add` command using _resolve_catalog_extension() helper (was only in `extension info`) - Use resolved extension ID for download_extension() call, not original argument which may be a display name Security fix (fail-closed): - Malformed catalog config (empty/missing URLs) now raises ValidationError instead of silently falling back to built-in catalogs Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 50936c9 commit 9f63a9c

File tree

3 files changed

+183
-50
lines changed

3 files changed

+183
-50
lines changed

src/specify_cli/__init__.py

Lines changed: 82 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,8 +2219,11 @@ def extension_add(
22192219
# Install from catalog
22202220
catalog = ExtensionCatalog(project_root)
22212221

2222-
# Check if extension exists in catalog
2223-
ext_info = catalog.get_extension_info(extension)
2222+
# Check if extension exists in catalog (supports both ID and display name)
2223+
ext_info, catalog_error = _resolve_catalog_extension(extension, catalog, "add")
2224+
if catalog_error:
2225+
console.print(f"[red]Error:[/red] Could not query extension catalog: {catalog_error}")
2226+
raise typer.Exit(1)
22242227
if not ext_info:
22252228
console.print(f"[red]Error:[/red] Extension '{extension}' not found in catalog")
22262229
console.print("\nSearch available extensions:")
@@ -2240,9 +2243,10 @@ def extension_add(
22402243
)
22412244
raise typer.Exit(1)
22422245

2243-
# Download extension ZIP
2246+
# Download extension ZIP (use resolved ID, not original argument which may be display name)
2247+
extension_id = ext_info['id']
22442248
console.print(f"Downloading {ext_info['name']} v{ext_info.get('version', 'unknown')}...")
2245-
zip_path = catalog.download_extension(extension)
2249+
zip_path = catalog.download_extension(extension_id)
22462250

22472251
try:
22482252
# Install from downloaded ZIP
@@ -2797,21 +2801,29 @@ def extension_update(
27972801
# 8. Install new version
27982802
_ = manager.install_from_zip(zip_path, speckit_version)
27992803

2800-
# 9. Restore enabled state from backup
2801-
# If extension was disabled before update, disable it again
2802-
if backup_registry_entry and not backup_registry_entry.get("enabled", True):
2804+
# 9. Restore metadata from backup (installed_at, enabled state)
2805+
if backup_registry_entry:
28032806
new_metadata = manager.registry.get(extension_id)
2804-
new_metadata["enabled"] = False
2807+
2808+
# Preserve the original installation timestamp
2809+
if "installed_at" in backup_registry_entry:
2810+
new_metadata["installed_at"] = backup_registry_entry["installed_at"]
2811+
2812+
# If extension was disabled before update, disable it again
2813+
if not backup_registry_entry.get("enabled", True):
2814+
new_metadata["enabled"] = False
2815+
28052816
manager.registry.update(extension_id, new_metadata)
28062817

2807-
# Also disable hooks in extensions.yml to match
2808-
config = hook_executor.get_project_config()
2809-
if "hooks" in config:
2810-
for hook_name in config["hooks"]:
2811-
for hook in config["hooks"][hook_name]:
2812-
if hook.get("extension") == extension_id:
2813-
hook["enabled"] = False
2814-
hook_executor.save_project_config(config)
2818+
# Also disable hooks in extensions.yml if extension was disabled
2819+
if not backup_registry_entry.get("enabled", True):
2820+
config = hook_executor.get_project_config()
2821+
if "hooks" in config:
2822+
for hook_name in config["hooks"]:
2823+
for hook in config["hooks"][hook_name]:
2824+
if hook.get("extension") == extension_id:
2825+
hook["enabled"] = False
2826+
hook_executor.save_project_config(config)
28152827
finally:
28162828
# Clean up downloaded ZIP
28172829
if zip_path.exists():
@@ -2835,16 +2847,41 @@ def extension_update(
28352847

28362848
try:
28372849
# Restore extension directory
2850+
# Only perform destructive rollback if backup exists (meaning we
2851+
# actually modified the extension). This avoids deleting a valid
2852+
# installation when failure happened before changes were made.
28382853
extension_dir = manager.extensions_dir / extension_id
2839-
# Always remove any existing directory (from failed update)
2840-
if extension_dir.exists():
2841-
shutil.rmtree(extension_dir)
2842-
# Restore from backup if it exists; otherwise leave absent
2843-
# (matching pre-update state when there was no original dir)
28442854
if backup_ext_dir.exists():
2855+
if extension_dir.exists():
2856+
shutil.rmtree(extension_dir)
28452857
shutil.copytree(backup_ext_dir, extension_dir)
28462858

2847-
# Restore command files
2859+
# Remove any NEW command files created by failed install
2860+
# (files that weren't in the original backup)
2861+
try:
2862+
new_registry_entry = manager.registry.get(extension_id)
2863+
new_registered_commands = new_registry_entry.get("registered_commands", {})
2864+
for agent_name, cmd_names in new_registered_commands.items():
2865+
if agent_name not in registrar.AGENT_CONFIGS:
2866+
continue
2867+
agent_config = registrar.AGENT_CONFIGS[agent_name]
2868+
commands_dir = project_root / agent_config["dir"]
2869+
2870+
for cmd_name in cmd_names:
2871+
cmd_file = commands_dir / f"{cmd_name}{agent_config['extension']}"
2872+
# Delete if it exists and wasn't in our backup
2873+
if cmd_file.exists() and str(cmd_file) not in backed_up_command_files:
2874+
cmd_file.unlink()
2875+
2876+
# Also handle copilot prompt files
2877+
if agent_name == "copilot":
2878+
prompt_file = project_root / ".github" / "prompts" / f"{cmd_name}.prompt.md"
2879+
if prompt_file.exists() and str(prompt_file) not in backed_up_command_files:
2880+
prompt_file.unlink()
2881+
except KeyError:
2882+
pass # No new registry entry exists, nothing to clean up
2883+
2884+
# Restore backed up command files
28482885
for original_path, backup_path in backed_up_command_files.items():
28492886
backup_file = Path(backup_path)
28502887
if backup_file.exists():
@@ -2857,24 +2894,30 @@ def extension_update(
28572894
# - backup_hooks={} or {...} means config had hooks key
28582895
config = hook_executor.get_project_config()
28592896
if "hooks" in config:
2860-
# Remove any hooks for this extension added by failed install
28612897
modified = False
2862-
for hook_name, hooks_list in config["hooks"].items():
2863-
original_len = len(hooks_list)
2864-
config["hooks"][hook_name] = [
2865-
h for h in hooks_list
2866-
if h.get("extension") != extension_id
2867-
]
2868-
if len(config["hooks"][hook_name]) != original_len:
2869-
modified = True
2870-
2871-
# Add back the backed up hooks if any
2872-
if backup_hooks:
2873-
for hook_name, hooks in backup_hooks.items():
2874-
if hook_name not in config["hooks"]:
2875-
config["hooks"][hook_name] = []
2876-
config["hooks"][hook_name].extend(hooks)
2877-
modified = True
2898+
2899+
if backup_hooks is None:
2900+
# Original config had no "hooks" key; remove it entirely
2901+
del config["hooks"]
2902+
modified = True
2903+
else:
2904+
# Remove any hooks for this extension added by failed install
2905+
for hook_name, hooks_list in config["hooks"].items():
2906+
original_len = len(hooks_list)
2907+
config["hooks"][hook_name] = [
2908+
h for h in hooks_list
2909+
if h.get("extension") != extension_id
2910+
]
2911+
if len(config["hooks"][hook_name]) != original_len:
2912+
modified = True
2913+
2914+
# Add back the backed up hooks if any
2915+
if backup_hooks:
2916+
for hook_name, hooks in backup_hooks.items():
2917+
if hook_name not in config["hooks"]:
2918+
config["hooks"][hook_name] = []
2919+
config["hooks"][hook_name].extend(hooks)
2920+
modified = True
28782921

28792922
if modified:
28802923
hook_executor.save_project_config(config)

src/specify_cli/extensions.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,12 +1149,13 @@ def _load_catalog_config(self, config_path: Path) -> Optional[List[CatalogEntry]
11491149
config_path: Path to extension-catalogs.yml
11501150
11511151
Returns:
1152-
Ordered list of CatalogEntry objects, or None if file doesn't exist
1153-
or contains no valid catalog entries.
1152+
Ordered list of CatalogEntry objects, or None if file doesn't exist.
11541153
11551154
Raises:
11561155
ValidationError: If any catalog entry has an invalid URL,
1157-
the file cannot be parsed, or a priority value is invalid.
1156+
the file cannot be parsed, a priority value is invalid,
1157+
or the file exists but contains no valid catalog entries
1158+
(fail-closed for security).
11581159
"""
11591160
if not config_path.exists():
11601161
return None
@@ -1166,19 +1167,25 @@ def _load_catalog_config(self, config_path: Path) -> Optional[List[CatalogEntry]
11661167
)
11671168
catalogs_data = data.get("catalogs", [])
11681169
if not catalogs_data:
1169-
return None
1170+
# File exists but has no catalogs key or empty list - fail closed
1171+
raise ValidationError(
1172+
f"Catalog config {config_path} exists but contains no 'catalogs' entries. "
1173+
f"Remove the file to use built-in defaults, or add valid catalog entries."
1174+
)
11701175
if not isinstance(catalogs_data, list):
11711176
raise ValidationError(
11721177
f"Invalid catalog config: 'catalogs' must be a list, got {type(catalogs_data).__name__}"
11731178
)
11741179
entries: List[CatalogEntry] = []
1180+
skipped_entries: List[int] = []
11751181
for idx, item in enumerate(catalogs_data):
11761182
if not isinstance(item, dict):
11771183
raise ValidationError(
11781184
f"Invalid catalog entry at index {idx}: expected a mapping, got {type(item).__name__}"
11791185
)
11801186
url = str(item.get("url", "")).strip()
11811187
if not url:
1188+
skipped_entries.append(idx)
11821189
continue
11831190
self._validate_catalog_url(url)
11841191
try:
@@ -1201,7 +1208,14 @@ def _load_catalog_config(self, config_path: Path) -> Optional[List[CatalogEntry]
12011208
description=str(item.get("description", "")),
12021209
))
12031210
entries.sort(key=lambda e: e.priority)
1204-
return entries if entries else None
1211+
if not entries:
1212+
# All entries were invalid (missing URLs) - fail closed for security
1213+
raise ValidationError(
1214+
f"Catalog config {config_path} contains {len(catalogs_data)} entries but none have valid URLs "
1215+
f"(entries at indices {skipped_entries} were skipped). "
1216+
f"Each catalog entry must have a 'url' field."
1217+
)
1218+
return entries
12051219

12061220
def get_active_catalogs(self) -> List[CatalogEntry]:
12071221
"""Get the ordered list of active catalogs.

tests/test_extensions.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,8 +1493,8 @@ def test_project_config_invalid_url_raises(self, temp_dir):
14931493
with pytest.raises(ValidationError, match="HTTPS"):
14941494
catalog.get_active_catalogs()
14951495

1496-
def test_empty_project_config_falls_back_to_defaults(self, temp_dir):
1497-
"""Empty catalogs list in config falls back to default stack."""
1496+
def test_empty_project_config_raises_error(self, temp_dir):
1497+
"""Empty catalogs list in config raises ValidationError (fail-closed for security)."""
14981498
import yaml as yaml_module
14991499

15001500
project_dir = self._make_project(temp_dir)
@@ -1503,11 +1503,32 @@ def test_empty_project_config_falls_back_to_defaults(self, temp_dir):
15031503
yaml_module.dump({"catalogs": []}, f)
15041504

15051505
catalog = ExtensionCatalog(project_dir)
1506-
entries = catalog.get_active_catalogs()
15071506

1508-
# Falls back to default stack
1509-
assert len(entries) == 2
1510-
assert entries[0].url == ExtensionCatalog.DEFAULT_CATALOG_URL
1507+
# Fail-closed: empty config should raise, not fall back to defaults
1508+
with pytest.raises(ValidationError) as exc_info:
1509+
catalog.get_active_catalogs()
1510+
assert "contains no 'catalogs' entries" in str(exc_info.value)
1511+
1512+
def test_catalog_entries_without_urls_raises_error(self, temp_dir):
1513+
"""Catalog entries without URLs raise ValidationError (fail-closed for security)."""
1514+
import yaml as yaml_module
1515+
1516+
project_dir = self._make_project(temp_dir)
1517+
config_path = project_dir / ".specify" / "extension-catalogs.yml"
1518+
with open(config_path, "w") as f:
1519+
yaml_module.dump({
1520+
"catalogs": [
1521+
{"name": "no-url-catalog", "priority": 1},
1522+
{"name": "another-no-url", "description": "Also missing URL"},
1523+
]
1524+
}, f)
1525+
1526+
catalog = ExtensionCatalog(project_dir)
1527+
1528+
# Fail-closed: entries without URLs should raise, not fall back to defaults
1529+
with pytest.raises(ValidationError) as exc_info:
1530+
catalog.get_active_catalogs()
1531+
assert "none have valid URLs" in str(exc_info.value)
15111532

15121533
# --- _load_catalog_config ---
15131534

@@ -2034,3 +2055,58 @@ def test_extensionignore_negation_pattern(self, temp_dir, valid_manifest_data):
20342055
assert not (dest / "docs" / "guide.md").exists()
20352056
assert not (dest / "docs" / "internal.md").exists()
20362057
assert (dest / "docs" / "api.md").exists()
2058+
2059+
2060+
class TestExtensionAddCLI:
2061+
"""CLI integration tests for extension add command."""
2062+
2063+
def test_add_by_display_name_uses_resolved_id_for_download(self, tmp_path):
2064+
"""extension add by display name should use resolved ID for download_extension()."""
2065+
from typer.testing import CliRunner
2066+
from unittest.mock import patch, MagicMock
2067+
from specify_cli import app
2068+
2069+
runner = CliRunner()
2070+
2071+
# Create project structure
2072+
project_dir = tmp_path / "test-project"
2073+
project_dir.mkdir()
2074+
(project_dir / ".specify").mkdir()
2075+
(project_dir / ".specify" / "extensions").mkdir(parents=True)
2076+
2077+
# Mock catalog that returns extension by display name
2078+
mock_catalog = MagicMock()
2079+
mock_catalog.get_extension_info.return_value = None # ID lookup fails
2080+
mock_catalog.search.return_value = [
2081+
{
2082+
"id": "acme-jira-integration",
2083+
"name": "Jira Integration",
2084+
"version": "1.0.0",
2085+
"description": "Jira integration extension",
2086+
"_install_allowed": True,
2087+
}
2088+
]
2089+
2090+
# Track what ID was passed to download_extension
2091+
download_called_with = []
2092+
def mock_download(extension_id):
2093+
download_called_with.append(extension_id)
2094+
# Return a path that will fail install (we just want to verify the ID)
2095+
raise ExtensionError("Mock download - checking ID was resolved")
2096+
2097+
mock_catalog.download_extension.side_effect = mock_download
2098+
2099+
with patch("specify_cli.extensions.ExtensionCatalog", return_value=mock_catalog), \
2100+
patch.object(Path, "cwd", return_value=project_dir):
2101+
result = runner.invoke(
2102+
app,
2103+
["extension", "add", "Jira Integration"],
2104+
catch_exceptions=True,
2105+
)
2106+
2107+
# Verify download_extension was called with the resolved ID, not the display name
2108+
assert len(download_called_with) == 1
2109+
assert download_called_with[0] == "acme-jira-integration", (
2110+
f"Expected download_extension to be called with resolved ID 'acme-jira-integration', "
2111+
f"but was called with '{download_called_with[0]}'"
2112+
)

0 commit comments

Comments
 (0)