Skip to content

Commit fbea611

Browse files
authored
Merge branch 'master' into dougqh/utf8-cache-fixes
2 parents adee794 + 15141ae commit fbea611

5 files changed

Lines changed: 132 additions & 5 deletions

File tree

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ public BadConfigurationException(final String message) {
7878
static final String META_STRUCT_SDS = "sds";
7979
static final String META_STRUCT_TAG_PROBS = "tag_probs";
8080

81+
/**
82+
* Anomaly detection tags copied from the local root span onto every {@code ai_guard} span with
83+
* the {@code ai_guard.} prefix, so the AI Guard backend can correlate AI Guard requests with the
84+
* request context (client IP, user, session) without depending on the local root span.
85+
*/
86+
static final String[] ANOMALY_DETECTION_TAGS = {
87+
Tags.HTTP_CLIENT_IP, Tags.NETWORK_CLIENT_IP, Tags.HTTP_USER_AGENT, "usr.id", "usr.session_id"
88+
};
89+
8190
public static void install() {
8291
final Config config = Config.get();
8392
final String apiKey = config.getApiKey();
@@ -241,6 +250,16 @@ private static void applyClientIpTags(final AgentSpan localRootSpan) {
241250
}
242251
}
243252

253+
private static void copyAnomalyDetectionTags(
254+
final AgentSpan span, final AgentSpan localRootSpan) {
255+
for (final String tag : ANOMALY_DETECTION_TAGS) {
256+
final Object value = localRootSpan.getTag(tag);
257+
if (value != null) {
258+
span.setTag("ai_guard." + tag, value.toString());
259+
}
260+
}
261+
}
262+
244263
@Override
245264
public Evaluation evaluate(final List<Message> messages, final Options options) {
246265
if (messages == null || messages.isEmpty()) {
@@ -258,6 +277,9 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
258277
localRootSpan.setTag(Tags.AI_GUARD_KEEP, true);
259278
localRootSpan.setTag(EVENT_TAG, true);
260279
applyClientIpTags(localRootSpan);
280+
// copyAnomalyDetectionTags MUST run after applyClientIpTags, to make
281+
// sure client IP tags were populated.
282+
copyAnomalyDetectionTags(span, localRootSpan);
261283
}
262284
try (final AgentScope scope = tracer.activateSpan(span)) {
263285
final Message last = messages.get(messages.size() - 1);

dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,15 +279,15 @@ class AIGuardInternalTests extends DDSpecification {
279279
final requestContext = Mock(RequestContext)
280280
localRootSpan.getRequestContext() >> requestContext
281281
requestContext.getClientIpAddressData() >> new ClientIpAddressData('4.4.4.4', '2.3.4.5')
282+
localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> null
283+
localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> null
282284
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
283285

284286
when:
285287
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
286288

287289
then:
288-
1 * localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> null
289290
1 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, '4.4.4.4')
290-
1 * localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> null
291291
1 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, '2.3.4.5')
292292
}
293293

@@ -296,15 +296,15 @@ class AIGuardInternalTests extends DDSpecification {
296296
final requestContext = Mock(RequestContext)
297297
localRootSpan.getRequestContext() >> requestContext
298298
requestContext.getClientIpAddressData() >> new ClientIpAddressData('4.4.4.4', '2.3.4.5')
299+
localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> '9.9.9.9'
300+
localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '8.8.8.8'
299301
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
300302

301303
when:
302304
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
303305

304306
then:
305-
1 * localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> '9.9.9.9'
306307
0 * localRootSpan.setTag(Tags.NETWORK_CLIENT_IP, _)
307-
1 * localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '8.8.8.8'
308308
0 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, _)
309309
}
310310

@@ -336,6 +336,46 @@ class AIGuardInternalTests extends DDSpecification {
336336
0 * localRootSpan.setTag(Tags.HTTP_CLIENT_IP, _)
337337
}
338338

