-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathmisc.py
More file actions
406 lines (319 loc) · 13.7 KB
/
misc.py
File metadata and controls
406 lines (319 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
import orjson
import logging
import inspect
from bson import ObjectId
from functools import wraps
from types import UnionType
from inspect import signature
from pydantic import BaseModel, create_model
from datetime import datetime, timezone, timedelta
from typing import Any, Union, get_origin, get_args, get_type_hints, Annotated
from bbot_server.errors import BBOTServerValueError
log = logging.getLogger("bbot_server.utils.misc")
def unwrap_type_annotation(type_anno) -> Any:
"""
Recursively unwrap type annotations to get the underlying type.
Handles:
- Direct types: returns as-is
- Annotated types: Annotated[Model, ...] -> Model
- Optional types: Model | None, Union[Model, None] -> Model
- Nested: Annotated[Model | None, ...] -> Model
Returns the unwrapped type, or None if only None was found.
"""
origin = get_origin(type_anno)
# Handle Annotated[Type, metadata...] - unwrap to get the actual type
if origin is Annotated:
args = get_args(type_anno)
if args:
return unwrap_type_annotation(args[0])
# Handle Union[Type1, Type2, ...] or Type | None
elif origin in (Union, UnionType):
for arg in get_args(type_anno):
# Skip None type
if arg is type(None):
continue
# Recursively unwrap this union member
return unwrap_type_annotation(arg)
# Base case: return the type as-is
return type_anno
def detect_translatable_function(fn):
"""
Detecth whether a function meets the requirements for human-friendly translation.
i.e. if the function accepts only one arg that is also a pydantic model, it is eligible for translation
Returns a two-tuple of the parameter name and pydantic model class, if eligible
"""
type_hints = get_type_hints(fn)
# Get all parameters excluding 'self' and 'return'
params = {k: v for k, v in type_hints.items() if k not in ("self", "return")}
# If there's exactly one parameter
if len(params) == 1:
param_name, type_anno = next(iter(params.items()))
unwrapped_type = unwrap_type_annotation(type_anno)
# if the single param's annotation type is a pydantic model
if isinstance(unwrapped_type, type) and issubclass(unwrapped_type, BaseModel):
return param_name, unwrapped_type
return None, None
def convert_human_args(fn, param_name, model_class, *args, **kwargs):
model_param = kwargs.get(param_name)
# if the param's name and type match the annotation, we can return as is
if model_param and isinstance(model_param, BaseModel):
# Already have the model, pass through
return args, kwargs
# Try to bind the arguments to understand what we received
# This handles the case where the correct model is passed in via args instead of kwargs
bound = None
try:
fn_sig = signature(fn)
bound = fn_sig.bind_partial(*args, **kwargs)
except TypeError:
pass
# if the bind worked,
if bound is not None:
bound.apply_defaults()
model_param = bound.arguments.get(param_name)
# Check if the parameter is already in the bound arguments
if model_param and isinstance(model_param, model_class):
# The model was passed positionally, pass through
return args, kwargs
# Otherwise, assume kwargs are individual model attributes
# Build the model from those kwargs and call with proper args
model_instance = model_class(**kwargs)
return args, {param_name: model_instance}
def human_friendly_kwargs(fn):
"""
This function wrapper makes BBOT server functions more human-friendly.
For endpoints that accept only one pydantic model (e.g. an Query object with a bunch of attributes),
instead of having to import and instantiate this pydantic model, you can simply call the function,
and the pydantic model will be automatically instantiated from your kwargs.
E.g. instead, of doing:
from bbot_server.modules.findings_models import FindingQuery
query = FindingQuery(search="apache")
bbserver.query_findings(query)
You can do:
bbserver.query_findings(search="apache")
"""
param_name, model_class = detect_translatable_function(fn)
if param_name is not None:
if inspect.iscoroutinefunction(fn):
@wraps(fn)
async def wrapper(*args, **kwargs):
args, kwargs = convert_human_args(fn, param_name, model_class, *args, **kwargs)
return await fn(*args, **kwargs)
elif inspect.isasyncgenfunction(fn):
@wraps(fn)
async def wrapper(*args, **kwargs):
args, kwargs = convert_human_args(fn, param_name, model_class, *args, **kwargs)
async for _ in fn(*args, **kwargs):
yield _
else:
@wraps(fn)
def wrapper(*args, **kwargs):
args, kwargs = convert_human_args(fn, param_name, model_class, *args, **kwargs)
return fn(*args, **kwargs)
return wrapper
return fn
def utc_now() -> float:
return datetime.now(timezone.utc).timestamp()
def seconds_to_human(seconds: float) -> str:
"""
Convert seconds to a human-friendly string representation using timedelta.
Only includes time units that are non-zero, from largest to smallest.
Args:
seconds: Number of seconds to convert
Returns:
Human-readable string like "2 days, 5 hours, 30 minutes"
"""
# Convert seconds to timedelta
delta = timedelta(seconds=seconds)
# Extract components
days = delta.days
hours, remainder = divmod(delta.seconds, 3600)
minutes, seconds = divmod(remainder, 60)
# Build the string parts
parts = []
if days > 0:
parts.append(f"{days} day{'s' if days != 1 else ''}")
if hours > 0:
parts.append(f"{hours} hour{'s' if hours != 1 else ''}")
if minutes > 0:
parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}")
if seconds > 0 or not parts: # Include seconds if non-zero or if all other units are zero
parts.append(f"{seconds} second{'s' if seconds != 1 else ''}")
# Join the parts with commas
return ", ".join(parts)
def timestamp_to_human(timestamp: float, include_hours: bool = True) -> str:
if include_hours:
format_str = "%Y-%m-%d %H:%M:%S"
else:
format_str = "%Y-%m-%d"
return datetime.fromtimestamp(timestamp).strftime(format_str)
def orjson_serializer(obj: Any) -> Any:
"""
Enable orjson to serialize Mongo's ObjectIds
"""
if isinstance(obj, ObjectId):
return str(obj)
return obj
def smart_encode(obj: Any) -> bytes:
# handle both python and pydantic objects, as well as strings
if isinstance(obj, BaseModel):
return obj.model_dump_json().encode()
elif isinstance(obj, str):
return obj.encode()
elif isinstance(obj, bytes):
return obj
else:
return orjson.dumps(obj, default=orjson_serializer)
def combine_pydantic_models(models, model_name, base_model=BaseModel):
"""
Combines multiple pydantic models into a single model.
Args:
models: list of pydantic models to combine
model_name: name of the new model
"""
combined_fields = {field_name: (field.annotation, field) for field_name, field in base_model.model_fields.items()}
for model in models:
try:
model_fields = model.model_fields
except AttributeError as e:
raise ValueError(f'Model {model.__name__} has no attribute "model_fields"') from e
for field_name, field in model_fields.items():
if field_name in combined_fields:
current_annotation, _ = combined_fields[field_name]
if field.annotation != current_annotation:
raise ValueError(
f'Field "{field_name}" on {model.__name__} already exists, but with a different annotation: ({current_annotation} vs {field.annotation})'
)
else:
combined_fields[field_name] = (field.annotation, field)
# Create the new model with all collected fields
combined_model = create_model(
model_name,
__base__=base_model,
**combined_fields,
)
return combined_model
# fmt: off
# 20260417: removed $jsonSchema because it may reveal internal/private fields
ALLOWED_QUERY_OPERATORS = {
# Query Operators (excluding $where, $expr)
"$eq", "$gt", "$gte", "$in", "$lt", "$lte", "$ne", "$nin",
"$and", "$not", "$nor", "$or",
"$exists", "$type",
"$mod", "$search", "$text", "$regex",
"$geoIntersects", "$geoWithin", "$near", "$nearSphere",
"$all", "$elemMatch", "$size",
"$bitsAllClear", "$bitsAllSet", "$bitsAnyClear", "$bitsAnySet",
"$comment"
}
# fmt: on
def _sanitize_mongo_query(data: Any) -> Any:
"""
Sanitizes a MongoDB query dictionary using a whitelist approach.
Throws a ValueError if any unauthorized operator (key starting with $) is found.
Focused on query operators for find() or $match.
"""
if isinstance(data, dict):
sanitized = {}
for key, value in data.items():
key = key.strip()
if key.startswith("$") and key not in ALLOWED_QUERY_OPERATORS:
raise BBOTServerValueError(f"Unauthorized MongoDB query operator: {key}")
sanitized[key] = _sanitize_mongo_query(value)
return sanitized
elif isinstance(data, list):
return [_sanitize_mongo_query(item) for item in data]
return data
# fmt: off
# 20260417: removed $unionWith because it may allow fetches to other collections
ALLOWED_AGG_OPERATORS = {
# We intentionally exclude $match because it"s automatically added and sanitized separately
# Aggregation Pipeline Stages (excluding $out, $merge, $lookup, $graphLookup)
"$addFields", "$bucket", "$bucketAuto", "$collStats", "$count",
"$densify", "$documents", "$facet", "$fill", "$geoNear",
"$group", "$indexStats", "$limit", "$listSessions",
"$planCacheStats", "$project", "$redact",
"$replaceRoot", "$replaceWith", "$sample", "$search", "$searchMeta",
"$set", "$setWindowFields", "$skip", "$sort", "$sortByCount",
"$unset", "$unwind",
# Aggregation Expression Operators (excluding $function, $accumulator)
# Arithmetic
"$abs", "$add", "$ceil", "$divide", "$exp", "$floor", "$ln",
"$log", "$log10", "$mod", "$multiply", "$pow", "$round",
"$sqrt", "$subtract", "$trunc",
# Array
"$arrayElemAt", "$arrayToObject", "$concatArrays", "$filter",
"$first", "$in", "$indexOfArray", "$isArray", "$last", "$map",
"$objectToArray", "$range", "$reduce", "$reverseArray", "$size",
"$slice", "$sortArray", "$zip",
# Boolean
"$and", "$not", "$or",
# Comparison
"$cmp", "$eq", "$gt", "$gte", "$lt", "$lte", "$ne",
# Conditional
"$cond", "$ifNull", "$switch",
# Data Size
"$binarySize", "$bsonSize",
# Date
"$dateAdd", "$dateDiff", "$dateFromParts", "$dateFromString",
"$dateSubtract", "$dateToParts", "$dateToString", "$dateTrunc",
"$dayOfMonth", "$dayOfWeek", "$dayOfYear", "$hour",
"$isoDayOfWeek", "$isoWeek", "$isoWeekYear", "$millisecond",
"$minute", "$month", "$second", "$toDate", "$week", "$year",
# Diagnostic
"$getField", "$rand", "$sampleRate", "$tsIncrement", "$tsSecond",
# Literal
"$literal",
# Miscellaneous
"$mergeObjects",
# Object
"$getField", "$mergeObjects", "$objectToArray", "$setField",
# Set
"$allElementsTrue", "$anyElementTrue", "$setDifference",
"$setEquals", "$setIntersection", "$setIsSubset", "$setUnion",
# String
"$concat", "$dateFromString", "$dateToString", "$indexOfBytes",
"$indexOfCP", "$ltrim", "$regexFind", "$regexFindAll",
"$regexMatch", "$replaceAll", "$replaceOne", "$rtrim", "$split",
"$strLenBytes", "$strLenCP", "$strcasecmp", "$substr",
"$substrBytes", "$substrCP", "$toLower", "$toString", "$trim",
"$toUpper",
# Text Search
"$meta",
# Trigonometry
"$sin", "$cos", "$tan", "$asin", "$acos", "$atan", "$atan2",
"$asinh", "$acosh", "$atanh", "$degreesToRadians", "$radiansToDegrees",
# Type
"$convert", "$isNumber", "$toBool", "$toDate", "$toDecimal",
"$toDouble", "$toInt", "$toLong", "$toObjectId", "$toString", "$type",
# Accumulators (for $group)
"$addToSet", "$avg", "$bottom", "$bottomN", "$count", "$first",
"$firstN", "$last", "$lastN", "$max", "$maxN", "$mergeObjects",
"$min", "$minN", "$push", "$stdDevPop", "$stdDevSamp", "$sum",
"$top", "$topN",
# Window Operators (for $setWindowFields)
"$addToSet", "$avg", "$count", "$covariancePop", "$covarianceSamp",
"$denseRank", "$derivative", "$documentNumber", "$expMovingAvg",
"$first", "$integral", "$last", "$linearFill", "$locf", "$max",
"$min", "$push", "$rank", "$shift", "$stdDevPop", "$stdDevSamp", "$sum",
# Variable
"$let"
}
# fmt: on
def _sanitize_mongo_aggregation(data: Any) -> Any:
"""
Sanitizes a MongoDB aggregation pipeline or expression dictionary using a whitelist approach.
Throws a ValueError if any unauthorized operator or stage (key starting with $) is found.
Focused on aggregation stages and expressions.
"""
if isinstance(data, dict):
sanitized = {}
for key, value in data.items():
key = key.strip()
if key.startswith("$") and key not in ALLOWED_AGG_OPERATORS:
raise BBOTServerValueError(f"Unauthorized MongoDB aggregation operator: {key}")
sanitized[key] = _sanitize_mongo_aggregation(value)
return sanitized
elif isinstance(data, list):
return [_sanitize_mongo_aggregation(item) for item in data]
return data