|
| 1 | +""" |
| 2 | +Shared memory measurement helpers for scan benchmarks. |
| 3 | +
|
| 4 | +Each helper is independent — pick the ones you need. Designed to be used |
| 5 | +from subprocess scan scripts (so tracemalloc + psutil polling are not |
| 6 | +contaminated by pytest's own allocations). |
| 7 | +
|
| 8 | +Three measurement angles, each answering a different question: |
| 9 | +
|
| 10 | + - ``RSSSampler``: actual OS-level RAM (catches Rust / lxml / yara that |
| 11 | + tracemalloc misses). Headline metrics: peak, end, retention (median |
| 12 | + over the last 25% of samples — the metric most sensitive to |
| 13 | + "stuck for the rest of the scan" pathology). |
| 14 | +
|
| 15 | + - ``event_census``: who is alive right now, grouped by event type, with |
| 16 | + HTTP_RESPONSE body bytes broken out separately. Answers |
| 17 | + "where is the memory going?". |
| 18 | +
|
| 19 | + - ``lineage_census``: walk every live event's parent chain back to the |
| 20 | + seed and bucket pinned events by seed. Answers "is the chain |
| 21 | + holding things alive?". |
| 22 | +""" |
| 23 | + |
| 24 | +import gc |
| 25 | +import json |
| 26 | +import threading |
| 27 | +import time |
| 28 | +import weakref |
| 29 | + |
| 30 | +import psutil |
| 31 | + |
| 32 | + |
| 33 | +class RSSSampler: |
| 34 | + """ |
| 35 | + Background thread that polls process RSS at a fixed interval. |
| 36 | +
|
| 37 | + Usage:: |
| 38 | +
|
| 39 | + sampler = RSSSampler(interval_s=0.2) |
| 40 | + sampler.start() |
| 41 | + # ... do work ... |
| 42 | + sampler.stop() |
| 43 | + m = sampler.metrics() # peak_rss_mb, end_rss_mb, retention_rss_mb |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__(self, interval_s=0.2): |
| 47 | + self.interval_s = interval_s |
| 48 | + self.samples = [] # list of (t, rss_mb) |
| 49 | + self._stop = threading.Event() |
| 50 | + self._thread = None |
| 51 | + self._proc = psutil.Process() |
| 52 | + |
| 53 | + def start(self): |
| 54 | + self._thread = threading.Thread(target=self._loop, daemon=True) |
| 55 | + self._thread.start() |
| 56 | + |
| 57 | + def _loop(self): |
| 58 | + t0 = time.monotonic() |
| 59 | + while not self._stop.is_set(): |
| 60 | + rss_mb = self._proc.memory_info().rss / 1024 / 1024 |
| 61 | + self.samples.append((time.monotonic() - t0, rss_mb)) |
| 62 | + self._stop.wait(self.interval_s) |
| 63 | + |
| 64 | + def stop(self): |
| 65 | + self._stop.set() |
| 66 | + if self._thread is not None: |
| 67 | + self._thread.join(timeout=1.0) |
| 68 | + |
| 69 | + def metrics(self): |
| 70 | + if not self.samples: |
| 71 | + return { |
| 72 | + "peak_rss_mb": 0.0, |
| 73 | + "end_rss_mb": 0.0, |
| 74 | + "retention_rss_mb": 0.0, |
| 75 | + "samples": 0, |
| 76 | + "duration_s": 0.0, |
| 77 | + } |
| 78 | + rss_values = [r for _, r in self.samples] |
| 79 | + peak_rss_mb = max(rss_values) |
| 80 | + end_rss_mb = rss_values[-1] |
| 81 | + # median over the last 25% of samples — robust to a single |
| 82 | + # transient spike at the tail and sensitive to baseline drift. |
| 83 | + last_quartile_start = max(1, int(len(rss_values) * 0.75)) |
| 84 | + last_quartile = sorted(rss_values[last_quartile_start:]) |
| 85 | + retention_rss_mb = last_quartile[len(last_quartile) // 2] |
| 86 | + duration_s = self.samples[-1][0] |
| 87 | + return { |
| 88 | + "peak_rss_mb": round(peak_rss_mb, 2), |
| 89 | + "end_rss_mb": round(end_rss_mb, 2), |
| 90 | + "retention_rss_mb": round(retention_rss_mb, 2), |
| 91 | + "samples": len(self.samples), |
| 92 | + "duration_s": round(duration_s, 2), |
| 93 | + } |
| 94 | + |
| 95 | + |
| 96 | +class LiveEventTracker: |
| 97 | + """ |
| 98 | + Mid-scan event census via weakref, without scanning the full Python |
| 99 | + object graph. |
| 100 | +
|
| 101 | + ``event_census()`` calls ``gc.get_objects()`` and walks every Python |
| 102 | + object — fine at scan-end (one-shot), too expensive to call every |
| 103 | + few hundred events. ``LiveEventTracker`` patches |
| 104 | + ``BaseEvent.__init__`` to register newcomers into a ``WeakSet``, so |
| 105 | + ``census()`` is O(live events) instead of O(every Python object). |
| 106 | +
|
| 107 | + Usage:: |
| 108 | +
|
| 109 | + tracker = LiveEventTracker() |
| 110 | + tracker.install() |
| 111 | + # ... run scan, calling tracker.census() periodically ... |
| 112 | + tracker.uninstall() # optional; not needed if the process exits |
| 113 | + """ |
| 114 | + |
| 115 | + def __init__(self): |
| 116 | + self._events = weakref.WeakSet() |
| 117 | + self._original_init = None |
| 118 | + self._patched = False |
| 119 | + |
| 120 | + def install(self): |
| 121 | + if self._patched: |
| 122 | + return |
| 123 | + from bbot.core.event.base import BaseEvent |
| 124 | + |
| 125 | + # Seed with any events that already exist (e.g. SCAN root_event |
| 126 | + # created before install). |
| 127 | + for obj in gc.get_objects(): |
| 128 | + if isinstance(obj, BaseEvent): |
| 129 | + self._events.add(obj) |
| 130 | + |
| 131 | + original_init = BaseEvent.__init__ |
| 132 | + events_ref = self._events |
| 133 | + |
| 134 | + def patched_init(self_evt, *a, **kw): |
| 135 | + original_init(self_evt, *a, **kw) |
| 136 | + try: |
| 137 | + events_ref.add(self_evt) |
| 138 | + except TypeError: |
| 139 | + # Not weakref-able for some reason; skip silently. |
| 140 | + pass |
| 141 | + |
| 142 | + self._original_init = original_init |
| 143 | + BaseEvent.__init__ = patched_init |
| 144 | + self._patched = True |
| 145 | + |
| 146 | + def uninstall(self): |
| 147 | + if not self._patched: |
| 148 | + return |
| 149 | + from bbot.core.event.base import BaseEvent |
| 150 | + |
| 151 | + BaseEvent.__init__ = self._original_init |
| 152 | + self._patched = False |
| 153 | + |
| 154 | + def census(self): |
| 155 | + """O(live events) census. Same shape as ``event_census()``.""" |
| 156 | + by_type = {} |
| 157 | + body_bytes = 0 |
| 158 | + body_count = 0 |
| 159 | + for obj in self._events: |
| 160 | + by_type[obj.type] = by_type.get(obj.type, 0) + 1 |
| 161 | + if obj.type == "HTTP_RESPONSE": |
| 162 | + data = getattr(obj, "data", None) |
| 163 | + body = data.get("body") if isinstance(data, dict) else None |
| 164 | + if body: |
| 165 | + body_count += 1 |
| 166 | + body_bytes += len(body) |
| 167 | + return { |
| 168 | + "live_events": sum(by_type.values()), |
| 169 | + "by_type": by_type, |
| 170 | + "http_response_body_mb": round(body_bytes / 1024 / 1024, 2), |
| 171 | + "http_response_with_body": body_count, |
| 172 | + } |
| 173 | + |
| 174 | + |
| 175 | +def event_census(): |
| 176 | + """ |
| 177 | + Walk live BaseEvent instances and classify by type + HTTP_RESPONSE body bytes. |
| 178 | +
|
| 179 | + Returns a dict:: |
| 180 | +
|
| 181 | + { |
| 182 | + "live_events": int, |
| 183 | + "by_type": {"DNS_NAME": int, "HTTP_RESPONSE": int, ...}, |
| 184 | + "http_response_body_mb": float, |
| 185 | + "http_response_with_body": int, |
| 186 | + } |
| 187 | + """ |
| 188 | + from bbot.core.event.base import BaseEvent |
| 189 | + |
| 190 | + gc.collect() |
| 191 | + by_type = {} |
| 192 | + body_bytes = 0 |
| 193 | + body_count = 0 |
| 194 | + |
| 195 | + for obj in gc.get_objects(): |
| 196 | + if not isinstance(obj, BaseEvent): |
| 197 | + continue |
| 198 | + by_type[obj.type] = by_type.get(obj.type, 0) + 1 |
| 199 | + if obj.type == "HTTP_RESPONSE": |
| 200 | + data = getattr(obj, "data", None) |
| 201 | + body = data.get("body") if isinstance(data, dict) else None |
| 202 | + if body: |
| 203 | + body_count += 1 |
| 204 | + body_bytes += len(body) |
| 205 | + |
| 206 | + return { |
| 207 | + "live_events": sum(by_type.values()), |
| 208 | + "by_type": by_type, |
| 209 | + "http_response_body_mb": round(body_bytes / 1024 / 1024, 2), |
| 210 | + "http_response_with_body": body_count, |
| 211 | + } |
| 212 | + |
| 213 | + |
| 214 | +def lineage_census(top_n=20): |
| 215 | + """ |
| 216 | + Walk every live BaseEvent's parent chain back to the seed and bucket |
| 217 | + pinned events by seed. |
| 218 | +
|
| 219 | + A long-lived seed pinning hundreds of events is the signature of the |
| 220 | + chain-retention pathology — this is the metric that proves a fix. |
| 221 | +
|
| 222 | + Returns:: |
| 223 | +
|
| 224 | + { |
| 225 | + "seeds": [{"seed": str, "pinned_events": int, |
| 226 | + "max_chain_depth": int, |
| 227 | + "types_pinned": {type: count, ...}}, ...], # top N |
| 228 | + "total_pinned_events": int, |
| 229 | + "max_chain_depth": int, |
| 230 | + "live_events_walked": int, |
| 231 | + } |
| 232 | + """ |
| 233 | + from bbot.core.event.base import BaseEvent |
| 234 | + |
| 235 | + gc.collect() |
| 236 | + seeds = {} |
| 237 | + max_chain_depth = 0 |
| 238 | + walked = 0 |
| 239 | + |
| 240 | + for obj in gc.get_objects(): |
| 241 | + if not isinstance(obj, BaseEvent): |
| 242 | + continue |
| 243 | + walked += 1 |
| 244 | + depth = 0 |
| 245 | + node = obj |
| 246 | + # Walk up via .parent. Root events are self-parented (node.parent is node). |
| 247 | + while node is not None and node is not getattr(node, "parent", None): |
| 248 | + depth += 1 |
| 249 | + parent = getattr(node, "parent", None) |
| 250 | + if parent is None or parent is node: |
| 251 | + break |
| 252 | + node = parent |
| 253 | + if node is None: |
| 254 | + continue |
| 255 | + max_chain_depth = max(max_chain_depth, depth) |
| 256 | + seed_data = getattr(node, "data", None) |
| 257 | + seed_data_str = str(seed_data)[:80] if seed_data is not None else "<no-data>" |
| 258 | + seed_key = f"{node.type}:{seed_data_str}" |
| 259 | + s = seeds.setdefault( |
| 260 | + seed_key, |
| 261 | + {"pinned_events": 0, "max_chain_depth": 0, "types_pinned": {}}, |
| 262 | + ) |
| 263 | + s["pinned_events"] += 1 |
| 264 | + s["max_chain_depth"] = max(s["max_chain_depth"], depth) |
| 265 | + s["types_pinned"][obj.type] = s["types_pinned"].get(obj.type, 0) + 1 |
| 266 | + |
| 267 | + sorted_seeds = sorted( |
| 268 | + ({"seed": k, **v} for k, v in seeds.items()), |
| 269 | + key=lambda x: -x["pinned_events"], |
| 270 | + ) |
| 271 | + return { |
| 272 | + "seeds": sorted_seeds[:top_n], |
| 273 | + "total_pinned_events": sum(v["pinned_events"] for v in seeds.values()), |
| 274 | + "max_chain_depth": max_chain_depth, |
| 275 | + "live_events_walked": walked, |
| 276 | + } |
| 277 | + |
| 278 | + |
| 279 | +def queue_residence(scan, tracker): |
| 280 | + """ |
| 281 | + Classify live events by where they currently live in the scan |
| 282 | + pipeline. |
| 283 | +
|
| 284 | + Two angles, since they answer different questions: |
| 285 | +
|
| 286 | + - Per-module queue depth (``by_queue``): which module is the |
| 287 | + current bottleneck — events are piling up where queue depth is |
| 288 | + large. |
| 289 | + - In-pipeline vs. chain-only (``in_pipeline`` / ``chain_only``): |
| 290 | + of the live events, how many are still being processed by some |
| 291 | + module (``event._module_consumers > 0``) vs. how many are |
| 292 | + held alive *only* by the parent chain. The chain_only count is |
| 293 | + the chain-retention pathology made directly visible. |
| 294 | +
|
| 295 | + Returns:: |
| 296 | +
|
| 297 | + { |
| 298 | + "in_pipeline": int, # _module_consumers > 0 |
| 299 | + "chain_only": int, # _module_consumers == 0 |
| 300 | + "queue_total": int, # sum of all module queue depths |
| 301 | + "by_queue": {module_name: {"incoming": int, "outgoing": int}, ...}, |
| 302 | + } |
| 303 | + """ |
| 304 | + by_queue = {} |
| 305 | + for module in scan.modules.values(): |
| 306 | + m_name = getattr(module, "name", str(module)) |
| 307 | + in_q = getattr(module, "_incoming_event_queue", None) |
| 308 | + out_q = getattr(module, "_outgoing_event_queue", None) |
| 309 | + in_count = 0 |
| 310 | + out_count = 0 |
| 311 | + # asyncio.Queue / ShuffleQueue both expose the underlying ``_queue`` deque. |
| 312 | + if in_q and in_q is not False and hasattr(in_q, "_queue"): |
| 313 | + in_count = len(in_q._queue) |
| 314 | + if out_q and hasattr(out_q, "_queue"): |
| 315 | + out_count = len(out_q._queue) |
| 316 | + if in_count or out_count: |
| 317 | + by_queue[m_name] = {"incoming": in_count, "outgoing": out_count} |
| 318 | + |
| 319 | + in_pipeline = 0 |
| 320 | + chain_only = 0 |
| 321 | + # Iterate the WeakSet directly — O(live events). |
| 322 | + for ev in tracker._events: |
| 323 | + if getattr(ev, "_module_consumers", 0) > 0: |
| 324 | + in_pipeline += 1 |
| 325 | + else: |
| 326 | + chain_only += 1 |
| 327 | + |
| 328 | + return { |
| 329 | + "in_pipeline": in_pipeline, |
| 330 | + "chain_only": chain_only, |
| 331 | + "queue_total": sum(d["incoming"] + d["outgoing"] for d in by_queue.values()), |
| 332 | + "by_queue": by_queue, |
| 333 | + } |
| 334 | + |
| 335 | + |
| 336 | +def emit_metrics_json(**metrics): |
| 337 | + """ |
| 338 | + Print a single ``METRICS_JSON:`` line that ``test_scan_memory.py`` parses. |
| 339 | +
|
| 340 | + Backward-compat: also prints the legacy ``PEAK_MB:`` line when |
| 341 | + ``peak_tracemalloc_mb`` is supplied, so any external readers still work. |
| 342 | + """ |
| 343 | + print(f"METRICS_JSON:{json.dumps(metrics)}") |
| 344 | + if "peak_tracemalloc_mb" in metrics: |
| 345 | + print(f"PEAK_MB:{metrics['peak_tracemalloc_mb']}") |
0 commit comments