-
Notifications
You must be signed in to change notification settings - Fork 332
Expand file tree
/
Copy pathOverheadController.java
More file actions
371 lines (319 loc) · 12.3 KB
/
OverheadController.java
File metadata and controls
371 lines (319 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
package com.datadog.iast.overhead;
import static com.datadog.iast.overhead.OverheadContext.globalMap;
import static datadog.trace.api.iast.IastDetectionMode.UNLIMITED;
import com.datadog.iast.IastRequestContext;
import com.datadog.iast.IastSystem;
import com.datadog.iast.model.VulnerabilityType;
import com.datadog.iast.util.NonBlockingSemaphore;
import datadog.trace.api.Config;
import datadog.trace.api.gateway.RequestContext;
import datadog.trace.api.gateway.RequestContextSlot;
import datadog.trace.api.iast.IastContext;
import datadog.trace.api.iast.VulnerabilityTypes;
import datadog.trace.api.telemetry.LogCollector;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
import datadog.trace.bootstrap.instrumentation.api.Tags;
import datadog.trace.util.AgentTaskScheduler;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public interface OverheadController {
boolean acquireRequest();
void reset();
int releaseRequest();
boolean hasQuota(Operation operation, @Nullable AgentSpan span);
boolean consumeQuota(Operation operation, @Nullable AgentSpan span);
boolean consumeQuota(
Operation operation, @Nullable AgentSpan span, @Nullable VulnerabilityType type);
static OverheadController build(final Config config, final AgentTaskScheduler scheduler) {
return build(
config.getIastRequestSampling(),
config.getIastMaxConcurrentRequests(),
config.getIastContextMode() == IastContext.Mode.GLOBAL,
scheduler);
}
static OverheadController build(
final float requestSampling,
final int maxConcurrentRequests,
final boolean globalFallback,
final AgentTaskScheduler scheduler) {
final OverheadControllerImpl result =
new OverheadControllerImpl(
requestSampling, maxConcurrentRequests, globalFallback, scheduler);
return IastSystem.DEBUG ? new OverheadControllerDebugAdapter(result) : result;
}
class OverheadControllerDebugAdapter implements OverheadController {
static Logger LOGGER = LoggerFactory.getLogger(OverheadController.class);
private final OverheadControllerImpl delegate;
public OverheadControllerDebugAdapter(final OverheadControllerImpl delegate) {
this.delegate = delegate;
}
@Override
public boolean acquireRequest() {
final boolean result = delegate.acquireRequest();
if (LOGGER.isDebugEnabled()) {
final int available = delegate.availableRequests.available();
LOGGER.debug(
"acquireRequest: acquired={}, availableRequests={}, span={}",
result,
available,
AgentTracer.activeSpan());
}
return result;
}
@Override
public int releaseRequest() {
int result = delegate.releaseRequest();
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"releaseRequest: availableRequests={}, span={}", result, AgentTracer.activeSpan());
}
return result;
}
@Override
public boolean hasQuota(final Operation operation, @Nullable final AgentSpan span) {
final boolean result = delegate.hasQuota(operation, span);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"hasQuota: operation={}, result={}, availableQuota={}, span={}",
operation,
result,
getAvailableQuote(span),
span);
}
return result;
}
@Override
public boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span) {
return consumeQuota(operation, span, null);
}
@Override
public boolean consumeQuota(
final Operation operation,
@Nullable final AgentSpan span,
@Nullable final VulnerabilityType type) {
final boolean result = delegate.consumeQuota(operation, span, type);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"consumeQuota: operation={}, result={}, availableQuota={}, span={}, type={}",
operation,
result,
getAvailableQuote(span),
span,
type);
}
return result;
}
@Override
public void reset() {
delegate.reset();
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("reset: span={}", AgentTracer.activeSpan());
}
}
private int getAvailableQuote(@Nullable final AgentSpan span) {
final OverheadContext context = delegate.getContext(span);
return context == null ? -1 : context.getAvailableQuota();
}
}
class OverheadControllerImpl implements OverheadController {
private static final Logger LOGGER = LoggerFactory.getLogger(OverheadControllerImpl.class);
private static final int RESET_PERIOD_SECONDS = 30;
private final int sampling;
/**
* Fallback to use the global context instance when no IAST context is present in the active
* span
*/
private final boolean useGlobalAsFallback;
final NonBlockingSemaphore availableRequests;
final AtomicLong cumulativeCounter;
private volatile long lastAcquiredTimestamp = Long.MAX_VALUE;
final OverheadContext globalContext =
new OverheadContext(Config.get().getIastVulnerabilitiesPerRequest(), true);
public OverheadControllerImpl(
final float requestSampling,
final int maxConcurrentRequests,
final boolean useGlobalAsFallback,
final AgentTaskScheduler taskScheduler) {
this.sampling = computeSamplingParameter(requestSampling);
availableRequests = maxConcurrentRequests(maxConcurrentRequests);
cumulativeCounter = new AtomicLong(sampling);
this.useGlobalAsFallback = useGlobalAsFallback;
if (taskScheduler != null) {
taskScheduler.scheduleAtFixedRate(
this::reset, 2 * RESET_PERIOD_SECONDS, RESET_PERIOD_SECONDS, TimeUnit.SECONDS);
}
}
@Override
public boolean acquireRequest() {
long prevValue = cumulativeCounter.getAndAdd(sampling);
long newValue = prevValue + sampling;
if (newValue / 100 == prevValue / 100 + 1) {
// Sample request
final boolean acquired = availableRequests.acquire();
if (acquired) {
lastAcquiredTimestamp = System.currentTimeMillis();
}
return acquired;
}
// Skipped by sampling
return false;
}
@Override
public int releaseRequest() {
return availableRequests.release();
}
@Override
public boolean hasQuota(final Operation operation, @Nullable final AgentSpan span) {
return operation.hasQuota(getContext(span));
}
@Override
public boolean consumeQuota(final Operation operation, @Nullable final AgentSpan span) {
return consumeQuota(operation, span, null);
}
@Override
public boolean consumeQuota(
final Operation operation,
@Nullable final AgentSpan span,
@Nullable final VulnerabilityType type) {
OverheadContext ctx = getContext(span);
if (ctx == null) {
return false;
}
if (ctx.isGlobal()) {
return operation.consumeQuota(ctx);
}
if (operation.hasQuota(ctx)) {
String method = null;
String path = null;
if (span != null) {
AgentSpan rootSpan = span.getLocalRootSpan();
Object methodTag = rootSpan.getTag(Tags.HTTP_METHOD);
method = (methodTag == null) ? "" : methodTag.toString();
Object routeTag = rootSpan.getTag(Tags.HTTP_ROUTE);
path = (routeTag == null) ? getHttpRouteFromRequestContext(span) : routeTag.toString();
}
if (!maybeSkipVulnerability(ctx, type, method, path)) {
return operation.consumeQuota(ctx);
}
}
return false;
}
/**
* Method to be called when a vulnerability of a certain type is detected. Implements the
* RFC-1029 algorithm.
*
* @param ctx the overhead context for the current request
* @param type the type of vulnerability detected
* @param httpMethod the HTTP method of the request (e.g., GET, POST)
* @param httpPath the HTTP path of the request
* @return true if the vulnerability should be skipped, false otherwise
*/
private boolean maybeSkipVulnerability(
@Nullable final OverheadContext ctx,
@Nullable final VulnerabilityType type,
@Nullable final String httpMethod,
@Nullable final String httpPath) {
if (ctx == null || type == null || ctx.getRequestMap() == null || ctx.getCopyMap() == null) {
return false;
}
int numberOfVulnerabilities = VulnerabilityTypes.STRINGS.length;
String currentEndpoint = httpMethod + " " + httpPath;
AtomicIntegerArray requestArray = ctx.getRequestMap().get(currentEndpoint);
int[] copyArray;
if (requestArray == null) {
AtomicIntegerArray globalArray =
globalMap.computeIfAbsent(
currentEndpoint, k -> new AtomicIntegerArray(numberOfVulnerabilities));
copyArray = toIntArray(globalArray);
ctx.getCopyMap().put(currentEndpoint, copyArray);
requestArray =
ctx.getRequestMap()
.computeIfAbsent(
currentEndpoint, k -> new AtomicIntegerArray(numberOfVulnerabilities));
} else {
copyArray = ctx.getCopyMap().get(currentEndpoint);
}
int counter = requestArray.getAndIncrement(type.type());
int storedCounter = 0;
if (copyArray != null) {
storedCounter = copyArray[type.type()];
}
return counter < storedCounter;
}
private static int[] toIntArray(AtomicIntegerArray atomic) {
int length = atomic.length();
int[] result = new int[length];
for (int i = 0; i < length; i++) {
result[i] = atomic.get(i);
}
return result;
}
@Nullable
public OverheadContext getContext(@Nullable final AgentSpan span) {
final RequestContext requestContext = span != null ? span.getRequestContext() : null;
if (requestContext != null) {
IastRequestContext iastRequestContext = requestContext.getData(RequestContextSlot.IAST);
if (iastRequestContext != null) {
return iastRequestContext.getOverheadContext();
}
if (!useGlobalAsFallback) {
return null;
}
}
return globalContext;
}
@Nullable
public String getHttpRouteFromRequestContext(@Nullable final AgentSpan span) {
String httpRoute = null;
final RequestContext requestContext = span != null ? span.getRequestContext() : null;
if (requestContext != null) {
IastRequestContext iastRequestContext = requestContext.getData(RequestContextSlot.IAST);
if (iastRequestContext != null) {
httpRoute = iastRequestContext.getRoute();
}
}
return httpRoute == null ? "" : httpRoute;
}
static int computeSamplingParameter(final float pct) {
if (pct >= 100) {
return 100;
}
if (pct <= 0) {
// We don't support disabling IAST by setting it, so we set it to 100%.
// TODO: We probably want a warning here.
return 100;
}
return (int) pct;
}
static NonBlockingSemaphore maxConcurrentRequests(final int max) {
return max == UNLIMITED
? NonBlockingSemaphore.unlimited()
: NonBlockingSemaphore.withPermitCount(max);
}
@Override
public void reset() {
globalContext.reset();
if (lastAcquiredTimestamp != Long.MAX_VALUE
&& System.currentTimeMillis() - lastAcquiredTimestamp > 1000 * 60 * 60) {
// If the last time a request was acquired is longer than 1h, we check the number of
// available requests. If it
// is zero, we might have lost request end events, leading to IAST not being able to acquire
// new requests.
// We report it to telemetry for further investigation.
if (availableRequests.available() == 0) {
LOGGER.debug(
LogCollector.SEND_TELEMETRY,
"IAST cannot acquire new requests, end of request events might be missing.");
// Once starved, do not report this again, unless this is recovered and then starved
// again.
lastAcquiredTimestamp = Long.MAX_VALUE;
}
}
}
}
}