Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions bottlecap/src/lifecycle/invocation/span_inferrer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ impl SpanInferrer {
);
}

if let Some(dd_resource_key) = t.get_dd_resource_key(&aws_config.region) {
inferred_span
.meta
.insert("dd_resource_key".to_string(), dd_resource_key);
}

self.wrapped_inferred_span = wrapped_inferred_span;
self.span_pointers = span_pointers;

Expand Down Expand Up @@ -278,12 +284,15 @@ impl SpanInferrer {
invocation_span.service.clone(),
);
s.meta.insert("span.kind".to_string(), "server".to_string());
let appsec_enabled = self.config.serverless_appsec_enabled;
propagate_appsec(appsec_enabled, invocation_span, s);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the json need to be propagated up for the inferred spans?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to address an issue raised by a customer who wants to be able to link attacks to API Gateway endpoints... Initially we wanted to do it backend-side, but we cannot guarantee the spans are in the same chunk so that was problematic... Copying the attack data to the inferred span makes this problem go away.


if let Some(ws) = &mut self.wrapped_inferred_span {
ws.trace_id = invocation_span.trace_id;
ws.error = invocation_span.error;
ws.meta
.insert(String::from("peer.service"), s.service.clone());
propagate_appsec(appsec_enabled, invocation_span, ws);

// The wrapper span should be the parent of the inferred span,
// therefore the `parent_id` of the inferred span should be the
Expand Down Expand Up @@ -325,6 +334,34 @@ impl SpanInferrer {
}
}

fn propagate_appsec(
serverless_appsec_enabled: bool,
invocation_span: &Span,
target_span: &mut Span,
) {
let has_appsec = invocation_span
.metrics
.get("_dd.appsec.enabled")
.copied()
.or(if serverless_appsec_enabled {
Some(1.0)
} else {
None
});

if let Some(enabled) = has_appsec {
target_span
.metrics
.insert("_dd.appsec.enabled".to_string(), enabled);
}

if let Some(json) = invocation_span.meta.get("_dd.appsec.json") {
target_span
.meta
.insert("_dd.appsec.json".to_string(), json.clone());
}
}

