From d103d3af496bffb09032c020b4ac517f08bebcf1 Mon Sep 17 00:00:00 2001 From: Gujiassh Date: Mon, 6 Apr 2026 03:39:51 +0900 Subject: [PATCH] fix: recover from corrupted cache JSON files Reset account/product cache files to empty defaults when JSON decoding fails so corrupted cache files no longer crash the CLI on startup. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus --- src/cache.py | 53 ++++++++++++++++----------------- tests/test_cache.py | 71 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 26 deletions(-) create mode 100644 tests/test_cache.py diff --git a/src/cache.py b/src/cache.py index 74973eb26..640afc8b5 100644 --- a/src/cache.py +++ b/src/cache.py @@ -4,6 +4,23 @@ from typing import List from config import ROOT_DIR + +def _load_cache_payload(cache_path: str, default_payload: dict) -> dict: + if not os.path.exists(cache_path): + with open(cache_path, 'w', encoding='utf-8') as file: + json.dump(default_payload, file, indent=4) + return default_payload + + try: + with open(cache_path, 'r', encoding='utf-8') as file: + parsed = json.load(file) + except json.JSONDecodeError: + with open(cache_path, 'w', encoding='utf-8') as file: + json.dump(default_payload, file, indent=4) + return default_payload + + return parsed if isinstance(parsed, dict) else default_payload + def get_cache_path() -> str: """ Gets the path to the cache file. @@ -72,24 +89,16 @@ def get_accounts(provider: str) -> List[dict]: """ cache_path = get_provider_cache_path(provider) - if not os.path.exists(cache_path): - # Create the cache file - with open(cache_path, 'w') as file: - json.dump({ - "accounts": [] - }, file, indent=4) + parsed = _load_cache_payload(cache_path, {'accounts': []}) - with open(cache_path, 'r') as file: - parsed = json.load(file) + if parsed is None: + return [] - if parsed is None: - return [] - - if 'accounts' not in parsed: - return [] + if 'accounts' not in parsed: + return [] - # Get accounts dictionary - return parsed['accounts'] + # Get accounts dictionary + return parsed['accounts'] def add_account(provider: str, account: dict) -> None: """ @@ -148,18 +157,10 @@ def get_products() -> List[dict]: Returns: products (List[dict]): The products """ - if not os.path.exists(get_afm_cache_path()): - # Create the cache file - with open(get_afm_cache_path(), 'w') as file: - json.dump({ - "products": [] - }, file, indent=4) - - with open(get_afm_cache_path(), 'r') as file: - parsed = json.load(file) + parsed = _load_cache_payload(get_afm_cache_path(), {'products': []}) - # Get the products - return parsed["products"] + # Get the products + return parsed['products'] def add_product(product: dict) -> None: """ diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 000000000..ae6d83c1c --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,71 @@ +import importlib.util +import json +import sys +import tempfile +import types +import unittest +from pathlib import Path + + +ROOT_DIR = Path(__file__).resolve().parents[1] +SRC_DIR = ROOT_DIR / "src" + + +def load_cache_module(root_dir: str): + spec = importlib.util.spec_from_file_location( + "cache_under_test", SRC_DIR / "cache.py" + ) + if spec is None or spec.loader is None: + raise RuntimeError("Unable to load cache module") + + fake_config = types.ModuleType("config") + fake_config.ROOT_DIR = root_dir + + previous_config = sys.modules.get("config") + sys.modules["config"] = fake_config + + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + return module + finally: + if previous_config is None: + sys.modules.pop("config", None) + else: + sys.modules["config"] = previous_config + + +class CacheRecoveryTests(unittest.TestCase): + def test_get_accounts_recovers_from_corrupted_json(self): + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) / ".mp" + cache_dir.mkdir() + cache_path = cache_dir / "twitter.json" + cache_path.write_text('{"accounts": [', encoding="utf-8") + + cache = load_cache_module(temp_dir) + + self.assertEqual(cache.get_accounts("twitter"), []) + self.assertEqual( + json.loads(cache_path.read_text(encoding="utf-8")), + {"accounts": []}, + ) + + def test_get_products_recovers_from_corrupted_json(self): + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) / ".mp" + cache_dir.mkdir() + cache_path = cache_dir / "afm.json" + cache_path.write_text('{"products": [', encoding="utf-8") + + cache = load_cache_module(temp_dir) + + self.assertEqual(cache.get_products(), []) + self.assertEqual( + json.loads(cache_path.read_text(encoding="utf-8")), + {"products": []}, + ) + + +if __name__ == "__main__": + unittest.main()