Skip to content

Commit f2dbab6

Browse files
authored
ref: Add support for custom sampling context to span first (14) (#5628)
Custom sampling context allows folks to have arbitrary data accessible in the `traces_sampler` in order to make a sampling decision. The SDK sets custom sampling context as well in some integrations, for example, in ASGI frameworks the ASGI scope will be available. Previously, you could provide custom sampling context as an argument to the `start_span` function. In the spirit of keeping the new `start_span` API minimal, we'll be moving `custom_sampling_context` to the propagation context and providing a dedicated API function to set it. In this PR, it's a scope method (`scope.set_custom_sampling_context()`). We can (and probably should) promote it to top-level API at some point in the future.
1 parent f825898 commit f2dbab6

File tree

3 files changed

+98
-20
lines changed

3 files changed

+98
-20
lines changed

sentry_sdk/scope.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,13 @@ def get_active_propagation_context(self) -> "PropagationContext":
696696
isolation_scope._propagation_context = PropagationContext()
697697
return isolation_scope._propagation_context
698698

699+
def set_custom_sampling_context(
700+
self, custom_sampling_context: "dict[str, Any]"
701+
) -> None:
702+
self.get_active_propagation_context()._set_custom_sampling_context(
703+
custom_sampling_context
704+
)
705+
699706
def clear(self) -> None:
700707
"""Clears the entire scope."""
701708
self._level: "Optional[LogLevelStr]" = None

sentry_sdk/tracing_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ class PropagationContext:
417417
"parent_span_id",
418418
"parent_sampled",
419419
"baggage",
420+
"custom_sampling_context",
420421
)
421422

