Skip to content

Commit 067d0d2

Browse files
manuel-alvarez-alvarezdevflow.devflow-routing-intake
andauthored
feat(ai-guard): expose tag probabilities in SDK responses (#11063)
feat(ai-guard): expose tag probabilities in SDK responses Merge branch 'master' into malvarez/ai-guard-attach-tag-probabilities Merge branch 'master' into malvarez/ai-guard-attach-tag-probabilities Co-authored-by: devflow.devflow-routing-intake <devflow.devflow-routing-intake@kubernetes.us1.ddbuild.io>
1 parent 6564199 commit 067d0d2

File tree

4 files changed

+52
-11
lines changed

4 files changed

+52
-11
lines changed

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public BadConfigurationException(final String message) {
7373
static final String META_STRUCT_MESSAGES = "messages";
7474
static final String META_STRUCT_CATEGORIES = "attack_categories";
7575
static final String META_STRUCT_SDS = "sds";
76+
static final String META_STRUCT_TAG_PROBS = "tag_probs";
7677

7778
public static void install() {
7879
final Config config = Config.get();
@@ -258,13 +259,18 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
258259
final List<String> tags = (List<String>) result.get("tags");
259260
@SuppressWarnings("unchecked")
260261
final List<?> sdsFindings = (List<?>) result.get("sds_findings");
262+
@SuppressWarnings("unchecked")
263+
final Map<String, Number> tagProbs = (Map<String, Number>) result.get("tag_probs");
261264
span.setTag(ACTION_TAG, action);
262265
if (reason != null) {
263266
span.setTag(REASON_TAG, reason);
264267
}
265268
if (tags != null && !tags.isEmpty()) {
266269
metaStruct.put(META_STRUCT_CATEGORIES, tags);
267270
}
271+
if (tagProbs != null && !tagProbs.isEmpty()) {
272+
metaStruct.put(META_STRUCT_TAG_PROBS, tagProbs);
273+
}
268274
if (sdsFindings != null && !sdsFindings.isEmpty()) {
269275
metaStruct.put(META_STRUCT_SDS, sdsFindings);
270276
}
@@ -273,9 +279,9 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
273279
WafMetricCollector.get().aiGuardRequest(action, shouldBlock);
274280
if (shouldBlock) {
275281
span.setTag(BLOCKED_TAG, true);
276-
throw new AIGuardAbortError(action, reason, tags, sdsFindings);
282+
throw new AIGuardAbortError(action, reason, tags, tagProbs, sdsFindings);
277283
}
278-
return new Evaluation(action, reason, tags, sdsFindings);
284+
return new Evaluation(action, reason, tags, tagProbs, sdsFindings);
279285
}
280286
} catch (AIGuardAbortError e) {
281287
span.addThrowable(e);

dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class AIGuardInternalTests extends DDSpecification {
168168
return mockResponse(
169169
request,
170170
200,
171-
[data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], is_blocking_enabled: suite.blocking]]]
171+
[data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], tag_probs: suite.tagProbabilities ?: [:], is_blocking_enabled: suite.blocking]]]
172172
)
173173
}
174174
}
@@ -210,12 +210,14 @@ class AIGuardInternalTests extends DDSpecification {
210210
error.action == suite.action
211211
error.reason == suite.reason
212212
error.tags == suite.tags
213+
error.tagProbabilities == suite.tagProbabilities
213214
error.sds == []
214215
} else {
215216
error == null
216217
eval.action == suite.action
217218
eval.reason == suite.reason
218219
eval.tags == suite.tags
220+
eval.tagProbabilities == suite.tagProbabilities
219221
eval.sds == []
220222
}
221223
assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false')
@@ -555,6 +557,9 @@ class AIGuardInternalTests extends DDSpecification {
555557
if (suite.tags) {
556558
assert meta.attack_categories == suite.tags
557559
}
560+
if (suite.tagProbabilities) {
561+
assert meta.tag_probs == suite.tagProbabilities
562+
}
558563
final receivedMessages = snakeCaseJson(meta.messages)
559564
final expectedMessages = snakeCaseJson(suite.messages)
560565
JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE)
@@ -774,15 +779,17 @@ class AIGuardInternalTests extends DDSpecification {
774779
private final AIGuard.Action action
775780
private final String reason
776781
private final List<String> tags
782+
private final Map<String, Double> tagProbabilities
777783
private final boolean blocking
778784
private final String description
779785
private final String target
780786
private final List<AIGuard.Message> messages
781787

782-
TestSuite(AIGuard.Action action, String reason, List<String> tags, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
788+
TestSuite(AIGuard.Action action, String reason, Map<String, Double> tagProbabilities, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
783789
this.action = action
784790
this.reason = reason
785-
this.tags = tags
791+
this.tags = new ArrayList<>(tagProbabilities.keySet())
792+
this.tagProbabilities = tagProbabilities
786793
this.blocking = blocking
787794
this.description = description
788795
this.target = target
@@ -791,9 +798,9 @@ class AIGuardInternalTests extends DDSpecification {
791798

792799
static List<TestSuite> build() {
793800
def actionValues = [
794-
[ALLOW, 'Go ahead', []],
795-
[DENY, 'Nope', ['deny_everything', 'test_deny']],
796-
[ABORT, 'Kill it with fire', ['alarm_tag', 'abort_everything']]
801+
[ALLOW, 'Go ahead', [:]],
802+
[DENY, 'Nope', ['deny_everything': 0.2D, 'test_deny': 0.8D]],
803+
[ABORT, 'Kill it with fire', ['alarm_tag': 0.1D, 'abort_everything': 0.9D]]
797804
]
798805
def blockingValues = [true, false]
799806
def suiteValues = [

dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.Collections;
66
import java.util.List;
77
import java.util.Locale;
8+
import java.util.Map;
89
import java.util.Objects;
910
import javax.annotation.Nonnull;
1011
import javax.annotation.Nullable;
@@ -69,14 +70,20 @@ public static class AIGuardAbortError extends RuntimeException {
6970
private final Action action;
7071
private final String reason;
7172
private final List<String> tags;
73+
private final Map<String, Number> tagProbs;
7274
private final List<?> sds;
7375

7476
public AIGuardAbortError(
75-
final Action action, final String reason, final List<String> tags, final List<?> sds) {
77+
final Action action,
78+
final String reason,
79+
final List<String> tags,
80+
final Map<String, Number> tagProbs,
81+
final List<?> sds) {
7682
super(reason);
7783
this.action = action;
7884
this.reason = reason;
7985
this.tags = tags;
86+
this.tagProbs = tagProbs != null ? tagProbs : Collections.emptyMap();
8087
this.sds = sds != null ? sds : Collections.emptyList();
8188
}
8289

@@ -92,6 +99,10 @@ public List<String> getTags() {
9299
return tags;
93100
}
94101

102+
public Map<String, Number> getTagProbabilities() {
103+
return tagProbs;
104+
}
105+
95106
public List<?> getSds() {
96107
return sds;
97108
}
@@ -156,6 +167,7 @@ public static class Evaluation {
156167
final Action action;
157168
final String reason;
158169
final List<String> tags;
170+
final Map<String, Number> tagProbs;
159171
final List<?> sds;
160172

161173
/**
@@ -164,13 +176,19 @@ public static class Evaluation {
164176
* @param action the recommended action for the evaluated content
165177
* @param reason human-readable explanation for the decision
166178
* @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection)
179+
* @param tagProbs map of tags associated to their probability
167180
* @param sds list of Sensitive Data Scanner findings
168181
*/
169182
public Evaluation(
170-
final Action action, final String reason, final List<String> tags, final List<?> sds) {
183+
final Action action,
184+
final String reason,
185+
final List<String> tags,
186+
final Map<String, Number> tagProbs,
187+
final List<?> sds) {
171188
this.action = action;
172189
this.reason = reason;
173190
this.tags = tags;
191+
this.tagProbs = tagProbs;
174192
this.sds = sds != null ? sds : Collections.emptyList();
175193
}
176194

@@ -201,6 +219,15 @@ public List<String> getTags() {
201219
return tags;
202220
}
203221

222+
/**
223+
* Returns a map from tag to their probability (e.g. [indirect-prompt-injection: 0.25])
224+
*
225+
* @return map of tag probabilities.
226+
*/
227+
public Map<String, Number> getTagProbabilities() {
228+
return tagProbs;
229+
}
230+
204231
/**
205232
* Returns the list of Sensitive Data Scanner findings.
206233
*

dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW;
44
import static java.util.Collections.emptyList;
5+
import static java.util.Collections.emptyMap;
56

67
import datadog.trace.api.aiguard.AIGuard.Evaluation;
78
import datadog.trace.api.aiguard.AIGuard.Message;
@@ -13,6 +14,6 @@ public final class NoOpEvaluator implements Evaluator {
1314

1415
@Override
1516
public Evaluation evaluate(final List<Message> messages, final Options options) {
16-
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList(), emptyList());
17+
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList(), emptyMap(), emptyList());
1718
}
1819
}

0 commit comments

Comments
 (0)