1616import re
1717import requests
1818import hashlib
19+ import gzip
20+ import base64
1921import time
2022import threading
2123from datetime import datetime , timedelta
3638
3739logger = logging .getLogger (__name__ )
3840
41+ # --- Compressed cache payloads ---------------------------------------------
42+ # Large query results (e.g. AllAlignedImages for whole-brain templates) serialise
43+ # to hundreds of MB of JSON. Stored raw, they blew past the cache size cap, were
44+ # never cached (so recomputed every call), and bloated the Solr index. We gzip the
45+ # serialised envelope and base64-encode it so it still fits the text cache_data
46+ # field. The marker prefix lets readers transparently handle both legacy
47+ # plain-JSON entries and compressed ones, and shrinks the on-the-wire payload
48+ # ~10-15x on every read.
49+ _CACHE_GZIP_PREFIX = "gz:"
50+
51+
52+ def _encode_cache_field (envelope_json : str ) -> str :
53+ """Compress a serialised cache envelope for storage in the cache_data field."""
54+ packed = gzip .compress (envelope_json .encode ("utf-8" ), 6 )
55+ return _CACHE_GZIP_PREFIX + base64 .b64encode (packed ).decode ("ascii" )
56+
57+
58+ def _decode_cache_field (cached_field ) -> str :
59+ """Return the serialised cache envelope from a stored cache_data value.
60+
61+ Handles the Solr list-or-string shape and both payload formats: legacy plain
62+ JSON (returned unchanged) and ``gz:``-prefixed base64 gzip blobs.
63+ """
64+ if isinstance (cached_field , list ):
65+ cached_field = cached_field [0 ] if cached_field else ""
66+ if isinstance (cached_field , str ) and cached_field .startswith (_CACHE_GZIP_PREFIX ):
67+ try :
68+ blob = base64 .b64decode (cached_field [len (_CACHE_GZIP_PREFIX ):])
69+ return gzip .decompress (blob ).decode ("utf-8" )
70+ except Exception :
71+ # Corrupt/truncated gz payload: return the raw string rather than
72+ # raising an uncaught error that would abort cleanup/stats runs. The
73+ # caller's json.loads then fails, so the entry is treated as invalid
74+ # JSON (get_cached_result purges it; other callers skip it).
75+ logger .warning ("Failed to decode compressed cache payload; treating as invalid" , exc_info = True )
76+ return cached_field
77+ return cached_field
78+
79+
3980@dataclass
4081class CacheMetadata :
4182 """Metadata for cached results"""
@@ -67,7 +108,7 @@ class SolrResultCache:
67108 def __init__ (self ,
68109 cache_url : str = None ,
69110 ttl_hours : int = 2160 , # 3 months like VFB_connect
70- max_result_size_mb : int = 10 ):
111+ max_result_size_mb : int = None ):
71112 """
72113 Initialize SOLR result cache
73114
@@ -76,13 +117,23 @@ def __init__(self,
76117 VFBQUERY_SOLR_URL env var if set, otherwise the dedicated
77118 query-cache Solr (DEFAULT_CACHE_URL).
78119 ttl_hours: Time-to-live for cache entries in hours
79- max_result_size_mb: Maximum result size to cache in MB
120+ max_result_size_mb: Maximum *compressed* result size to cache in MB
121+ (default 100; override with the VFBQUERY_MAX_RESULT_MB env var)
80122 """
81123 if cache_url is None :
82124 cache_url = os .getenv ('VFBQUERY_SOLR_URL' , self .DEFAULT_CACHE_URL )
83125 self .cache_url = cache_url
84126 self .ttl_hours = ttl_hours
127+ if max_result_size_mb is None :
128+ raw = os .getenv ("VFBQUERY_MAX_RESULT_MB" , "100" )
129+ try :
130+ max_result_size_mb = int (raw )
131+ except ValueError :
132+ logger .warning ("Invalid VFBQUERY_MAX_RESULT_MB=%r; falling back to 100" , raw )
133+ max_result_size_mb = 100
85134 self .max_result_size_mb = max_result_size_mb
135+ # The cap is enforced on the COMPRESSED (gzip+base64) payload that is
136+ # actually stored, so 100 MB here corresponds to ~1-1.5 GB of raw JSON.
86137 self .max_result_size_bytes = max_result_size_mb * 1024 * 1024
87138
88139 # When Solr is unreachable, disable caching for a period (backoff).
@@ -108,10 +159,8 @@ def _create_cache_metadata(self, result: Any, **params) -> Optional[Dict[str, An
108159 serialized_result = json .dumps (result , cls = NumpyEncoder )
109160 result_size = len (serialized_result .encode ('utf-8' ))
110161
111- # Don't cache if result is too large
112- if result_size > self .max_result_size_bytes :
113- logger .warning (f"Result too large to cache: { result_size / 1024 / 1024 :.2f} MB > { self .max_result_size_mb } MB" )
114- return None
162+ # The size cap is enforced on the compressed payload in cache_result();
163+ # result_size here is the raw size, kept for metadata and monitoring only.
115164
116165 now = datetime .now ().astimezone ()
117166 expires_at = now + timedelta (hours = self .ttl_hours ) # 2160 hours = 90 days = 3 months
@@ -266,8 +315,15 @@ def get_cached_result(self, query_type: str, term_id: str, **params) -> Optional
266315 if isinstance (cached_field , list ):
267316 cached_field = cached_field [0 ]
268317
269- # Parse the cached metadata and result
270- cached_data = json .loads (cached_field )
318+ # Parse the cached metadata and result. A corrupt/undecodable entry
319+ # (e.g. a truncated gz: payload) raises here; purge it so the next
320+ # call repopulates, rather than leaving a permanent cache miss.
321+ try :
322+ cached_data = json .loads (_decode_cache_field (cached_field ))
323+ except (ValueError , TypeError ):
324+ logger .warning (f"Corrupt cache entry for { query_type } ({ term_id } ); clearing it" )
325+ self ._clear_expired_cache_document (cache_doc_id )
326+ return None
271327
272328 # Check package version before anything else so stale cache is rejected early.
273329 # Only invalidate when the cached entry is OLDER than the current code
@@ -415,11 +471,24 @@ def cache_result(self, query_type: str, term_id: str, result: Any, **params) ->
415471 # This ensures different query types for the same term have separate cache entries
416472 cache_doc_id = f"vfb_query_{ query_type } _{ term_id } "
417473
474+ # Serialise then gzip the envelope. The size cap is enforced here, on
475+ # the compressed payload that is actually stored.
476+ cache_field = _encode_cache_field (json .dumps (cached_data , cls = NumpyEncoder ))
477+ stored_size = len (cache_field .encode ('utf-8' ))
478+ if stored_size > self .max_result_size_bytes :
479+ logger .warning (
480+ f"Compressed result too large to cache: "
481+ f"{ stored_size / 1024 / 1024 :.2f} MB > { self .max_result_size_mb } MB "
482+ f"(raw { cached_data .get ('result_size' , 0 )/ 1024 / 1024 :.1f} MB) "
483+ f"for { query_type } ({ term_id } )"
484+ )
485+ return False
486+
418487 cache_doc = {
419488 "id" : cache_doc_id ,
420489 "original_term_id" : term_id ,
421490 "query_type" : query_type ,
422- "cache_data" : json . dumps ( cached_data , cls = NumpyEncoder ) ,
491+ "cache_data" : cache_field ,
423492 "cached_at" : cached_data ["cached_at" ],
424493 "expires_at" : cached_data ["expires_at" ]
425494 }
@@ -430,11 +499,11 @@ def cache_result(self, query_type: str, term_id: str, result: Any, **params) ->
430499 data = json .dumps ([cache_doc ]),
431500 headers = {"Content-Type" : "application/json" },
432501 params = {"commit" : "true" }, # Immediate commit for availability
433- timeout = int (os .getenv ('VFBQUERY_SOLR_WRITE_TIMEOUT' , '30 ' ))
502+ timeout = int (os .getenv ('VFBQUERY_SOLR_WRITE_TIMEOUT' , '60 ' ))
434503 )
435504
436505 if response .status_code == 200 :
437- logger .info (f"Cached { query_type } for { term_id } as { cache_doc_id } , size: { cached_data ['result_size' ]/ 1024 :.1f} KB" )
506+ logger .info (f"Cached { query_type } for { term_id } as { cache_doc_id } , raw { cached_data ['result_size' ]/ 1024 :.1f } KB, stored { stored_size / 1024 :.1f} KB" )
438507 return True
439508 else :
440509 logger .error (f"Failed to cache result: HTTP { response .status_code } - { response .text } " )
@@ -564,7 +633,7 @@ def get_cache_age(self, query_type: str, term_id: str, **params) -> Optional[Dic
564633 if isinstance (cached_field , list ):
565634 cached_field = cached_field [0 ]
566635
567- cached_data = json .loads (cached_field )
636+ cached_data = json .loads (_decode_cache_field ( cached_field ) )
568637
569638 cached_at = datetime .fromisoformat (cached_data ["cached_at" ].replace ('Z' , '+00:00' ))
570639 expires_at = datetime .fromisoformat (cached_data ["expires_at" ].replace ('Z' , '+00:00' ))
@@ -635,7 +704,7 @@ def cleanup_expired_entries(self) -> int:
635704 cached_field = doc ["cache_data" ]
636705 if isinstance (cached_field , list ):
637706 cached_field = cached_field [0 ]
638- cached_data = json .loads (cached_field )
707+ cached_data = json .loads (_decode_cache_field ( cached_field ) )
639708 expires_at = datetime .fromisoformat (cached_data ["expires_at" ].replace ('Z' , '+00:00' ))
640709
641710 if expires_at and now > expires_at :
@@ -716,7 +785,7 @@ def get_cache_stats(self) -> Dict[str, Any]:
716785 if isinstance (cached_field , list ):
717786 cached_field = cached_field [0 ]
718787
719- cached_data = json .loads (cached_field )
788+ cached_data = json .loads (_decode_cache_field ( cached_field ) )
720789 total_size += len (cached_field )
721790
722791 # Get timestamps from document fields or cache_data
0 commit comments