Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 83 additions & 14 deletions src/vfbquery/solr_result_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import re
import requests
import hashlib
import gzip
import base64
import time
import threading
from datetime import datetime, timedelta
Expand All @@ -36,6 +38,45 @@

logger = logging.getLogger(__name__)

# --- Compressed cache payloads ---------------------------------------------
# Large query results (e.g. AllAlignedImages for whole-brain templates) serialise
# to hundreds of MB of JSON. Stored raw, they blew past the cache size cap, were
# never cached (so recomputed every call), and bloated the Solr index. We gzip the
# serialised envelope and base64-encode it so it still fits the text cache_data
# field. The marker prefix lets readers transparently handle both legacy
# plain-JSON entries and compressed ones, and shrinks the on-the-wire payload
# ~10-15x on every read.
_CACHE_GZIP_PREFIX = "gz:"


def _encode_cache_field(envelope_json: str) -> str:
"""Compress a serialised cache envelope for storage in the cache_data field."""
packed = gzip.compress(envelope_json.encode("utf-8"), 6)
return _CACHE_GZIP_PREFIX + base64.b64encode(packed).decode("ascii")


def _decode_cache_field(cached_field) -> str:
"""Return the serialised cache envelope from a stored cache_data value.

Handles the Solr list-or-string shape and both payload formats: legacy plain
JSON (returned unchanged) and ``gz:``-prefixed base64 gzip blobs.
"""
if isinstance(cached_field, list):
cached_field = cached_field[0] if cached_field else ""
if isinstance(cached_field, str) and cached_field.startswith(_CACHE_GZIP_PREFIX):
try:
blob = base64.b64decode(cached_field[len(_CACHE_GZIP_PREFIX):])
return gzip.decompress(blob).decode("utf-8")
except Exception:
# Corrupt/truncated gz payload: return the raw string rather than
# raising an uncaught error that would abort cleanup/stats runs. The
# caller's json.loads then fails, so the entry is treated as invalid
# JSON (get_cached_result purges it; other callers skip it).
logger.warning("Failed to decode compressed cache payload; treating as invalid", exc_info=True)
return cached_field
return cached_field


@dataclass
class CacheMetadata:
"""Metadata for cached results"""
Expand Down Expand Up @@ -67,7 +108,7 @@ class SolrResultCache:
def __init__(self,
cache_url: str = None,
ttl_hours: int = 2160, # 3 months like VFB_connect
max_result_size_mb: int = 10):
max_result_size_mb: int = None):
"""
Initialize SOLR result cache

Expand All @@ -76,13 +117,23 @@ def __init__(self,
VFBQUERY_SOLR_URL env var if set, otherwise the dedicated
query-cache Solr (DEFAULT_CACHE_URL).
ttl_hours: Time-to-live for cache entries in hours
max_result_size_mb: Maximum result size to cache in MB
max_result_size_mb: Maximum *compressed* result size to cache in MB
(default 100; override with the VFBQUERY_MAX_RESULT_MB env var)
"""
if cache_url is None:
cache_url = os.getenv('VFBQUERY_SOLR_URL', self.DEFAULT_CACHE_URL)
self.cache_url = cache_url
self.ttl_hours = ttl_hours
if max_result_size_mb is None:
raw = os.getenv("VFBQUERY_MAX_RESULT_MB", "100")
try:
max_result_size_mb = int(raw)
except ValueError:
logger.warning("Invalid VFBQUERY_MAX_RESULT_MB=%r; falling back to 100", raw)
max_result_size_mb = 100
self.max_result_size_mb = max_result_size_mb
Comment thread
Copilot marked this conversation as resolved.
# The cap is enforced on the COMPRESSED (gzip+base64) payload that is
# actually stored, so 100 MB here corresponds to ~1-1.5 GB of raw JSON.
self.max_result_size_bytes = max_result_size_mb * 1024 * 1024