422423
def __init__(
@@ -450,6 +451,8 @@ def __init__(
450451
if baggage is None and dynamic_sampling_context is not None:
451452
self.baggage = Baggage(dynamic_sampling_context)
452453

454+
self.custom_sampling_context: "Optional[dict[str, Any]]" = None
455+
453456
@classmethod
454457
def from_incoming_data(
455458
cls, incoming_data: "Dict[str, Any]"
@@ -537,6 +540,11 @@ def update(self, other_dict: "Dict[str, Any]") -> None:
537540
except AttributeError:
538541
pass
539542

543+
def _set_custom_sampling_context(
544+
self, custom_sampling_context: "dict[str, Any]"
545+
) -> None:
546+
self.custom_sampling_context = custom_sampling_context
547+
540548
def __repr__(self) -> str:
541549
return "<PropagationContext _trace_id={} _span_id={} parent_span_id={} parent_sampled={} baggage={}>".format(
542550
self._trace_id,
@@ -1413,13 +1421,18 @@ def _make_sampling_decision(
14131421
traces_sampler_defined = callable(client.options.get("traces_sampler"))
14141422
if traces_sampler_defined:
14151423
sampling_context = {
1416-
"name": name,
1417-
"trace_id": propagation_context.trace_id,
1418-
"parent_span_id": propagation_context.parent_span_id,
1419-
"parent_sampled": propagation_context.parent_sampled,
1420-
"attributes": dict(attributes) if attributes else {},
1424+
"span_context": {
1425+
"name": name,
1426+
"trace_id": propagation_context.trace_id,
1427+
"parent_span_id": propagation_context.parent_span_id,
1428+
"parent_sampled": propagation_context.parent_sampled,
1429+
"attributes": dict(attributes) if attributes else {},
1430+
},
14211431
}
14221432

1433+
if propagation_context.custom_sampling_context:
1434+
sampling_context.update(propagation_context.custom_sampling_context)
1435+
14231436
sample_rate = client.options["traces_sampler"](sampling_context)
14241437
else:
14251438
if propagation_context.parent_sampled is not None:

tests/tracing/test_span_streaming.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_span_sampled_when_created(sentry_init, capture_envelopes):
120120
# at start_span() time
121121

122122
def traces_sampler(sampling_context):
123-
assert "delayed_attribute" not in sampling_context["attributes"]
123+
assert "delayed_attribute" not in sampling_context["span_context"]["attributes"]
124124
return 1.0
125125

126126
sentry_init(
@@ -169,9 +169,11 @@ def test_start_span_attributes(sentry_init, capture_envelopes):
169169

170170
def test_start_span_attributes_in_traces_sampler(sentry_init, capture_envelopes):
171171
def traces_sampler(sampling_context):
172-
assert "attributes" in sampling_context
173-
assert "my_attribute" in sampling_context["attributes"]
174-
assert sampling_context["attributes"]["my_attribute"] == "my_value"
172+
assert "attributes" in sampling_context["span_context"]
173+
assert "my_attribute" in sampling_context["span_context"]["attributes"]
174+
assert (
175+
sampling_context["span_context"]["attributes"]["my_attribute"] == "my_value"
176+
)
175177
return 1.0
176178

177179
sentry_init(
@@ -202,16 +204,16 @@ def test_sampling_context(sentry_init, capture_envelopes):
202204
def traces_sampler(sampling_context):
203205
nonlocal received_trace_id
204206

205-
assert "trace_id" in sampling_context
206-
received_trace_id = sampling_context["trace_id"]
207+
assert "trace_id" in sampling_context["span_context"]
208+
received_trace_id = sampling_context["span_context"]["trace_id"]
207209

208-
assert "parent_span_id" in sampling_context
209-
assert sampling_context["parent_span_id"] is None
210+
assert "parent_span_id" in sampling_context["span_context"]
211+
assert sampling_context["span_context"]["parent_span_id"] is None
210212

211-
assert "parent_sampled" in sampling_context
212-
assert sampling_context["parent_sampled"] is None
213+
assert "parent_sampled" in sampling_context["span_context"]
214+
assert sampling_context["span_context"]["parent_sampled"] is None
213215

214-
assert "attributes" in sampling_context
216+
assert "attributes" in sampling_context["span_context"]
215217

216218
return 1.0
217219

@@ -233,6 +235,62 @@ def traces_sampler(sampling_context):
233235
assert len(spans) == 1
234236

235237

238+
def test_custom_sampling_context(sentry_init):
239+
class MyClass: ...
240+
241+
my_class = MyClass()
242+
243+
def traces_sampler(sampling_context):
244+
assert "class" in sampling_context
245+
assert "string" in sampling_context
246+
assert sampling_context["class"] == my_class
247+
assert sampling_context["string"] == "my string"
248+
return 1.0
249+
250+
sentry_init(
251+
traces_sampler=traces_sampler,
252+
_experiments={"trace_lifecycle": "stream"},
253+
)
254+
255+
sentry_sdk.get_current_scope().set_custom_sampling_context(
256+
{
257+
"class": my_class,
258+
"string": "my string",
259+
}
260+
)
261+
262+
with sentry_sdk.traces.start_span(name="span"):
263+
...
264+
265+
266+
def test_custom_sampling_context_update_to_context_value_persists(sentry_init):
267+
def traces_sampler(sampling_context):
268+
if sampling_context["span_context"]["attributes"]["first"] is True:
269+
assert sampling_context["custom_value"] == 1
270+
else:
271+
assert sampling_context["custom_value"] == 2
272+
return 1.0
273+
274+
sentry_init(
275+
traces_sampler=traces_sampler,
276+
_experiments={"trace_lifecycle": "stream"},
277+
)
278+
279+
sentry_sdk.traces.new_trace()
280+
281+
sentry_sdk.get_current_scope().set_custom_sampling_context({"custom_value": 1})
282+
283+
with sentry_sdk.traces.start_span(name="span", attributes={"first": True}):
284+
...
285+
286+
sentry_sdk.traces.new_trace()
287+
288+
sentry_sdk.get_current_scope().set_custom_sampling_context({"custom_value": 2})
289+
290+
with sentry_sdk.traces.start_span(name="span", attributes={"first": False}):
291+
...
292+
293+
236294
def test_span_attributes(sentry_init, capture_envelopes):
237295
sentry_init(
238296
traces_sample_rate=1.0,
@@ -305,10 +363,10 @@ class Class:
305363

306364
def test_traces_sampler_drops_span(sentry_init, capture_envelopes):
307365
def traces_sampler(sampling_context):
308-
assert "attributes" in sampling_context
309-
assert "drop" in sampling_context["attributes"]
366+
assert "attributes" in sampling_context["span_context"]
367+
assert "drop" in sampling_context["span_context"]["attributes"]
310368

311-
if sampling_context["attributes"]["drop"] is True:
369+
if sampling_context["span_context"]["attributes"]["drop"] is True:
312370
return 0.0
313371

314372
return 1.0
@@ -342,7 +400,7 @@ def test_traces_sampler_called_once_per_segment(sentry_init):
342400
def traces_sampler(sampling_context):
343401
nonlocal traces_sampler_called, span_name_in_traces_sampler
344402
traces_sampler_called += 1
345-
span_name_in_traces_sampler = sampling_context["name"]
403+
span_name_in_traces_sampler = sampling_context["span_context"]["name"]
346404
return 1.0
347405

348406
sentry_init(

0 commit comments

Comments
 (0)