-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathplugin.py
More file actions
499 lines (430 loc) · 20.5 KB
/
Copy pathplugin.py
File metadata and controls
499 lines (430 loc) · 20.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
"""OpenTelemetry instrumentation plugin for AWS Durable Execution SDK."""
from __future__ import annotations
import datetime
import logging
import threading
from typing import TYPE_CHECKING, Any
from aws_durable_execution_sdk_python.lambda_service import OperationType
from aws_durable_execution_sdk_python.plugin import (
DurableInstrumentationPlugin,
InvocationEndInfo,
InvocationStartInfo,
OperationEndInfo,
OperationStartInfo,
UserFunctionEndInfo,
UserFunctionOutcome,
UserFunctionStartInfo,
)
from opentelemetry import context, trace
from opentelemetry.context import Context
from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider
from opentelemetry.sdk.trace.sampling import TraceIdRatioBased
from opentelemetry.trace import (
Link,
Span,
SpanContext,
StatusCode,
TraceFlags,
Tracer,
)
from aws_durable_execution_sdk_python_otel.context_extractors import (
ContextExtractor,
xray_context_extractor,
)
from aws_durable_execution_sdk_python_otel.deterministic_id_generator import (
DeterministicIdGenerator,
operation_id_to_span_id,
)
from aws_durable_execution_sdk_python_otel.logger import OtelEnrichedLogger
if TYPE_CHECKING:
from aws_durable_execution_sdk_python.types import LoggerInterface
logger = logging.getLogger(__name__)
def _to_otel_timestamp(dt: datetime.datetime | None) -> int | None:
"""Convert a datetime to OTel timestamp (nanoseconds since epoch), or None."""
if dt is None:
dt = datetime.datetime.now(datetime.UTC)
return int(dt.timestamp() * 1_000_000_000)
class DurableExecutionOtelPlugin(DurableInstrumentationPlugin):
"""OpenTelemetry instrumentation plugin for durable executions.
The plugin creates spans for Lambda invocations, durable operations, and
user-function attempts. Trace IDs are derived from the durable execution ARN
and execution start time so each replay or resumed invocation contributes to
the same trace.
Operation IDs are converted into deterministic span IDs. The first observed
span for an operation uses that deterministic ID; later continuation spans
use newly generated span IDs and link back to the deterministic span ID so
trace viewers can relate retries and replay-created terminal spans to the
original logical operation.
Args:
trace_provider: OpenTelemetry tracer provider used to create spans.
context_extractor: Optional extractor for upstream context. Defaults to
AWS X-Ray header extraction.
sampling_rate: Ratio used by ``TraceIdRatioBased`` sampling.
instrument_name: Instrumentation scope name registered with the tracer.
"""
DEFAULT_INSTRUMENT_NAME = "aws-durable-execution-sdk-python"
def __init__(
self,
trace_provider: SdkTracerProvider,
context_extractor: ContextExtractor | None = None,
sampling_rate: float = 1.0,
instrument_name: str = DEFAULT_INSTRUMENT_NAME,
enrich_logger: bool = True,
) -> None:
"""Initialize the plugin with an OpenTelemetry tracer provider.
The provided tracer provider is configured with this plugin's
deterministic ID generator and sampling strategy so spans for a durable
execution share stable trace and logical operation identifiers.
"""
self._enrich_logger = enrich_logger
self._context_extractor: ContextExtractor = (
context_extractor or xray_context_extractor
)
self._provider = trace_provider
# A ProxyTracerProvider (the API default from trace.get_tracer_provider()
# before an SDK provider is configured) has no id_generator; fall back to
# None so DeterministicIdGenerator uses its own default generator.
self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator(
fallback_id_generator=getattr(self._provider, "id_generator", None)
)
self._provider.id_generator = self._id_generator
self._provider.sampler = TraceIdRatioBased(sampling_rate)
self._tracer: Tracer = self._provider.get_tracer(instrument_name)
# per invocation status:
self._execution_arn = ""
self._extracted_context: Context | None = None
# Maps operation ID (None for root) to the active span.
self._operation_spans: dict[str | None, Span] = {}
self._operation_spans_lock = threading.RLock()
def wrap_logger(self, logger: LoggerInterface) -> LoggerInterface | None:
"""Wrap the execution logger to inject OTel trace context.
When enrich_logger is enabled (default), returns an OtelEnrichedLogger
that adds trace_id, span_id, and trace_sampled to every log message.
Idempotent: returns None if the logger is already an OtelEnrichedLogger.
Args:
logger: The current logger interface from the execution context.
Returns:
An OtelEnrichedLogger wrapping the input, or None if disabled or
already wrapped.
"""
if not self._enrich_logger or isinstance(logger, OtelEnrichedLogger):
return None
return OtelEnrichedLogger(inner=logger, plugin=self)
def _set_span(self, operation_id: str | None, span: Span) -> None:
"""Register the active span for an operation ID."""
with self._operation_spans_lock:
self._operation_spans[operation_id] = span
def _delete_span(self, operation_id: str | None) -> None:
"""Remove the active span for an operation ID if one is stored."""
with self._operation_spans_lock:
self._operation_spans.pop(operation_id, None)
def _get_span(self, operation_id: str | None) -> Span | None:
"""Return the active span for an operation ID, if present."""
with self._operation_spans_lock:
return self._operation_spans.get(operation_id)
def get_current_span_context(self) -> SpanContext | None:
"""Return the span context to use for log correlation.
Resolution order:
1. The span attached to the OTel thread-local context. Inside a step or
child context this is the active operation span (attached in
on_user_function_start), and between operations it is the enclosing
operation span (restored in on_user_function_end).
2. The invocation span from the plugin registry. This is the path used
for top-level handler code: the invocation span is never attached to
the worker thread's context, so the registry is the only way to
resolve it.
Returns:
A valid SpanContext, or None if no span is active.
"""
span_context = trace.get_current_span().get_span_context()
if span_context and span_context.is_valid:
return span_context
invocation_span = self._get_span(None)
if invocation_span:
invocation_context = invocation_span.get_span_context()
if invocation_context and invocation_context.is_valid:
return invocation_context
return None
# ------------------------------------------------------------------
# Context resolution
# ------------------------------------------------------------------
def _resolve_parent_span(self, parent_id: str | None = None) -> Span:
"""Resolve the active parent span for a durable operation.
``parent_id`` is ``None`` for root-level durable operations beneath the
invocation span. For child operations, the parent operation must already
have an active span in the current invocation.
Raises:
ValueError: If the requested parent span is not active.
"""
# Check if we already have a context for this parent
existing_span = self._get_span(parent_id)
if existing_span is not None:
return existing_span
raise ValueError("No parent span found")
def _start_span(
self,
operation_id: str | None,
name: str,
attributes: dict[str, str],
start_time: datetime.datetime | None = None,
parent_span: Span | None = None,
existed: bool = False,
) -> Span:
"""Start and store a span for an invocation or durable operation.
Args:
operation_id: Durable operation ID. ``None`` is used for the root
invocation span.
name: Span display name.
attributes: Span attributes.
start_time: Optional durable start timestamp.
parent_span: Active parent span. When omitted, the extracted
upstream context is used as the parent.
existed: Whether the logical operation already had a previous span.
Continuation spans link back to the deterministic span ID for
the operation while using a fresh generated span ID.
Returns:
The started OpenTelemetry span.
"""
logger.info(
"starting a span: operation_id=%s, name=%s, parent_span=%s",
operation_id,
name,
parent_span,
)
with self._operation_spans_lock:
if existed:
if not operation_id:
raise ValueError("operation id is required")
span_id = operation_id_to_span_id(operation_id)
links = [
Link(
context=SpanContext(
trace_id=self._id_generator.generate_trace_id(),
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
)
]
self._id_generator.set_next_span_id(None)
else:
links = []
self._id_generator.set_next_span_id(
operation_id_to_span_id(operation_id) if operation_id else None
)
if parent_span is None:
# root span
parent_context = self._extracted_context
else:
parent_context = trace.set_span_in_context(
parent_span, self._extracted_context
)
span = self._tracer.start_span(
name=name,
attributes=attributes,
start_time=_to_otel_timestamp(start_time),
context=parent_context,
links=links,
)
self._operation_spans[operation_id] = span
logger.info("started a span: %s", span)
return span
def _end_span(
self, operation_id: str | None, end_timestamp: datetime.datetime | None = None
):
"""End and unregister the active span for an operation ID.
Args:
operation_id: Durable operation ID, or ``None`` for the invocation
span.
end_timestamp: Optional durable end timestamp to use as the span end
time. When omitted, OpenTelemetry uses the current time.
"""
logger.info("ending a span for operation: %s", operation_id)
with self._operation_spans_lock:
span = self._operation_spans.pop(operation_id, None)
if span:
# the span is not going to be populated if it has the same end_time and start_time
end_time = _to_otel_timestamp(end_timestamp) if end_timestamp else None
span.end(end_time=end_time)
logger.info("ended otel span: %s", span)
# ------------------------------------------------------------------
# Plugin lifecycle callbacks
# ------------------------------------------------------------------
def on_invocation_start(self, info: InvocationStartInfo) -> None:
"""Called at the start of each invocation. Creates the invocation span."""
logger.info("Invocation started: %s", info)
self._execution_arn = info.execution_arn or ""
self._extracted_context = self._context_extractor(info)
self._id_generator.set_trace_id(self._execution_arn, info.start_time)
self._start_span(
operation_id=None,
name="invocation",
attributes=self._extract_attributes(info),
)
def on_invocation_end(self, info: InvocationEndInfo) -> None:
"""Called at the end of each invocation. Ends the invocation span and flushes."""
logger.info(f"Invocation ended: {info}")
end_time = info.end_time
# end all pending spans
with self._operation_spans_lock:
operation_ids = list(self._operation_spans.keys())
for operation_id in operation_ids:
if operation_id:
self._end_span(operation_id, end_time)
# end the invocation span
self._end_span(None, end_time)
# Clear all per-invocation state to prevent leaks across warm Lambda reuses
self._execution_arn = ""
self._extracted_context = None
with self._operation_spans_lock:
self._operation_spans = {}
# Flush before Lambda freeze
if hasattr(self._provider, "force_flush"):
self._provider.force_flush()
def on_operation_start(self, info: OperationStartInfo) -> None:
"""Called when an operation begins. Creates a span for the operation."""
logger.info(f"Operation started: {info}")
if info.operation_type in [OperationType.CONTEXT, OperationType.STEP]:
# Context and Step operations are tracked using on_user_function_start
return
parent_span = self._resolve_parent_span(info.parent_id)
attributes = self._extract_attributes(info)
self._start_span(
operation_id=info.operation_id,
name=info.name or info.operation_id,
attributes=attributes,
start_time=info.start_time,
parent_span=parent_span,
)
def on_operation_end(self, info: OperationEndInfo) -> None:
"""Called when an operation reaches a terminal durable status.
Non-user-function operations are started by ``on_operation_start``. If
an operation end is observed without a matching in-memory span, this
invocation is completing an operation that began earlier, so a short
continuation span is created and linked to the deterministic logical
operation span before being ended.
"""
logger.info(f"Operation ended: {info}")
if info.operation_type in [OperationType.CONTEXT, OperationType.STEP]:
# Context and Step operations are tracked using on_user_function_end
return
span = self._get_span(info.operation_id)
if not span:
# the span was not started in the current invocation, so we need to
# create a new one that links to the previous one
parent_span = self._resolve_parent_span(info.parent_id)
attributes = self._extract_attributes(info)
span = self._start_span(
operation_id=info.operation_id,
name=info.name or info.operation_id,
attributes=attributes,
start_time=datetime.datetime.now(datetime.UTC),
parent_span=parent_span,
existed=True,
)
if info.error:
span.set_status(StatusCode.ERROR, info.error.message or "")
span.record_exception(
Exception(info.error.message or info.error.type or "Unknown error")
)
else:
span.set_status(StatusCode.OK)
end_timestamp = info.end_time
if end_timestamp is not None and end_timestamp == info.start_time:
end_timestamp += datetime.timedelta(microseconds=1)
self._end_span(info.operation_id, end_timestamp)
def on_user_function_start(self, info: UserFunctionStartInfo) -> None:
"""Called when a context or step operation starts user code.
This callback runs inside the thread that executes user code so the
started span can be attached to the OpenTelemetry context for any
instrumentation used by that code. Attempts after the first are emitted
as continuation spans linked to the logical operation span.
Args:
info: Information about the operation attempt.
"""
logger.info("User function started: %s", info)
# Context and Step operations are tracked using on_user_function_start
if info.operation_type not in [OperationType.CONTEXT, OperationType.STEP]:
raise RuntimeError(
"on_user_function_start should only be called for CONTEXT and STEP operations"
)
parent_span = self._resolve_parent_span(info.parent_id)
attributes = self._extract_attributes(info)
span = self._start_span(
operation_id=info.operation_id,
name=info.name or info.operation_id,
attributes=attributes,
start_time=info.start_time,
parent_span=parent_span,
existed=info.attempt != 1,
)
context.attach(trace.set_span_in_context(span, self._extracted_context))
def on_user_function_end(self, info: UserFunctionEndInfo) -> None:
"""Called when a context or step operation finishes user code.
This callback records the final attempt status, captures exceptions for
failed attempts, and ends the span that was attached in
``on_user_function_start``.
Args:
info: Information about the operation attempt.
"""
logger.info("User function ended: %s", info)
if info.operation_type not in [OperationType.CONTEXT, OperationType.STEP]:
raise RuntimeError(
"on_user_function_end should only be called for CONTEXT and STEP operations"
)
# key = f"{info.operation_id}-{int(info.start_time.timestamp())}"
span = self._get_span(info.operation_id)
if not span:
raise RuntimeError(
"on_user_function_end called without matching on_user_function_start"
)
span.set_attributes(self._extract_attributes(info))
if info.outcome is UserFunctionOutcome.FAILED:
span.set_status(StatusCode.ERROR, info.error.message if info.error else "")
span.record_exception(
Exception(
(info.error.message or info.error.type)
if info.error
else "Unknown error"
)
)
elif info.outcome is UserFunctionOutcome.SUCCEEDED:
span.set_status(StatusCode.OK)
else:
# PENDING
span.set_status(StatusCode.UNSET, "PENDING")
end_timestamp = info.end_time
if end_timestamp is not None and end_timestamp == info.start_time:
end_timestamp += datetime.timedelta(microseconds=1)
self._end_span(info.operation_id, end_timestamp)
# Restore the enclosing operation span as current so code that runs
# after this operation (e.g. between steps in a child context)
# correlates to its enclosing operation, not the operation that just
# ended. For a top-level operation (parent_id is None) this is the
# invocation span; for a nested operation it is the parent context span.
parent_span = self._get_span(info.parent_id) or self._get_span(None)
if parent_span:
context.attach(
trace.set_span_in_context(parent_span, self._extracted_context)
)
def _extract_attributes(self, info: Any) -> dict[str, str]:
"""Extract durable execution fields as OpenTelemetry span attributes.
Args:
info: Invocation, operation, or user-function callback payload.
Returns:
A dictionary of durable execution attributes suitable for a span.
"""
attributes: dict[str, str] = {
"durable.execution.arn": self._execution_arn,
}
if hasattr(info, "operation_id") and info.operation_id is not None:
attributes["durable.operation.id"] = info.operation_id
if hasattr(info, "operation_type") and info.operation_type is not None:
attributes["durable.operation.type"] = info.operation_type.value
if hasattr(info, "name") and info.name is not None:
attributes["durable.operation.name"] = info.name
if hasattr(info, "attempt") and info.attempt is not None:
attributes["durable.attempt.number"] = info.attempt
if hasattr(info, "outcome") and info.outcome is not None:
attributes["durable.attempt.outcome"] = info.outcome.value
return attributes