diff --git a/src/vfbquery/solr_result_cache.py b/src/vfbquery/solr_result_cache.py index ff66b21..47b901d 100644 --- a/src/vfbquery/solr_result_cache.py +++ b/src/vfbquery/solr_result_cache.py @@ -16,6 +16,8 @@ import re import requests import hashlib +import gzip +import base64 import time import threading from datetime import datetime, timedelta @@ -36,6 +38,45 @@ logger = logging.getLogger(__name__) +# --- Compressed cache payloads --------------------------------------------- +# Large query results (e.g. AllAlignedImages for whole-brain templates) serialise +# to hundreds of MB of JSON. Stored raw, they blew past the cache size cap, were +# never cached (so recomputed every call), and bloated the Solr index. We gzip the +# serialised envelope and base64-encode it so it still fits the text cache_data +# field. The marker prefix lets readers transparently handle both legacy +# plain-JSON entries and compressed ones, and shrinks the on-the-wire payload +# ~10-15x on every read. +_CACHE_GZIP_PREFIX = "gz:" + + +def _encode_cache_field(envelope_json: str) -> str: + """Compress a serialised cache envelope for storage in the cache_data field.""" + packed = gzip.compress(envelope_json.encode("utf-8"), 6) + return _CACHE_GZIP_PREFIX + base64.b64encode(packed).decode("ascii") + + +def _decode_cache_field(cached_field) -> str: + """Return the serialised cache envelope from a stored cache_data value. + + Handles the Solr list-or-string shape and both payload formats: legacy plain + JSON (returned unchanged) and ``gz:``-prefixed base64 gzip blobs. + """ + if isinstance(cached_field, list): + cached_field = cached_field[0] if cached_field else "" + if isinstance(cached_field, str) and cached_field.startswith(_CACHE_GZIP_PREFIX): + try: + blob = base64.b64decode(cached_field[len(_CACHE_GZIP_PREFIX):]) + return gzip.decompress(blob).decode("utf-8") + except Exception: + # Corrupt/truncated gz payload: return the raw string rather than + # raising an uncaught error that would abort cleanup/stats runs. The + # caller's json.loads then fails, so the entry is treated as invalid + # JSON (get_cached_result purges it; other callers skip it). + logger.warning("Failed to decode compressed cache payload; treating as invalid", exc_info=True) + return cached_field + return cached_field + + @dataclass class CacheMetadata: """Metadata for cached results""" @@ -67,7 +108,7 @@ class SolrResultCache: def __init__(self, cache_url: str = None, ttl_hours: int = 2160, # 3 months like VFB_connect - max_result_size_mb: int = 10): + max_result_size_mb: int = None): """ Initialize SOLR result cache @@ -76,13 +117,23 @@ def __init__(self, VFBQUERY_SOLR_URL env var if set, otherwise the dedicated query-cache Solr (DEFAULT_CACHE_URL). ttl_hours: Time-to-live for cache entries in hours - max_result_size_mb: Maximum result size to cache in MB + max_result_size_mb: Maximum *compressed* result size to cache in MB + (default 100; override with the VFBQUERY_MAX_RESULT_MB env var) """ if cache_url is None: cache_url = os.getenv('VFBQUERY_SOLR_URL', self.DEFAULT_CACHE_URL) self.cache_url = cache_url self.ttl_hours = ttl_hours + if max_result_size_mb is None: + raw = os.getenv("VFBQUERY_MAX_RESULT_MB", "100") + try: + max_result_size_mb = int(raw) + except ValueError: + logger.warning("Invalid VFBQUERY_MAX_RESULT_MB=%r; falling back to 100", raw) + max_result_size_mb = 100 self.max_result_size_mb = max_result_size_mb + # The cap is enforced on the COMPRESSED (gzip+base64) payload that is + # actually stored, so 100 MB here corresponds to ~1-1.5 GB of raw JSON. self.max_result_size_bytes = max_result_size_mb * 1024 * 1024 # When Solr is unreachable, disable caching for a period (backoff). @@ -108,10 +159,8 @@ def _create_cache_metadata(self, result: Any, **params) -> Optional[Dict[str, An serialized_result = json.dumps(result, cls=NumpyEncoder) result_size = len(serialized_result.encode('utf-8')) - # Don't cache if result is too large - if result_size > self.max_result_size_bytes: - logger.warning(f"Result too large to cache: {result_size/1024/1024:.2f}MB > {self.max_result_size_mb}MB") - return None + # The size cap is enforced on the compressed payload in cache_result(); + # result_size here is the raw size, kept for metadata and monitoring only. now = datetime.now().astimezone() expires_at = now + timedelta(hours=self.ttl_hours) # 2160 hours = 90 days = 3 months @@ -266,8 +315,15 @@ def get_cached_result(self, query_type: str, term_id: str, **params) -> Optional if isinstance(cached_field, list): cached_field = cached_field[0] - # Parse the cached metadata and result - cached_data = json.loads(cached_field) + # Parse the cached metadata and result. A corrupt/undecodable entry + # (e.g. a truncated gz: payload) raises here; purge it so the next + # call repopulates, rather than leaving a permanent cache miss. + try: + cached_data = json.loads(_decode_cache_field(cached_field)) + except (ValueError, TypeError): + logger.warning(f"Corrupt cache entry for {query_type}({term_id}); clearing it") + self._clear_expired_cache_document(cache_doc_id) + return None # Check package version before anything else so stale cache is rejected early. # Only invalidate when the cached entry is OLDER than the current code @@ -415,11 +471,24 @@ def cache_result(self, query_type: str, term_id: str, result: Any, **params) -> # This ensures different query types for the same term have separate cache entries cache_doc_id = f"vfb_query_{query_type}_{term_id}" + # Serialise then gzip the envelope. The size cap is enforced here, on + # the compressed payload that is actually stored. + cache_field = _encode_cache_field(json.dumps(cached_data, cls=NumpyEncoder)) + stored_size = len(cache_field.encode('utf-8')) + if stored_size > self.max_result_size_bytes: + logger.warning( + f"Compressed result too large to cache: " + f"{stored_size/1024/1024:.2f}MB > {self.max_result_size_mb}MB " + f"(raw {cached_data.get('result_size', 0)/1024/1024:.1f}MB) " + f"for {query_type}({term_id})" + ) + return False + cache_doc = { "id": cache_doc_id, "original_term_id": term_id, "query_type": query_type, - "cache_data": json.dumps(cached_data, cls=NumpyEncoder), + "cache_data": cache_field, "cached_at": cached_data["cached_at"], "expires_at": cached_data["expires_at"] } @@ -430,11 +499,11 @@ def cache_result(self, query_type: str, term_id: str, result: Any, **params) -> data=json.dumps([cache_doc]), headers={"Content-Type": "application/json"}, params={"commit": "true"}, # Immediate commit for availability - timeout=int(os.getenv('VFBQUERY_SOLR_WRITE_TIMEOUT', '30')) + timeout=int(os.getenv('VFBQUERY_SOLR_WRITE_TIMEOUT', '60')) ) if response.status_code == 200: - logger.info(f"Cached {query_type} for {term_id} as {cache_doc_id}, size: {cached_data['result_size']/1024:.1f}KB") + logger.info(f"Cached {query_type} for {term_id} as {cache_doc_id}, raw {cached_data['result_size']/1024:.1f}KB, stored {stored_size/1024:.1f}KB") return True else: logger.error(f"Failed to cache result: HTTP {response.status_code} - {response.text}") @@ -564,7 +633,7 @@ def get_cache_age(self, query_type: str, term_id: str, **params) -> Optional[Dic if isinstance(cached_field, list): cached_field = cached_field[0] - cached_data = json.loads(cached_field) + cached_data = json.loads(_decode_cache_field(cached_field)) cached_at = datetime.fromisoformat(cached_data["cached_at"].replace('Z', '+00:00')) expires_at = datetime.fromisoformat(cached_data["expires_at"].replace('Z', '+00:00')) @@ -635,7 +704,7 @@ def cleanup_expired_entries(self) -> int: cached_field = doc["cache_data"] if isinstance(cached_field, list): cached_field = cached_field[0] - cached_data = json.loads(cached_field) + cached_data = json.loads(_decode_cache_field(cached_field)) expires_at = datetime.fromisoformat(cached_data["expires_at"].replace('Z', '+00:00')) if expires_at and now > expires_at: @@ -716,7 +785,7 @@ def get_cache_stats(self) -> Dict[str, Any]: if isinstance(cached_field, list): cached_field = cached_field[0] - cached_data = json.loads(cached_field) + cached_data = json.loads(_decode_cache_field(cached_field)) total_size += len(cached_field) # Get timestamps from document fields or cache_data diff --git a/tests/test_gzip_cache.py b/tests/test_gzip_cache.py new file mode 100644 index 0000000..288e396 --- /dev/null +++ b/tests/test_gzip_cache.py @@ -0,0 +1,55 @@ +"""Unit tests for gzip-compressed Solr cache payloads (no network).""" +import json +from vfbquery.solr_result_cache import ( + _encode_cache_field, _decode_cache_field, _CACHE_GZIP_PREFIX, SolrResultCache, +) + + +def test_roundtrip_compresses_and_restores(): + env = json.dumps({"result": {"rows": list(range(5000))}, "cached_at": "x"}) + enc = _encode_cache_field(env) + assert enc.startswith(_CACHE_GZIP_PREFIX) + assert len(enc) < len(env) + assert _decode_cache_field(enc) == env + + +def test_decode_returns_raw_string_on_corrupt_gz_payload(): + # A gz:-prefixed but undecodable value must not raise; it returns the raw + # string so the caller's json.loads fails and the entry is treated as invalid. + for bad in (_CACHE_GZIP_PREFIX + "!!!not-base64!!!", _CACHE_GZIP_PREFIX + "AAAA"): + assert _decode_cache_field(bad) == bad + + +def test_decode_handles_legacy_plain_json_and_list_shape(): + legacy = json.dumps({"result": 1}) + assert _decode_cache_field(legacy) == legacy + assert _decode_cache_field([legacy]) == legacy + enc = _encode_cache_field(legacy) + assert _decode_cache_field([enc]) == legacy + + +def test_cap_is_enforced_on_compressed_not_raw_size(): + # Small cap + a highly compressible payload: the RAW JSON must exceed the cap + # while the gzip+base64 form stays under it, proving the cap is on the stored + # (compressed) size, not the raw size. Kept fast/memory-light via repetition. + cap_mb = 1 + c = SolrResultCache(max_result_size_mb=cap_mb) + cap = cap_mb * 1024 * 1024 + assert c.max_result_size_bytes == cap + payload = json.dumps({"result": {"rows": ["x" * 100] * 50000}}) # ~5 MB raw, compresses hard + raw = len(payload.encode("utf-8")) + compressed = len(_encode_cache_field(payload).encode("utf-8")) + assert raw > cap, f"raw {raw} should exceed cap {cap}" + assert compressed < cap, f"compressed {compressed} should be under cap {cap}" + + +def test_env_override(monkeypatch): + monkeypatch.setenv("VFBQUERY_MAX_RESULT_MB", "250") + assert SolrResultCache().max_result_size_mb == 250 + + +def test_create_metadata_no_longer_rejects_large_raw(): + c = SolrResultCache(max_result_size_mb=1) + meta = c._create_cache_metadata({"rows": list(range(200000))}) + assert meta is not None + assert meta["result_size"] > 1024 * 1024