Skip to content

Commit 2cb07bc

Browse files
authored
Merge pull request #3083 from blacklanternsecurity/scan-memory-benchmarks
Scan memory benchmarks: RSS sampling, mid-scan census, parallel-chains workload
2 parents ce8a832 + 688dcfa commit 2cb07bc

6 files changed

Lines changed: 910 additions & 33 deletions

File tree

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
"""
2+
Shared memory measurement helpers for scan benchmarks.
3+
4+
Each helper is independent — pick the ones you need. Designed to be used
5+
from subprocess scan scripts (so tracemalloc + psutil polling are not
6+
contaminated by pytest's own allocations).
7+
8+
Three measurement angles, each answering a different question:
9+
10+
- ``RSSSampler``: actual OS-level RAM (catches Rust / lxml / yara that
11+
tracemalloc misses). Headline metrics: peak, end, retention (median
12+
over the last 25% of samples — the metric most sensitive to
13+
"stuck for the rest of the scan" pathology).
14+
15+
- ``event_census``: who is alive right now, grouped by event type, with
16+
HTTP_RESPONSE body bytes broken out separately. Answers
17+
"where is the memory going?".
18+
19+
- ``lineage_census``: walk every live event's parent chain back to the
20+
seed and bucket pinned events by seed. Answers "is the chain
21+
holding things alive?".
22+
"""
23+
24+
import gc
25+
import json
26+
import threading
27+
import time
28+
import weakref
29+
30+
import psutil
31+
32+
33+
class RSSSampler:
34+
"""
35+
Background thread that polls process RSS at a fixed interval.
36+
37+
Usage::
38+
39+
sampler = RSSSampler(interval_s=0.2)
40+
sampler.start()
41+
# ... do work ...
42+
sampler.stop()
43+
m = sampler.metrics() # peak_rss_mb, end_rss_mb, retention_rss_mb
44+
"""
45+
46+
def __init__(self, interval_s=0.2):
47+
self.interval_s = interval_s
48+
self.samples = [] # list of (t, rss_mb)
49+
self._stop = threading.Event()
50+
self._thread = None
51+
self._proc = psutil.Process()
52+
53+
def start(self):
54+
self._thread = threading.Thread(target=self._loop, daemon=True)
55+
self._thread.start()
56+
57+
def _loop(self):
58+
t0 = time.monotonic()
59+
while not self._stop.is_set():
60+
rss_mb = self._proc.memory_info().rss / 1024 / 1024
61+
self.samples.append((time.monotonic() - t0, rss_mb))
62+
self._stop.wait(self.interval_s)
63+
64+
def stop(self):
65+
self._stop.set()
66+
if self._thread is not None:
67+
self._thread.join(timeout=1.0)
68+
69+
def metrics(self):
70+
if not self.samples:
71+
return {
72+
"peak_rss_mb": 0.0,
73+
"end_rss_mb": 0.0,
74+
"retention_rss_mb": 0.0,
75+
"samples": 0,
76+
"duration_s": 0.0,
77+
}
78+
rss_values = [r for _, r in self.samples]
79+
peak_rss_mb = max(rss_values)
80+
end_rss_mb = rss_values[-1]
81+
# median over the last 25% of samples — robust to a single
82+
# transient spike at the tail and sensitive to baseline drift.
83+
last_quartile_start = max(1, int(len(rss_values) * 0.75))
84+
last_quartile = sorted(rss_values[last_quartile_start:])
85+
retention_rss_mb = last_quartile[len(last_quartile) // 2]
86+
duration_s = self.samples[-1][0]
87+
return {
88+
"peak_rss_mb": round(peak_rss_mb, 2),
89+
"end_rss_mb": round(end_rss_mb, 2),
90+
"retention_rss_mb": round(retention_rss_mb, 2),
91+
"samples": len(self.samples),
92+
"duration_s": round(duration_s, 2),
93+
}
94+
95+
96+
class LiveEventTracker:
97+
"""
98+
Mid-scan event census via weakref, without scanning the full Python
99+
object graph.
100+
101+
``event_census()`` calls ``gc.get_objects()`` and walks every Python
102+
object — fine at scan-end (one-shot), too expensive to call every
103+
few hundred events. ``LiveEventTracker`` patches
104+
``BaseEvent.__init__`` to register newcomers into a ``WeakSet``, so
105+
``census()`` is O(live events) instead of O(every Python object).
106+
107+
Usage::
108+
109+
tracker = LiveEventTracker()
110+
tracker.install()
111+
# ... run scan, calling tracker.census() periodically ...
112+
tracker.uninstall() # optional; not needed if the process exits
113+
"""
114+
115+
def __init__(self):
116+
self._events = weakref.WeakSet()
117+
self._original_init = None
118+
self._patched = False
119+
120+
def install(self):
121+
if self._patched:
122+
return
123+
from bbot.core.event.base import BaseEvent
124+
125+
# Seed with any events that already exist (e.g. SCAN root_event
126+
# created before install).
127+
for obj in gc.get_objects():
128+
if isinstance(obj, BaseEvent):
129+
self._events.add(obj)
130+
131+
original_init = BaseEvent.__init__
132+
events_ref = self._events
133+
134+
def patched_init(self_evt, *a, **kw):
135+
original_init(self_evt, *a, **kw)
136+
try:
137+
events_ref.add(self_evt)
138+
except TypeError:
139+
# Not weakref-able for some reason; skip silently.
140+
pass
141+
142+
self._original_init = original_init
143+
BaseEvent.__init__ = patched_init
144+
self._patched = True
145+
146+
def uninstall(self):
147+
if not self._patched:
148+
return
149+
from bbot.core.event.base import BaseEvent
150+
151+
BaseEvent.__init__ = self._original_init
152+
self._patched = False
153+
154+
def census(self):
155+
"""O(live events) census. Same shape as ``event_census()``."""
156+
by_type = {}
157+
body_bytes = 0
158+
body_count = 0
159+
for obj in self._events:
160+
by_type[obj.type] = by_type.get(obj.type, 0) + 1
161+
if obj.type == "HTTP_RESPONSE":
162+
data = getattr(obj, "data", None)
163+
body = data.get("body") if isinstance(data, dict) else None
164+
if body:
165+
body_count += 1
166+
body_bytes += len(body)
167+
return {
168+
"live_events": sum(by_type.values()),
169+
"by_type": by_type,
170+
"http_response_body_mb": round(body_bytes / 1024 / 1024, 2),
171+
"http_response_with_body": body_count,
172+
}
173+
174+
175+
def event_census():
176+
"""
177+
Walk live BaseEvent instances and classify by type + HTTP_RESPONSE body bytes.
178+
179+
Returns a dict::
180+
181+
{
182+
"live_events": int,
183+
"by_type": {"DNS_NAME": int, "HTTP_RESPONSE": int, ...},
184+
"http_response_body_mb": float,
185+
"http_response_with_body": int,
186+
}
187+
"""
188+
from bbot.core.event.base import BaseEvent
189+
190+
gc.collect()
191+
by_type = {}
192+
body_bytes = 0
193+
body_count = 0
194+
195+
for obj in gc.get_objects():
196+
if not isinstance(obj, BaseEvent):
197+
continue
198+
by_type[obj.type] = by_type.get(obj.type, 0) + 1
199+
if obj.type == "HTTP_RESPONSE":
200+
data = getattr(obj, "data", None)
201+
body = data.get("body") if isinstance(data, dict) else None
202+
if body:
203+
body_count += 1
204+
body_bytes += len(body)
205+
206+
return {
207+
"live_events": sum(by_type.values()),
208+
"by_type": by_type,
209+
"http_response_body_mb": round(body_bytes / 1024 / 1024, 2),
210+
"http_response_with_body": body_count,
211+
}
212+
213+
214+
def lineage_census(top_n=20):
215+
"""
216+
Walk every live BaseEvent's parent chain back to the seed and bucket
217+
pinned events by seed.
218+
219+
A long-lived seed pinning hundreds of events is the signature of the
220+
chain-retention pathology — this is the metric that proves a fix.
221+
222+
Returns::
223+
224+
{
225+
"seeds": [{"seed": str, "pinned_events": int,
226+
"max_chain_depth": int,
227+
"types_pinned": {type: count, ...}}, ...], # top N
228+
"total_pinned_events": int,
229+
"max_chain_depth": int,
230+
"live_events_walked": int,
231+
}
232+
"""
233+
from bbot.core.event.base import BaseEvent
234+
235+
gc.collect()
236+
seeds = {}
237+
max_chain_depth = 0
238+
walked = 0
239+
240+
for obj in gc.get_objects():
241+
if not isinstance(obj, BaseEvent):
242+
continue
243+
walked += 1
244+
depth = 0
245+
node = obj
246+
# Walk up via .parent. Root events are self-parented (node.parent is node).
247+
while node is not None and node is not getattr(node, "parent", None):
248+
depth += 1
249+
parent = getattr(node, "parent", None)
250+
if parent is None or parent is node:
251+
break
252+
node = parent
253+
if node is None:
254+
continue
255+
max_chain_depth = max(max_chain_depth, depth)
256+
seed_data = getattr(node, "data", None)
257+
seed_data_str = str(seed_data)[:80] if seed_data is not None else "<no-data>"
258+
seed_key = f"{node.type}:{seed_data_str}"
259+
s = seeds.setdefault(
260+
seed_key,
261+
{"pinned_events": 0, "max_chain_depth": 0, "types_pinned": {}},
262+
)
263+
s["pinned_events"] += 1
264+
s["max_chain_depth"] = max(s["max_chain_depth"], depth)
265+
s["types_pinned"][obj.type] = s["types_pinned"].get(obj.type, 0) + 1
266+
267+
sorted_seeds = sorted(
268+
({"seed": k, **v} for k, v in seeds.items()),
269+
key=lambda x: -x["pinned_events"],
270+
)
271+
return {
272+
"seeds": sorted_seeds[:top_n],
273+
"total_pinned_events": sum(v["pinned_events"] for v in seeds.values()),
274+
"max_chain_depth": max_chain_depth,
275+
"live_events_walked": walked,
276+
}
277+
278+
279+
def queue_residence(scan, tracker):
280+
"""
281+
Classify live events by where they currently live in the scan
282+
pipeline.
283+
284+
Two angles, since they answer different questions:
285+
286+
- Per-module queue depth (``by_queue``): which module is the
287+
current bottleneck — events are piling up where queue depth is
288+
large.
289+
- In-pipeline vs. chain-only (``in_pipeline`` / ``chain_only``):
290+
of the live events, how many are still being processed by some
291+
module (``event._module_consumers > 0``) vs. how many are
292+
held alive *only* by the parent chain. The chain_only count is
293+
the chain-retention pathology made directly visible.
294+
295+
Returns::
296+
297+
{
298+
"in_pipeline": int, # _module_consumers > 0
299+
"chain_only": int, # _module_consumers == 0
300+
"queue_total": int, # sum of all module queue depths
301+
"by_queue": {module_name: {"incoming": int, "outgoing": int}, ...},
302+
}
303+
"""
304+
by_queue = {}
305+
for module in scan.modules.values():
306+
m_name = getattr(module, "name", str(module))
307+
in_q = getattr(module, "_incoming_event_queue", None)
308+
out_q = getattr(module, "_outgoing_event_queue", None)
309+
in_count = 0
310+
out_count = 0
311+
# asyncio.Queue / ShuffleQueue both expose the underlying ``_queue`` deque.
312+
if in_q and in_q is not False and hasattr(in_q, "_queue"):
313+
in_count = len(in_q._queue)
314+
if out_q and hasattr(out_q, "_queue"):
315+
out_count = len(out_q._queue)
316+
if in_count or out_count:
317+
by_queue[m_name] = {"incoming": in_count, "outgoing": out_count}
318+
319+
in_pipeline = 0
320+
chain_only = 0
321+
# Iterate the WeakSet directly — O(live events).
322+
for ev in tracker._events:
323+
if getattr(ev, "_module_consumers", 0) > 0:
324+
in_pipeline += 1
325+
else:
326+
chain_only += 1
327+
328+
return {
329+
"in_pipeline": in_pipeline,
330+
"chain_only": chain_only,
331+
"queue_total": sum(d["incoming"] + d["outgoing"] for d in by_queue.values()),
332+
"by_queue": by_queue,
333+
}
334+
335+
336+
def emit_metrics_json(**metrics):
337+
"""
338+
Print a single ``METRICS_JSON:`` line that ``test_scan_memory.py`` parses.
339+
340+
Backward-compat: also prints the legacy ``PEAK_MB:`` line when
341+
``peak_tracemalloc_mb`` is supplied, so any external readers still work.
342+
"""
343+
print(f"METRICS_JSON:{json.dumps(metrics)}")
344+
if "peak_tracemalloc_mb" in metrics:
345+
print(f"PEAK_MB:{metrics['peak_tracemalloc_mb']}")

0 commit comments

Comments
 (0)