339+
void 'test evaluate copies anomaly detection tags from local root span to ai_guard span'() {
340+
given:
341+
localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '1.2.3.4'
342+
localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> '5.6.7.8'
343+
localRootSpan.getTag(Tags.HTTP_USER_AGENT) >> 'curl/8.0'
344+
localRootSpan.getTag('usr.id') >> 'u-123'
345+
localRootSpan.getTag('usr.session_id') >> 's-456'
346+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
347+
348+
when:
349+
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
350+
351+
then:
352+
1 * span.setTag('ai_guard.http.client_ip', '1.2.3.4')
353+
1 * span.setTag('ai_guard.network.client.ip', '5.6.7.8')
354+
1 * span.setTag('ai_guard.http.useragent', 'curl/8.0')
355+
1 * span.setTag('ai_guard.usr.id', 'u-123')
356+
1 * span.setTag('ai_guard.usr.session_id', 's-456')
357+
}
358+
359+
void 'test evaluate skips missing anomaly detection tags'() {
360+
given:
361+
localRootSpan.getTag(Tags.HTTP_CLIENT_IP) >> '1.2.3.4'
362+
localRootSpan.getTag(Tags.NETWORK_CLIENT_IP) >> null
363+
localRootSpan.getTag(Tags.HTTP_USER_AGENT) >> null
364+
localRootSpan.getTag('usr.id') >> 'u-123'
365+
localRootSpan.getTag('usr.session_id') >> null
366+
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'It is fine']]])
367+
368+
when:
369+
aiguard.evaluate(TOOL_CALL, AIGuard.Options.DEFAULT)
370+
371+
then:
372+
1 * span.setTag('ai_guard.http.client_ip', '1.2.3.4')
373+
1 * span.setTag('ai_guard.usr.id', 'u-123')
374+
0 * span.setTag('ai_guard.network.client.ip', _)
375+
0 * span.setTag('ai_guard.http.useragent', _)
376+
0 * span.setTag('ai_guard.usr.session_id', _)
377+
}
378+
339379
void 'test evaluate with API errors'() {
340380
given:
341381
final errors = [[status: 400, title: 'Bad request']]

dd-smoke-tests/appsec/springboot/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ dependencies {
2121
implementation(group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.6.0')
2222
implementation group: 'com.h2database', name: 'h2', version: '2.1.212'
2323

24+
// Used by AIGuardController to set user/session tags on the local root span via the active OT span
25+
implementation group: 'io.opentracing', name: 'opentracing-api', version: '0.32.0'
26+
implementation group: 'io.opentracing', name: 'opentracing-util', version: '0.32.0'
27+
2428
// file upload
2529
implementation group: 'commons-fileupload', name: 'commons-fileupload', version: '1.5'
2630

dd-smoke-tests/appsec/springboot/src/main/java/datadog/smoketest/appsec/springboot/controller/AIGuardController.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import datadog.trace.api.aiguard.AIGuard.Evaluation;
99
import datadog.trace.api.aiguard.AIGuard.Message;
1010
import datadog.trace.api.aiguard.AIGuard.Options;
11+
import datadog.trace.api.interceptor.MutableSpan;
12+
import io.opentracing.Span;
13+
import io.opentracing.util.GlobalTracer;
1114
import java.util.Collections;
1215
import java.util.HashMap;
1316
import java.util.List;
@@ -26,7 +29,19 @@
2629
public class AIGuardController {
2730

2831
@GetMapping(value = "/allow")
29-
public ResponseEntity<?> allow() {
32+
public ResponseEntity<?> allow(
33+
@RequestHeader(name = "X-User-Id", required = false) final String userId,
34+
@RequestHeader(name = "X-Session-Id", required = false) final String sessionId) {
35+
final Span activeSpan = GlobalTracer.get().activeSpan();
36+
if (activeSpan instanceof MutableSpan) {
37+
final MutableSpan rootSpan = ((MutableSpan) activeSpan).getLocalRootSpan();
38+
if (userId != null && !userId.isEmpty()) {
39+
rootSpan.setTag("usr.id", userId);
40+
}
41+
if (sessionId != null && !sessionId.isEmpty()) {
42+
rootSpan.setTag("usr.session_id", sessionId);
43+
}
44+
}
3045
final Evaluation result =
3146
AIGuard.evaluate(
3247
asList(

dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,52 @@ class AIGuardSmokeTest extends AbstractAppSecServerSmokeTest {
167167
rootSpan.meta.get('network.client.ip') != publicIp
168168
}
169169

170+
void 'anomaly detection tags are copied from the local root span to the ai_guard span'() {
171+
given:
172+
final publicIp = '5.6.7.9'
173+
final userId = 'u12345'
174+
final sessionId = 's12345'
175+
final userAgent = 'AIGuardSmokeTest/1.0'
176+
final request = new Request.Builder()
177+
.url("http://localhost:${httpPort}/aiguard/allow")
178+
.header('X-Forwarded-For', publicIp)
179+
.header('X-User-Id', userId)
180+
.header('X-Session-Id', sessionId)
181+
.header('User-Agent', userAgent)
182+
.get()
183+
.build()
184+
185+
when:
186+
final response = client.newCall(request).execute()
187+
188+
then:
189+
response.code() == 200
190+
191+
and:
192+
waitForTraceCount(2) // /aiguard/allow + internal /aiguard/evaluate mock
193+
final aiGuardSpan = traces*.spans
194+
?.flatten()
195+
?.find { it.resource == 'ai_guard' } as DecodedSpan
196+
aiGuardSpan != null
197+
final rootSpan = traces*.spans
198+
?.flatten()
199+
?.find { it.traceId == aiGuardSpan.traceId && it.parentId == 0 } as DecodedSpan
200+
rootSpan != null
201+
202+
// Tags must match what is on the root span
203+
aiGuardSpan.meta.get('ai_guard.http.client_ip') == rootSpan.meta.get('http.client_ip')
204+
aiGuardSpan.meta.get('ai_guard.network.client.ip') == rootSpan.meta.get('network.client.ip')
205+
aiGuardSpan.meta.get('ai_guard.http.useragent') == rootSpan.meta.get('http.useragent')
206+
aiGuardSpan.meta.get('ai_guard.usr.id') == rootSpan.meta.get('usr.id')
207+
aiGuardSpan.meta.get('ai_guard.usr.session_id') == rootSpan.meta.get('usr.session_id')
208+
209+
// And carry the expected values
210+
aiGuardSpan.meta.get('ai_guard.http.client_ip') == publicIp
211+
aiGuardSpan.meta.get('ai_guard.http.useragent') == userAgent
212+
aiGuardSpan.meta.get('ai_guard.usr.id') == userId
213+
aiGuardSpan.meta.get('ai_guard.usr.session_id') == sessionId
214+
}
215+
170216
void 'test multimodal content parts evaluation'() {
171217
given:
172218
def request = new Request.Builder()

0 commit comments

Comments
 (0)