Skip to content

Commit 2a98da9

Browse files
iamaeroplaneclaude
andcommitted
fix(extensions): address reviewer feedback on update/rollback logic
- Hook rollback: handle empty backup_hooks by checking `is not None` instead of truthiness (falsy empty dict would skip hook cleanup) - extension_info: use resolved_installed_id for catalog lookup when extension was found by display name (prevents wrong catalog match) - Rollback: always remove extension dir first, then restore if backup exists (handles case when no original dir existed before update) - Validate extension ID from ZIP before installing, not after (avoids side effects of installing wrong extension before rollback) - Preserve enabled state during updates: re-apply disabled state and hook enabled flags after successful update - Optimize _resolve_catalog_extension: pass query to catalog.search() instead of fetching all extensions - update() now merges metadata with existing entry instead of replacing (preserves fields like registered_commands when only updating enabled) - Add tests for ExtensionRegistry.update() and restore() methods: - test_update_preserves_installed_at - test_update_merges_with_existing - test_update_raises_for_missing_extension - test_restore_overwrites_completely - test_restore_can_recreate_removed_entry Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 7bd6fd9 commit 2a98da9

File tree

3 files changed

+166
-36
lines changed

3 files changed

+166
-36
lines changed

src/specify_cli/__init__.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,9 +1881,9 @@ def _resolve_catalog_extension(
18811881
if ext_info:
18821882
return (ext_info, None)
18831883

1884-
# Try by display name - search all extensions
1885-
all_extensions = catalog.search()
1886-
name_matches = [ext for ext in all_extensions if ext["name"].lower() == argument.lower()]
1884+
# Try by display name - search using argument as query, then filter for exact match
1885+
search_results = catalog.search(query=argument)
1886+
name_matches = [ext for ext in search_results if ext["name"].lower() == argument.lower()]
18871887

18881888
if len(name_matches) == 1:
18891889
return (name_matches[0], None)
@@ -2468,7 +2468,10 @@ def extension_info(
24682468
raise typer.Exit(1)
24692469

24702470
# Try catalog lookup (with error handling)
2471-
ext_info, catalog_error = _resolve_catalog_extension(extension, catalog, "info")
2471+
# If we resolved an installed extension by display name, use its ID for catalog lookup
2472+
# to ensure we get the correct catalog entry (not a different extension with same name)
2473+
lookup_key = resolved_installed_id if resolved_installed_id else extension
2474+
ext_info, catalog_error = _resolve_catalog_extension(lookup_key, catalog, "info")
24722475

24732476
# Case 1: Found in catalog - show full catalog info
24742477
if ext_info:
@@ -2729,7 +2732,7 @@ def extension_update(
27292732

27302733
# Store backup state
27312734
backup_registry_entry = None
2732-
backup_hooks = None
2735+
backup_hooks = {} # Initialize to empty dict, not None
27332736
backed_up_command_files = {}
27342737

27352738
try:
@@ -2778,30 +2781,51 @@ def extension_update(
27782781
if ext_hooks:
27792782
backup_hooks[hook_name] = ext_hooks
27802783

2781-
# 5. Remove old extension (handles command file cleanup and registry removal)
2782-
manager.remove(extension_id, keep_config=True)
2783-
2784-
# 6. Download and install new version
2784+
# 5. Download new version
27852785
zip_path = catalog.download_extension(extension_id)
27862786
try:
2787-
installed_manifest = manager.install_from_zip(zip_path, speckit_version)
2788-
2789-
# 7. Verify extension ID matches
2790-
if installed_manifest.id != extension_id:
2791-
# Remove the wrongly installed extension before raising
2787+
# 6. Validate extension ID from ZIP BEFORE modifying installation
2788+
with zipfile.ZipFile(zip_path, "r") as zf:
27922789
try:
2793-
manager.remove(installed_manifest.id)
2794-
except Exception:
2795-
pass # Best effort cleanup
2790+
with zf.open("extension.yml") as f:
2791+
import yaml
2792+
manifest_data = yaml.safe_load(f) or {}
2793+
except KeyError:
2794+
raise ValueError("Downloaded extension archive is missing 'extension.yml'")
2795+
2796+
zip_extension_id = manifest_data.get("extension", {}).get("id")
2797+
if zip_extension_id != extension_id:
27962798
raise ValueError(
2797-
f"Extension ID mismatch: expected '{extension_id}', got '{installed_manifest.id}'"
2799+
f"Extension ID mismatch: expected '{extension_id}', got '{zip_extension_id}'"
27982800
)
2801+
2802+
# 7. Remove old extension (handles command file cleanup and registry removal)
2803+
manager.remove(extension_id, keep_config=True)
2804+
2805+
# 8. Install new version
2806+
installed_manifest = manager.install_from_zip(zip_path, speckit_version)
2807+
2808+
# 9. Restore enabled state from backup
2809+
# If extension was disabled before update, disable it again
2810+
if backup_registry_entry and not backup_registry_entry.get("enabled", True):
2811+
new_metadata = manager.registry.get(extension_id)
2812+
new_metadata["enabled"] = False
2813+
manager.registry.update(extension_id, new_metadata)
2814+
2815+
# Also disable hooks in extensions.yml to match
2816+
config = hook_executor.get_project_config()
2817+
if "hooks" in config:
2818+
for hook_name in config["hooks"]:
2819+
for hook in config["hooks"][hook_name]:
2820+
if hook.get("extension") == extension_id:
2821+
hook["enabled"] = False
2822+
hook_executor.save_project_config(config)
27992823
finally:
28002824
# Clean up downloaded ZIP
28012825
if zip_path.exists():
28022826
zip_path.unlink()
28032827

2804-
# 8. Clean up backup on success
2828+
# 10. Clean up backup on success
28052829
if backup_base.exists():
28062830
shutil.rmtree(backup_base)
28072831

@@ -2819,10 +2843,13 @@ def extension_update(
28192843

28202844
try:
28212845
# Restore extension directory
2846+
extension_dir = manager.extensions_dir / extension_id
2847+
# Always remove any existing directory (from failed update)
2848+
if extension_dir.exists():
2849+
shutil.rmtree(extension_dir)
2850+
# Restore from backup if it exists; otherwise leave absent
2851+
# (matching pre-update state when there was no original dir)
28222852
if backup_ext_dir.exists():
2823-
extension_dir = manager.extensions_dir / extension_id
2824-
if extension_dir.exists():
2825-
shutil.rmtree(extension_dir)
28262853
shutil.copytree(backup_ext_dir, extension_dir)
28272854

28282855
# Restore command files
@@ -2834,19 +2861,24 @@ def extension_update(
28342861
shutil.copy2(backup_file, original_file)
28352862

28362863
# Restore hooks in extensions.yml
2837-
if backup_hooks:
2864+
# Always remove any hooks for this extension (from failed install),
2865+
# then restore backed-up hooks (even if empty)
2866+
if backup_hooks is not None:
28382867
config = hook_executor.get_project_config()
28392868
if "hooks" not in config:
28402869
config["hooks"] = {}
2841-
for hook_name, hooks in backup_hooks.items():
2842-
if hook_name not in config["hooks"]:
2843-
config["hooks"][hook_name] = []
2844-
# Remove any existing hooks for this extension first
2870+
2871+
# First remove any existing hooks for this extension across all hook groups
2872+
for hook_name, hooks_list in config["hooks"].items():
28452873
config["hooks"][hook_name] = [
2846-
h for h in config["hooks"][hook_name]
2874+
h for h in hooks_list
28472875
if h.get("extension") != extension_id
28482876
]
2849-
# Add back the backed up hooks
2877+
2878+
# Then add back the backed up hooks (may be empty)
2879+
for hook_name, hooks in backup_hooks.items():
2880+
if hook_name not in config["hooks"]:
2881+
config["hooks"][hook_name] = []
28502882
config["hooks"][hook_name].extend(hooks)
28512883
hook_executor.save_project_config(config)
28522884

src/specify_cli/extensions.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -229,27 +229,34 @@ def add(self, extension_id: str, metadata: dict):
229229
self._save()
230230

231231
def update(self, extension_id: str, metadata: dict):
232-
"""Update extension metadata in registry, preserving installed_at.
232+
"""Update extension metadata in registry, merging with existing entry.
233+
234+
Merges the provided metadata with the existing entry, preserving any
235+
fields not specified in the new metadata. The installed_at timestamp
236+
is always preserved from the original entry.
233237
234238
Use this method instead of add() when updating existing extension
235239
metadata (e.g., enabling/disabling) to preserve the original
236-
installation timestamp.
240+
installation timestamp and other existing fields.
237241
238242
Args:
239243
extension_id: Extension ID
240-
metadata: Extension metadata to update
244+
metadata: Extension metadata fields to update (merged with existing)
241245
242246
Raises:
243247
KeyError: If extension is not installed
244248
"""
245249
if extension_id not in self.data["extensions"]:
246250
raise KeyError(f"Extension '{extension_id}' is not installed")
247-
# Preserve the original installed_at timestamp
251+
# Merge new metadata with existing, preserving original installed_at
248252
existing = self.data["extensions"][extension_id]
249253
original_installed_at = existing.get("installed_at")
250-
self.data["extensions"][extension_id] = metadata
251-
if original_installed_at and "installed_at" not in metadata:
252-
self.data["extensions"][extension_id]["installed_at"] = original_installed_at
254+
# Merge: existing fields preserved, new fields override
255+
merged = {**existing, **metadata}
256+
# Always preserve original installed_at
257+
if original_installed_at:
258+
merged["installed_at"] = original_installed_at
259+
self.data["extensions"][extension_id] = merged
253260
self._save()
254261

255262
def restore(self, extension_id: str, metadata: dict):

tests/test_extensions.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,97 @@ def test_registry_persistence(self, temp_dir):
277277
assert registry2.is_installed("test-ext")
278278
assert registry2.get("test-ext")["version"] == "1.0.0"
279279

280+
def test_update_preserves_installed_at(self, temp_dir):
281+
"""Test that update() preserves the original installed_at timestamp."""
282+
extensions_dir = temp_dir / "extensions"
283+
extensions_dir.mkdir()
284+
285+
registry = ExtensionRegistry(extensions_dir)
286+
registry.add("test-ext", {"version": "1.0.0", "enabled": True})
287+
288+
# Get original installed_at
289+
original_data = registry.get("test-ext")
290+
original_installed_at = original_data["installed_at"]
291+
292+
# Update with new metadata
293+
registry.update("test-ext", {"version": "2.0.0", "enabled": False})
294+
295+
# Verify installed_at is preserved
296+
updated_data = registry.get("test-ext")
297+
assert updated_data["installed_at"] == original_installed_at
298+
assert updated_data["version"] == "2.0.0"
299+
assert updated_data["enabled"] is False
300+
301+
def test_update_merges_with_existing(self, temp_dir):
302+
"""Test that update() merges new metadata with existing fields."""
303+
extensions_dir = temp_dir / "extensions"
304+
extensions_dir.mkdir()
305+
306+
registry = ExtensionRegistry(extensions_dir)
307+
registry.add("test-ext", {
308+
"version": "1.0.0",
309+
"enabled": True,
310+
"registered_commands": {"claude": ["cmd1", "cmd2"]},
311+
})
312+
313+
# Update with partial metadata (only enabled field)
314+
registry.update("test-ext", {"enabled": False})
315+
316+
# Verify existing fields are preserved
317+
updated_data = registry.get("test-ext")
318+
assert updated_data["enabled"] is False
319+
assert updated_data["version"] == "1.0.0" # Preserved
320+
assert updated_data["registered_commands"] == {"claude": ["cmd1", "cmd2"]} # Preserved
321+
322+
def test_update_raises_for_missing_extension(self, temp_dir):
323+
"""Test that update() raises KeyError for non-installed extension."""
324+
extensions_dir = temp_dir / "extensions"
325+
extensions_dir.mkdir()
326+
327+
registry = ExtensionRegistry(extensions_dir)
328+
329+
with pytest.raises(KeyError, match="not installed"):
330+
registry.update("nonexistent-ext", {"enabled": False})
331+
332+
def test_restore_overwrites_completely(self, temp_dir):
333+
"""Test that restore() overwrites the registry entry completely."""
334+
extensions_dir = temp_dir / "extensions"
335+
extensions_dir.mkdir()
336+
337+
registry = ExtensionRegistry(extensions_dir)
338+
registry.add("test-ext", {"version": "2.0.0", "enabled": True})
339+
340+
# Restore with complete backup data
341+
backup_data = {
342+
"version": "1.0.0",
343+
"enabled": False,
344+
"installed_at": "2024-01-01T00:00:00+00:00",
345+
"registered_commands": {"claude": ["old-cmd"]},
346+
}
347+
registry.restore("test-ext", backup_data)
348+
349+
# Verify entry is exactly as restored
350+
restored_data = registry.get("test-ext")
351+
assert restored_data == backup_data
352+
353+
def test_restore_can_recreate_removed_entry(self, temp_dir):
354+
"""Test that restore() can recreate an entry after remove()."""
355+
extensions_dir = temp_dir / "extensions"
356+
extensions_dir.mkdir()
357+
358+
registry = ExtensionRegistry(extensions_dir)
359+
registry.add("test-ext", {"version": "1.0.0"})
360+
361+
# Save backup and remove
362+
backup = registry.get("test-ext").copy()
363+
registry.remove("test-ext")
364+
assert not registry.is_installed("test-ext")
365+
366+
# Restore should recreate the entry
367+
registry.restore("test-ext", backup)
368+
assert registry.is_installed("test-ext")
369+
assert registry.get("test-ext")["version"] == "1.0.0"
370+
280371

281372
# ===== ExtensionManager Tests =====
282373

0 commit comments

Comments
 (0)