# When Solr is unreachable, disable caching for a period (backoff).
Expand All @@ -108,10 +159,8 @@ def _create_cache_metadata(self, result: Any, **params) -> Optional[Dict[str, An
serialized_result = json.dumps(result, cls=NumpyEncoder)
result_size = len(serialized_result.encode('utf-8'))

# Don't cache if result is too large
if result_size > self.max_result_size_bytes:
logger.warning(f"Result too large to cache: {result_size/1024/1024:.2f}MB > {self.max_result_size_mb}MB")
return None
# The size cap is enforced on the compressed payload in cache_result();
# result_size here is the raw size, kept for metadata and monitoring only.

now = datetime.now().astimezone()
expires_at = now + timedelta(hours=self.ttl_hours) # 2160 hours = 90 days = 3 months
Expand Down Expand Up @@ -266,8 +315,15 @@ def get_cached_result(self, query_type: str, term_id: str, **params) -> Optional
if isinstance(cached_field, list):
cached_field = cached_field[0]

# Parse the cached metadata and result
cached_data = json.loads(cached_field)
# Parse the cached metadata and result. A corrupt/undecodable entry
# (e.g. a truncated gz: payload) raises here; purge it so the next
# call repopulates, rather than leaving a permanent cache miss.
try:
cached_data = json.loads(_decode_cache_field(cached_field))
except (ValueError, TypeError):
logger.warning(f"Corrupt cache entry for {query_type}({term_id}); clearing it")
self._clear_expired_cache_document(cache_doc_id)
return None

# Check package version before anything else so stale cache is rejected early.
# Only invalidate when the cached entry is OLDER than the current code
Expand Down Expand Up @@ -415,11 +471,24 @@ def cache_result(self, query_type: str, term_id: str, result: Any, **params) ->
# This ensures different query types for the same term have separate cache entries
cache_doc_id = f"vfb_query_{query_type}_{term_id}"

# Serialise then gzip the envelope. The size cap is enforced here, on
# the compressed payload that is actually stored.
cache_field = _encode_cache_field(json.dumps(cached_data, cls=NumpyEncoder))
stored_size = len(cache_field.encode('utf-8'))
if stored_size > self.max_result_size_bytes:
logger.warning(
f"Compressed result too large to cache: "
f"{stored_size/1024/1024:.2f}MB > {self.max_result_size_mb}MB "
f"(raw {cached_data.get('result_size', 0)/1024/1024:.1f}MB) "
f"for {query_type}({term_id})"
)
return False

cache_doc = {
"id": cache_doc_id,
"original_term_id": term_id,
"query_type": query_type,
"cache_data": json.dumps(cached_data, cls=NumpyEncoder),
"cache_data": cache_field,
"cached_at": cached_data["cached_at"],
"expires_at": cached_data["expires_at"]
}
Expand All @@ -430,11 +499,11 @@ def cache_result(self, query_type: str, term_id: str, result: Any, **params) ->
data=json.dumps([cache_doc]),
headers={"Content-Type": "application/json"},
params={"commit": "true"}, # Immediate commit for availability
timeout=int(os.getenv('VFBQUERY_SOLR_WRITE_TIMEOUT', '30'))
timeout=int(os.getenv('VFBQUERY_SOLR_WRITE_TIMEOUT', '60'))
)

if response.status_code == 200:
logger.info(f"Cached {query_type} for {term_id} as {cache_doc_id}, size: {cached_data['result_size']/1024:.1f}KB")
logger.info(f"Cached {query_type} for {term_id} as {cache_doc_id}, raw {cached_data['result_size']/1024:.1f}KB, stored {stored_size/1024:.1f}KB")
return True
else:
logger.error(f"Failed to cache result: HTTP {response.status_code} - {response.text}")
Expand Down Expand Up @@ -564,7 +633,7 @@ def get_cache_age(self, query_type: str, term_id: str, **params) -> Optional[Dic
if isinstance(cached_field, list):
cached_field = cached_field[0]

cached_data = json.loads(cached_field)
cached_data = json.loads(_decode_cache_field(cached_field))

cached_at = datetime.fromisoformat(cached_data["cached_at"].replace('Z', '+00:00'))
expires_at = datetime.fromisoformat(cached_data["expires_at"].replace('Z', '+00:00'))
Expand Down Expand Up @@ -635,7 +704,7 @@ def cleanup_expired_entries(self) -> int:
cached_field = doc["cache_data"]
if isinstance(cached_field, list):
cached_field = cached_field[0]
cached_data = json.loads(cached_field)
cached_data = json.loads(_decode_cache_field(cached_field))
expires_at = datetime.fromisoformat(cached_data["expires_at"].replace('Z', '+00:00'))

if expires_at and now > expires_at:
Expand Down Expand Up @@ -716,7 +785,7 @@ def get_cache_stats(self) -> Dict[str, Any]:
if isinstance(cached_field, list):
cached_field = cached_field[0]

cached_data = json.loads(cached_field)
cached_data = json.loads(_decode_cache_field(cached_field))
total_size += len(cached_field)

# Get timestamps from document fields or cache_data
Expand Down
55 changes: 55 additions & 0 deletions tests/test_gzip_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Unit tests for gzip-compressed Solr cache payloads (no network)."""
import json
from vfbquery.solr_result_cache import (
_encode_cache_field, _decode_cache_field, _CACHE_GZIP_PREFIX, SolrResultCache,
)


def test_roundtrip_compresses_and_restores():
env = json.dumps({"result": {"rows": list(range(5000))}, "cached_at": "x"})
enc = _encode_cache_field(env)
assert enc.startswith(_CACHE_GZIP_PREFIX)
assert len(enc) < len(env)
assert _decode_cache_field(enc) == env


def test_decode_returns_raw_string_on_corrupt_gz_payload():
# A gz:-prefixed but undecodable value must not raise; it returns the raw
# string so the caller's json.loads fails and the entry is treated as invalid.
for bad in (_CACHE_GZIP_PREFIX + "!!!not-base64!!!", _CACHE_GZIP_PREFIX + "AAAA"):
assert _decode_cache_field(bad) == bad


def test_decode_handles_legacy_plain_json_and_list_shape():
legacy = json.dumps({"result": 1})
assert _decode_cache_field(legacy) == legacy
assert _decode_cache_field([legacy]) == legacy
enc = _encode_cache_field(legacy)
assert _decode_cache_field([enc]) == legacy


def test_cap_is_enforced_on_compressed_not_raw_size():
# Small cap + a highly compressible payload: the RAW JSON must exceed the cap
# while the gzip+base64 form stays under it, proving the cap is on the stored
# (compressed) size, not the raw size. Kept fast/memory-light via repetition.
cap_mb = 1
c = SolrResultCache(max_result_size_mb=cap_mb)
cap = cap_mb * 1024 * 1024
assert c.max_result_size_bytes == cap
payload = json.dumps({"result": {"rows": ["x" * 100] * 50000}}) # ~5 MB raw, compresses hard
raw = len(payload.encode("utf-8"))
compressed = len(_encode_cache_field(payload).encode("utf-8"))
assert raw > cap, f"raw {raw} should exceed cap {cap}"
assert compressed < cap, f"compressed {compressed} should be under cap {cap}"


def test_env_override(monkeypatch):
monkeypatch.setenv("VFBQUERY_MAX_RESULT_MB", "250")
assert SolrResultCache().max_result_size_mb == 250


def test_create_metadata_no_longer_rejects_large_raw():
c = SolrResultCache(max_result_size_mb=1)
meta = c._create_cache_metadata({"rows": list(range(200000))})
assert meta is not None
assert meta["result_size"] > 1024 * 1024
Loading