Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 27 additions & 26 deletions src/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@
from typing import List
from config import ROOT_DIR


def _load_cache_payload(cache_path: str, default_payload: dict) -> dict:
if not os.path.exists(cache_path):
with open(cache_path, 'w', encoding='utf-8') as file:
json.dump(default_payload, file, indent=4)
return default_payload

try:
with open(cache_path, 'r', encoding='utf-8') as file:
parsed = json.load(file)
except json.JSONDecodeError:
with open(cache_path, 'w', encoding='utf-8') as file:
json.dump(default_payload, file, indent=4)
return default_payload

return parsed if isinstance(parsed, dict) else default_payload

def get_cache_path() -> str:
"""
Gets the path to the cache file.
Expand Down Expand Up @@ -72,24 +89,16 @@ def get_accounts(provider: str) -> List[dict]:
"""
cache_path = get_provider_cache_path(provider)

if not os.path.exists(cache_path):
# Create the cache file
with open(cache_path, 'w') as file:
json.dump({
"accounts": []
}, file, indent=4)
parsed = _load_cache_payload(cache_path, {'accounts': []})

with open(cache_path, 'r') as file:
parsed = json.load(file)
if parsed is None:
return []

if parsed is None:
return []

if 'accounts' not in parsed:
return []
if 'accounts' not in parsed:
return []

# Get accounts dictionary
return parsed['accounts']
# Get accounts dictionary
return parsed['accounts']

def add_account(provider: str, account: dict) -> None:
"""
Expand Down Expand Up @@ -148,18 +157,10 @@ def get_products() -> List[dict]:
Returns:
products (List[dict]): The products
"""
if not os.path.exists(get_afm_cache_path()):
# Create the cache file
with open(get_afm_cache_path(), 'w') as file:
json.dump({
"products": []
}, file, indent=4)

with open(get_afm_cache_path(), 'r') as file:
parsed = json.load(file)
parsed = _load_cache_payload(get_afm_cache_path(), {'products': []})

# Get the products
return parsed["products"]
# Get the products
return parsed['products']

def add_product(product: dict) -> None:
"""
Expand Down
71 changes: 71 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import importlib.util
import json
import sys
import tempfile
import types
import unittest
from pathlib import Path


ROOT_DIR = Path(__file__).resolve().parents[1]
SRC_DIR = ROOT_DIR / "src"


def load_cache_module(root_dir: str):
spec = importlib.util.spec_from_file_location(
"cache_under_test", SRC_DIR / "cache.py"
)
if spec is None or spec.loader is None:
raise RuntimeError("Unable to load cache module")

fake_config = types.ModuleType("config")
fake_config.ROOT_DIR = root_dir

previous_config = sys.modules.get("config")
sys.modules["config"] = fake_config

module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
return module
finally:
if previous_config is None:
sys.modules.pop("config", None)
else:
sys.modules["config"] = previous_config


class CacheRecoveryTests(unittest.TestCase):
def test_get_accounts_recovers_from_corrupted_json(self):
with tempfile.TemporaryDirectory() as temp_dir:
cache_dir = Path(temp_dir) / ".mp"
cache_dir.mkdir()
cache_path = cache_dir / "twitter.json"
cache_path.write_text('{"accounts": [', encoding="utf-8")

cache = load_cache_module(temp_dir)

self.assertEqual(cache.get_accounts("twitter"), [])
self.assertEqual(
json.loads(cache_path.read_text(encoding="utf-8")),
{"accounts": []},
)

def test_get_products_recovers_from_corrupted_json(self):
with tempfile.TemporaryDirectory() as temp_dir:
cache_dir = Path(temp_dir) / ".mp"
cache_dir.mkdir()
cache_path = cache_dir / "afm.json"
cache_path.write_text('{"products": [', encoding="utf-8")

cache = load_cache_module(temp_dir)

self.assertEqual(cache.get_products(), [])
self.assertEqual(
json.loads(cache_path.read_text(encoding="utf-8")),
{"products": []},
)


if __name__ == "__main__":
unittest.main()