44from __future__ import annotations
55
66import logging
7- from typing import TYPE_CHECKING , Optional , Tuple
7+ from collections .abc import Mapping
8+ from typing import TYPE_CHECKING , Any , Optional , Tuple , cast
89
910from celery import registry # pylint: disable=no-name-in-module
1011from celery .app .task import Task
1819if TYPE_CHECKING :
1920 from contextlib import AbstractContextManager
2021
22+ ContextTuple = Tuple [Span , "AbstractContextManager[Span]" , Optional [object ]]
23+ ContextDict = dict [tuple [str , bool ], ContextTuple ]
24+
2125logger = logging .getLogger (__name__ )
2226
2327# Celery Context key
4852
4953
5054# pylint:disable=too-many-branches
51- def set_attributes_from_context (span , context ):
55+ def set_attributes_from_context (
56+ span : Span ,
57+ context : Mapping [str , Any ],
58+ ) -> None :
5259 """Helper to extract meta values from a Celery Context"""
5360 if not span .is_recording ():
5461 return
@@ -144,7 +151,7 @@ def attach_context(
144151 if task is None :
145152 return
146153
147- ctx_dict = getattr (task , CTX_KEY , None )
154+ ctx_dict = cast ( Optional [ ContextDict ], getattr (task , CTX_KEY , None ) )
148155
149156 if ctx_dict is None :
150157 ctx_dict = {}
@@ -153,12 +160,17 @@ def attach_context(
153160 ctx_dict [(task_id , is_publish )] = (span , activation , token )
154161
155162
156- def detach_context (task , task_id , is_publish = False ) -> None :
163+ def detach_context (
164+ task : Optional [Task ], task_id : str , is_publish : bool = False
165+ ) -> None :
157166 """Helper to remove `Span`, `ContextManager` and context token in a
158167 Celery task when it's propagated.
159168 This function handles tasks where no values are attached to the `Task`.
160169 """
161- span_dict = getattr (task , CTX_KEY , None )
170+ if task is None :
171+ return
172+
173+ span_dict = cast (Optional [ContextDict ], getattr (task , CTX_KEY , None ))
162174 if span_dict is None :
163175 return
164176
@@ -167,27 +179,30 @@ def detach_context(task, task_id, is_publish=False) -> None:
167179
168180
169181def retrieve_context (
170- task , task_id , is_publish = False
171- ) -> Optional [Tuple [ Span , AbstractContextManager [ Span ], Optional [ object ]] ]:
182+ task : Optional [ Task ] , task_id : str , is_publish : bool = False
183+ ) -> Optional [ContextTuple ]:
172184 """Helper to retrieve an active `Span`, `ContextManager` and context token
173185 stored in a `Task` instance
174186 """
175- span_dict = getattr (task , CTX_KEY , None )
187+ if task is None :
188+ return None
189+
190+ span_dict = cast (Optional [ContextDict ], getattr (task , CTX_KEY , None ))
176191 if span_dict is None :
177192 return None
178193
179194 # See note in `attach_context` for key info
180195 return span_dict .get ((task_id , is_publish ), None )
181196
182197
183- def retrieve_task (kwargs ) :
198+ def retrieve_task (kwargs : Mapping [ str , Any ]) -> Optional [ Task ] :
184199 task = kwargs .get ("task" )
185200 if task is None :
186201 logger .debug ("Unable to retrieve task from signal arguments" )
187- return task
202+ return cast ( Optional [ Task ], task )
188203
189204
190- def retrieve_task_from_sender (kwargs ) :
205+ def retrieve_task_from_sender (kwargs : Mapping [ str , Any ]) -> Optional [ Task ] :
191206 sender = kwargs .get ("sender" )
192207 if sender is None :
193208 logger .debug ("Unable to retrieve the sender from signal arguments" )
@@ -199,30 +214,31 @@ def retrieve_task_from_sender(kwargs):
199214 if sender is None :
200215 logger .debug ("Unable to retrieve the task from sender=%s" , sender )
201216
202- return sender
217+ return cast ( Optional [ Task ], sender )
203218
204219
205- def retrieve_task_id (kwargs ) :
220+ def retrieve_task_id (kwargs : Mapping [ str , Any ]) -> Optional [ str ] :
206221 task_id = kwargs .get ("task_id" )
207222 if task_id is None :
208223 logger .debug ("Unable to retrieve task_id from signal arguments" )
209- return task_id
224+ return cast ( Optional [ str ], task_id )
210225
211226
212- def retrieve_task_id_from_request (kwargs ) :
227+ def retrieve_task_id_from_request (kwargs : Mapping [ str , Any ]) -> Optional [ str ] :
213228 # retry signal does not include task_id as argument so use request argument
214229 request = kwargs .get ("request" )
215230 if request is None :
216231 logger .debug ("Unable to retrieve the request from signal arguments" )
232+ return None
217233
218- task_id = getattr (request , "id" )
234+ task_id = cast ( Optional [ str ], getattr (request , "id" , None ) )
219235 if task_id is None :
220236 logger .debug ("Unable to retrieve the task_id from the request" )
221237
222238 return task_id
223239
224240
225- def retrieve_task_id_from_message (kwargs ) :
241+ def retrieve_task_id_from_message (kwargs : Mapping [ str , Any ]) -> Optional [ str ] :
226242 """Helper to retrieve the `Task` identifier from the message `body`.
227243 This helper supports Protocol Version 1 and 2. The Protocol is well
228244 detailed in the official documentation:
@@ -232,12 +248,14 @@ def retrieve_task_id_from_message(kwargs):
232248 body = kwargs .get ("body" )
233249 if headers is not None and len (headers ) > 0 :
234250 # Protocol Version 2 (default from Celery 4.0)
235- return headers .get ("id" )
251+ return cast ( Optional [ str ], headers .get ("id" ) )
236252 # Protocol Version 1
237- return body .get ("id" )
253+ if body is None :
254+ return None
255+ return cast (Optional [str ], body .get ("id" ))
238256
239257
240- def retrieve_reason (kwargs ) :
258+ def retrieve_reason (kwargs : Mapping [ str , Any ]) -> Optional [ object ] :
241259 reason = kwargs .get ("reason" )
242260 if not reason :
243261 logger .debug ("Unable to retrieve the retry reason" )
0 commit comments