Skip to content

Commit 74cebf1

Browse files
Alex Wangwangyb-A
authored andcommitted
fix: race condition for id generator
1 parent a0dc691 commit 74cebf1

2 files changed

Lines changed: 89 additions & 3 deletions

File tree

packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import contextvars
56
import hashlib
67
import os
78
import re
@@ -12,6 +13,12 @@
1213

1314
HASHED_ID_PATTERN = re.compile(r"^[0-9a-f]{16}$")
1415

16+
# Scoping the pending span ID to the execution context ensures concurrent
17+
# operations cannot consume each other's deterministic span ID.
18+
_next_span_id: contextvars.ContextVar[int | None] = contextvars.ContextVar(
19+
"next_span_id", default=None
20+
)
21+
1522

1623
def _parse_xray_root_trace_id(trace_header: str | None) -> str | None:
1724
"""Parse the Root trace ID from an X-Ray trace header string.
@@ -83,7 +90,6 @@ class DeterministicIdGenerator(RandomIdGenerator):
8390
"""
8491

8592
def __init__(self, fallback_id_generator: IdGenerator | None = None) -> None:
86-
self._next_span_id: int | None = None
8793
self._execution_trace_id: int | None = None
8894
self._fallback_id_generator = fallback_id_generator or RandomIdGenerator()
8995

@@ -92,7 +98,7 @@ def set_next_span_id(self, span_id: int | None) -> None:
9298
9399
After one span is created, it resets to random.
94100
"""
95-
self._next_span_id = span_id
101+
_next_span_id.set(span_id)
96102

97103
def set_trace_id(
98104
self, execution_arn: str, start_timestamp: datetime | None
@@ -113,5 +119,8 @@ def generate_trace_id(self) -> int:
113119

114120
def generate_span_id(self) -> int:
115121
"""Generate a 64-bit span ID."""
116-
span_id, self._next_span_id = self._next_span_id, None
122+
span_id = _next_span_id.get()
123+
# Consume once: the deterministic ID applies only to the next span
124+
# created in this context; subsequent spans fall back to random.
125+
_next_span_id.set(None)
117126
return span_id or self._fallback_id_generator.generate_span_id()

packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import asyncio
6+
import threading
57
from datetime import UTC, datetime
68

79
from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator
@@ -203,3 +205,78 @@ def test_deterministic_id_generator_prefers_next_span_id_over_fallback():
203205
assert generator.generate_span_id() == deterministic_span_id
204206
# Subsequent calls fall back to the provided generator.
205207
assert generator.generate_span_id() == int("b" * 16, 16)
208+
209+
210+
def test_pending_span_id_is_isolated_across_threads():
211+
"""Verify a span ID set in one thread is not consumed by another thread.
212+
213+
The pending span ID is stored in a ContextVar, so each worker thread has
214+
its own value. Without this isolation a concurrent operation could steal
215+
another operation's deterministic span ID, producing the wrong span ID.
216+
"""
217+
random_span_id = int("f" * 16, 16)
218+
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=random_span_id)
219+
generator = DeterministicIdGenerator(fallback_id_generator=fallback)
220+
221+
# The main thread sets a deterministic span ID but never consumes it.
222+
main_deterministic_span_id = int("1" * 16, 16)
223+
generator.set_next_span_id(main_deterministic_span_id)
224+
225+
barrier = threading.Barrier(2)
226+
results: dict[str, int] = {}
227+
228+
def worker(name: str, span_id: int) -> None:
229+
# Each worker starts with a fresh context (default None), so it must
230+
# not see the main thread's pending span ID.
231+
barrier.wait()
232+
results[f"{name}-before-set"] = generator.generate_span_id()
233+
generator.set_next_span_id(span_id)
234+
results[f"{name}-after-set"] = generator.generate_span_id()
235+
236+
worker_a_span_id = int("2" * 16, 16)
237+
worker_b_span_id = int("3" * 16, 16)
238+
thread_a = threading.Thread(target=worker, args=("a", worker_a_span_id))
239+
thread_b = threading.Thread(target=worker, args=("b", worker_b_span_id))
240+
thread_a.start()
241+
thread_b.start()
242+
thread_a.join()
243+
thread_b.join()
244+
245+
# Workers never observed the main thread's value; they fell back to random.
246+
assert results["a-before-set"] == random_span_id
247+
assert results["b-before-set"] == random_span_id
248+
# Each worker consumed only its own deterministic span ID.
249+
assert results["a-after-set"] == worker_a_span_id
250+
assert results["b-after-set"] == worker_b_span_id
251+
# The main thread's pending span ID was untouched by the workers.
252+
assert generator.generate_span_id() == main_deterministic_span_id
253+
254+
255+
def test_pending_span_id_is_isolated_across_async_tasks():
256+
"""Verify a span ID set in one async task is not consumed by another.
257+
258+
Each asyncio task runs with its own copied context, so the pending span ID
259+
stays scoped to the task that set it even across await boundaries on the
260+
same thread.
261+
"""
262+
fallback_span_id = int("e" * 16, 16)
263+
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=fallback_span_id)
264+
generator = DeterministicIdGenerator(fallback_id_generator=fallback)
265+
266+
task_a_span_id = int("4" * 16, 16)
267+
task_b_span_id = int("5" * 16, 16)
268+
269+
async def task(span_id: int) -> int:
270+
generator.set_next_span_id(span_id)
271+
# Yield control so the other task interleaves between set and consume.
272+
await asyncio.sleep(0)
273+
return generator.generate_span_id()
274+
275+
async def main() -> tuple[int, int]:
276+
return await asyncio.gather(task(task_a_span_id), task(task_b_span_id))
277+
278+
result_a, result_b = asyncio.run(main())
279+
280+
# Despite interleaving, each task consumed only its own deterministic ID.
281+
assert result_a == task_a_span_id
282+
assert result_b == task_b_span_id

0 commit comments

Comments
 (0)