pub fn extract_span_context(
payload_value: &Value,
propagator: Arc<impl Propagator>,
Expand Down Expand Up @@ -368,6 +405,7 @@ pub fn extract_generated_span_context(
#[cfg(test)]
mod tests {
use super::*;
use crate::lifecycle::invocation::triggers::test_utils::read_json_file;
use crate::traces::propagation::text_map_propagator::DatadogHeaderPropagator;
use serde_json::json;
use std::sync::Arc;
Expand Down Expand Up @@ -571,4 +609,117 @@ mod tests {
"Should have SQS as event source"
);
}

fn api_gateway_rest_payload() -> serde_json::Value {
let json = read_json_file("api_gateway_rest_event.json");
serde_json::from_str(&json).expect("Failed to deserialize API Gateway REST payload")
}

fn aws_config(region: &str) -> Arc<AwsConfig> {
Arc::new(AwsConfig {
region: region.to_string(),
aws_lwa_proxy_lambda_runtime_api: Some(String::new()),
runtime_api: String::new(),
function_name: String::new(),
sandbox_init_time: Instant::now(),
exec_wrapper: None,
initialization_type: "on-demand".into(),
})
}

#[test]
fn test_complete_inferred_spans_propagates_appsec_from_invocation() {
let payload = api_gateway_rest_payload();
let aws_config = aws_config("us-east-1");
let mut inferrer = SpanInferrer::new(Arc::new(Config::default()));

inferrer.infer_span(&payload, &aws_config);

let mut invocation_span = Span {
trace_id: 42,
span_id: 100,
service: "lambda-service".to_string(),
..Span::default()
};
if let Some(inferred_span) = &inferrer.inferred_span {
invocation_span.start = inferred_span.start;
}
invocation_span.duration = 1;
invocation_span
.metrics
.insert("_dd.appsec.enabled".to_string(), 1.0);
invocation_span.meta.insert(
"_dd.appsec.json".to_string(),
r#"{"triggers":["rule"]}"#.to_string(),
);

inferrer.complete_inferred_spans(&invocation_span);

let inferred_span = inferrer
.inferred_span
.as_ref()
.expect("Inferred span should still be present");

let appsec_enabled = inferred_span
.metrics
.get("_dd.appsec.enabled")
.copied()
.unwrap_or_default();
assert!(
(appsec_enabled - 1.0).abs() < f64::EPSILON,
"Expected appsec enabled metric to be 1.0"
);
assert_eq!(
inferred_span
.meta
.get("_dd.appsec.json")
.cloned()
.unwrap_or_default(),
r#"{"triggers":["rule"]}"#
);
}

#[test]
fn test_complete_inferred_spans_sets_appsec_when_enabled_in_config() {
let config = Config {
serverless_appsec_enabled: true,
..Config::default()
};
let mut inferrer = SpanInferrer::new(Arc::new(config));

let payload = api_gateway_rest_payload();
let aws_config = aws_config("us-east-1");
inferrer.infer_span(&payload, &aws_config);

let mut invocation_span = Span {
trace_id: 7,
service: "lambda-service".to_string(),
..Span::default()
};
if let Some(inferred_span) = &inferrer.inferred_span {
invocation_span.start = inferred_span.start;
}
invocation_span.duration = 1;

inferrer.complete_inferred_spans(&invocation_span);

let inferred_span = inferrer
.inferred_span
.as_ref()
.expect("Inferred span should still be present");

let appsec_enabled = inferred_span
.metrics
.get("_dd.appsec.enabled")
.copied()
.unwrap_or_default();
assert!(
(appsec_enabled - 1.0).abs() < f64::EPSILON,
"Expected appsec enabled metric to be 1.0"
);
assert!(
!inferred_span.meta.contains_key("_dd.appsec.json"),
"AppSec JSON should not be added when invocation span has none"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl Trigger for APIGatewayHttpEvent {
span.name = "aws.httpapi".to_string();
span.service = service_name;
span.resource.clone_from(&resource);
span.r#type = "http".to_string();
span.r#type = "web".to_string();
span.start = start_time;
span.meta.extend(HashMap::from([
(
Expand All @@ -131,7 +131,6 @@ impl Trigger for APIGatewayHttpEvent {
"http.user_agent".to_string(),
self.request_context.http.user_agent.clone(),
),
("operation_name".to_string(), "aws.httpapi".to_string()),
(
"request_id".to_string(),
self.request_context.request_id.clone(),
Expand Down Expand Up @@ -200,6 +199,20 @@ impl Trigger for APIGatewayHttpEvent {
)
}

fn get_dd_resource_key(&self, region: &str) -> Option<String> {
if self.request_context.api_id.is_empty() {
return None;
}

let partition = get_aws_partition_by_region(region);
Some(format!(
"arn:{partition}:apigateway:{region}::/apis/{api_id}",
partition = partition,
region = region,
api_id = self.request_context.api_id
))
}

fn is_async(&self) -> bool {
self.headers
.get("x-amz-invocation-type")
Expand All @@ -220,6 +233,7 @@ impl ServiceNameResolver for APIGatewayHttpEvent {
"lambda_api_gateway"
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -309,7 +323,7 @@ mod tests {
"x02yirxc7a.execute-api.sa-east-1.amazonaws.com"
);
assert_eq!(span.resource, "GET /httpapi/get");
assert_eq!(span.r#type, "http");
assert_eq!(span.r#type, "web");
assert_eq!(
span.meta,
HashMap::from([
Expand All @@ -323,7 +337,6 @@ mod tests {
("http.protocol".to_string(), "HTTP/1.1".to_string()),
("http.source_ip".to_string(), "38.122.226.210".to_string()),
("http.user_agent".to_string(), "curl/7.64.1".to_string()),
("operation_name".to_string(), "aws.httpapi".to_string()),
("request_id".to_string(), "FaHnXjKCGjQEJ7A=".to_string()),
])
);
Expand Down Expand Up @@ -373,7 +386,7 @@ mod tests {
"9vj54we5ih.execute-api.sa-east-1.amazonaws.com"
);
assert_eq!(span.resource, "GET /user/{user_id}");
assert_eq!(span.r#type, "http");
assert_eq!(span.r#type, "web");
assert_eq!(
span.meta,
HashMap::from([
Expand All @@ -386,7 +399,6 @@ mod tests {
("http.protocol".to_string(), "HTTP/1.1".to_string()),
("http.source_ip".to_string(), "76.115.124.192".to_string()),
("http.user_agent".to_string(), "curl/8.1.2".to_string()),
("operation_name".to_string(), "aws.httpapi".to_string()),
("request_id".to_string(), "Ur2JtjEfGjQEPOg=".to_string()),
])
);
Expand Down Expand Up @@ -429,6 +441,18 @@ mod tests {
);
}

#[test]
fn test_get_dd_resource_key() {
let json = read_json_file("api_gateway_http_event.json");
let payload = serde_json::from_str(&json).expect("Failed to deserialize into Value");
let event =
APIGatewayHttpEvent::new(payload).expect("Failed to deserialize APIGatewayHttpEvent");
assert_eq!(
event.get_dd_resource_key("sa-east-1"),
Some("arn:aws:apigateway:sa-east-1::/apis/x02yirxc7a".to_string())
);
}

#[test]
fn test_resolve_service_name_with_representation_enabled() {
let json = read_json_file("api_gateway_http_event.json");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl Trigger for APIGatewayRestEvent {
span.name = "aws.apigateway".to_string();
span.service = service_name;
span.resource = resource;
span.r#type = "http".to_string();
span.r#type = "web".to_string();
span.start = start_time;
span.meta.extend(HashMap::from([
("endpoint".to_string(), self.request_context.path.clone()),
Expand All @@ -125,7 +125,6 @@ impl Trigger for APIGatewayRestEvent {
"http.user_agent".to_string(),
self.request_context.identity.user_agent.clone(),
),
("operation_name".to_string(), "aws.apigateway".to_string()),
(
"request_id".to_string(),
self.request_context.request_id.clone(),
Expand Down Expand Up @@ -187,6 +186,20 @@ impl Trigger for APIGatewayRestEvent {
)
}

fn get_dd_resource_key(&self, region: &str) -> Option<String> {
if self.request_context.api_id.is_empty() {
return None;
}

let partition = get_aws_partition_by_region(region);
Some(format!(
"arn:{partition}:apigateway:{region}::/restapis/{api_id}",
partition = partition,
region = region,
api_id = self.request_context.api_id
))
}

fn is_async(&self) -> bool {
self.headers
.get("x-amz-invocation-type")
Expand Down Expand Up @@ -327,7 +340,7 @@ mod tests {
assert_eq!(span.name, "aws.apigateway");
assert_eq!(span.service, "id.execute-api.us-east-1.amazonaws.com");
assert_eq!(span.resource, "GET /my/path");
assert_eq!(span.r#type, "http");
assert_eq!(span.r#type, "web");

assert_eq!(
span.meta,
Expand All @@ -342,7 +355,6 @@ mod tests {
("http.source_ip".to_string(), "IP".to_string()),
("http.user_agent".to_string(), "user-agent".to_string()),
("http.route".to_string(), "/path".to_string()),
("operation_name".to_string(), "aws.apigateway".to_string()),
("request_id".to_string(), "id=".to_string()),
])
);
Expand Down Expand Up @@ -389,7 +401,7 @@ mod tests {
"mcwkra0ya4.execute-api.sa-east-1.amazonaws.com"
);
assert_eq!(span.resource, "GET /dev/user/{user_id}/id/{id}");
assert_eq!(span.r#type, "http");
assert_eq!(span.r#type, "web");
let expected = HashMap::from([
("endpoint".to_string(), "/dev/user/42/id/50".to_string()),
(
Expand All @@ -402,7 +414,6 @@ mod tests {
("http.source_ip".to_string(), "76.115.124.192".to_string()),
("http.user_agent".to_string(), "curl/8.1.2".to_string()),
("http.route".to_string(), "/user/{id}".to_string()),
("operation_name".to_string(), "aws.apigateway".to_string()),
(
"request_id".to_string(),
"e16399f7-e984-463a-9931-745ba021a27f".to_string(),
Expand Down Expand Up @@ -454,6 +465,18 @@ mod tests {
);
}

#[test]
fn test_get_dd_resource_key() {
let json = read_json_file("api_gateway_rest_event.json");
let payload = serde_json::from_str(&json).expect("Failed to deserialize into Value");
let event =
APIGatewayRestEvent::new(payload).expect("Failed to deserialize APIGatewayRestEvent");
assert_eq!(
event.get_dd_resource_key("us-east-1"),
Some("arn:aws:apigateway:us-east-1::/restapis/id".to_string())
);
}

#[test]
fn test_resolve_service_name_with_representation_enabled() {
let json = read_json_file("api_gateway_rest_event.json");
Expand Down
Loading
Loading