@@ -48,13 +48,16 @@ def add(x, y):
4848---
4949"""
5050
51+ from __future__ import annotations
52+
5153import logging
54+ from collections .abc import Collection , Iterable
5255from timeit import default_timer
53- from typing import Collection , Iterable
5456
5557from billiard import VERSION
5658from billiard .einfo import ExceptionInfo
5759from celery import signals # pylint: disable=no-name-in-module
60+ from celery .worker .request import Request # pylint: disable=no-name-in-module
5861
5962from opentelemetry import context as context_api
6063from opentelemetry import trace
@@ -88,8 +91,8 @@ def add(x, y):
8891_TASK_NAME_KEY = "celery.task_name"
8992
9093
91- class CeleryGetter (Getter ):
92- def get (self , carrier , key ) :
94+ class CeleryGetter (Getter [ Request ] ):
95+ def get (self , carrier : Request , key : str ) -> list [ str ] | None :
9396 value = getattr (carrier , key , None )
9497 if value is None :
9598 return None
@@ -98,25 +101,25 @@ def get(self, carrier, key):
98101 # of ints). The TextMapPropagator contract requires string
99102 # values, so coerce anything that isn't already a string.
100103 if isinstance (value , str ):
101- value = (value ,)
102- elif isinstance (value , Iterable ):
103- value = tuple (
104- str (v ) if not isinstance (v , str ) else v for v in value
105- )
106- else :
107- value = (str (value ),)
108- return value
104+ return [value ]
105+ if isinstance (value , Iterable ):
106+ return [str (v ) if not isinstance (v , str ) else v for v in value ]
107+ return [str (value )]
109108
110- def keys (self , carrier ) :
109+ def keys (self , carrier : Request ) -> list [ str ] :
111110 return []
112111
113112
114113celery_getter = CeleryGetter ()
115114
116115
117116class CeleryInstrumentor (BaseInstrumentor ):
118- metrics = None
119- task_id_to_start_time = {}
117+ def __init__ (self ):
118+ super ().__init__ ()
119+ if not hasattr (self , "metrics" ):
120+ self .metrics = None
121+ if not hasattr (self , "task_id_to_start_time" ):
122+ self .task_id_to_start_time = {}
120123
121124 def instrumentation_dependencies (self ) -> Collection [str ]:
122125 return _instruments
@@ -139,6 +142,7 @@ def _instrument(self, **kwargs):
139142 schema_url = "https://opentelemetry.io/schemas/1.11.0" ,
140143 )
141144
145+ self .task_id_to_start_time = {}
142146 self .create_celery_metrics (meter )
143147
144148 signals .task_prerun .connect (self ._trace_prerun , weak = False )
@@ -159,6 +163,7 @@ def _uninstrument(self, **kwargs):
159163 signals .after_task_publish .disconnect (self ._trace_after_publish )
160164 signals .task_failure .disconnect (self ._trace_failure )
161165 signals .task_retry .disconnect (self ._trace_retry )
166+ self .task_id_to_start_time = {}
162167
163168 def _trace_prerun (self , * args , ** kwargs ):
164169 task = utils .retrieve_task (kwargs )
@@ -213,6 +218,7 @@ def _trace_postrun(self, *args, **kwargs):
213218 self .update_task_duration_time (task_id )
214219 labels = {"task" : task .name , "worker" : task .request .hostname }
215220 self ._record_histograms (task_id , labels )
221+ self .task_id_to_start_time .pop (task_id , None )
216222 # if the process sending the task is not instrumented
217223 # there's no incoming context and no token to detach
218224 if token is not None :
0 commit comments