44from __future__ import annotations
55
66import logging
7- from typing import TYPE_CHECKING , Optional , Tuple
7+ from collections .abc import Mapping
8+ from typing import TYPE_CHECKING , Any , Optional , Protocol , cast
89
910from celery import registry # pylint: disable=no-name-in-module
1011from celery .app .task import Task
1819if TYPE_CHECKING :
1920 from contextlib import AbstractContextManager
2021
22+ ContextKey = tuple [str , bool ]
23+ ContextTuple = tuple [Span , "AbstractContextManager[Span]" , object | None ]
24+ ContextDict = dict [ContextKey , ContextTuple ]
25+
26+
27+ class ContextCarrier (Protocol ):
28+ def get (self , key : str , default : Any = None ) -> Any : ...
29+
30+
2131logger = logging .getLogger (__name__ )
2232
2333# Celery Context key
4858
4959
5060# pylint:disable=too-many-branches
51- def set_attributes_from_context (span , context ):
61+ def set_attributes_from_context (
62+ span : Span ,
63+ context : ContextCarrier ,
64+ ) -> None :
5265 """Helper to extract meta values from a Celery Context"""
5366 if not span .is_recording ():
5467 return
@@ -144,7 +157,7 @@ def attach_context(
144157 if task is None :
145158 return
146159
147- ctx_dict = getattr (task , CTX_KEY , None )
160+ ctx_dict = cast ( Optional [ ContextDict ], getattr (task , CTX_KEY , None ) )
148161
149162 if ctx_dict is None :
150163 ctx_dict = {}
@@ -153,12 +166,17 @@ def attach_context(
153166 ctx_dict [(task_id , is_publish )] = (span , activation , token )
154167
155168
156- def detach_context (task , task_id , is_publish = False ) -> None :
169+ def detach_context (
170+ task : Optional [Task ], task_id : str , is_publish : bool = False
171+ ) -> None :
157172 """Helper to remove `Span`, `ContextManager` and context token in a
158173 Celery task when it's propagated.
159174 This function handles tasks where no values are attached to the `Task`.
160175 """
161- span_dict = getattr (task , CTX_KEY , None )
176+ if task is None :
177+ return
178+
179+ span_dict = cast (Optional [ContextDict ], getattr (task , CTX_KEY , None ))
162180 if span_dict is None :
163181 return
164182
@@ -167,27 +185,30 @@ def detach_context(task, task_id, is_publish=False) -> None:
167185
168186
169187def retrieve_context (
170- task , task_id , is_publish = False
171- ) -> Optional [Tuple [ Span , AbstractContextManager [ Span ], Optional [ object ]] ]:
188+ task : Optional [ Task ] , task_id : str , is_publish : bool = False
189+ ) -> Optional [ContextTuple ]:
172190 """Helper to retrieve an active `Span`, `ContextManager` and context token
173191 stored in a `Task` instance
174192 """
175- span_dict = getattr (task , CTX_KEY , None )
193+ if task is None :
194+ return None
195+
196+ span_dict = cast (Optional [ContextDict ], getattr (task , CTX_KEY , None ))
176197 if span_dict is None :
177198 return None
178199
179200 # See note in `attach_context` for key info
180201 return span_dict .get ((task_id , is_publish ), None )
181202
182203
183- def retrieve_task (kwargs ) :
204+ def retrieve_task (kwargs : Mapping [ str , Any ]) -> Optional [ Task ] :
184205 task = kwargs .get ("task" )
185206 if task is None :
186207 logger .debug ("Unable to retrieve task from signal arguments" )
187- return task
208+ return cast ( Optional [ Task ], task )
188209
189210
190- def retrieve_task_from_sender (kwargs ) :
211+ def retrieve_task_from_sender (kwargs : Mapping [ str , Any ]) -> Optional [ Task ] :
191212 sender = kwargs .get ("sender" )
192213 if sender is None :
193214 logger .debug ("Unable to retrieve the sender from signal arguments" )
@@ -199,30 +220,31 @@ def retrieve_task_from_sender(kwargs):
199220 if sender is None :
200221 logger .debug ("Unable to retrieve the task from sender=%s" , sender )
201222
202- return sender
223+ return cast ( Optional [ Task ], sender )
203224
204225
205- def retrieve_task_id (kwargs ) :
226+ def retrieve_task_id (kwargs : Mapping [ str , Any ]) -> Optional [ str ] :
206227 task_id = kwargs .get ("task_id" )
207228 if task_id is None :
208229 logger .debug ("Unable to retrieve task_id from signal arguments" )
209- return task_id
230+ return cast ( Optional [ str ], task_id )
210231
211232
212- def retrieve_task_id_from_request (kwargs ) :
233+ def retrieve_task_id_from_request (kwargs : Mapping [ str , Any ]) -> Optional [ str ] :
213234 # retry signal does not include task_id as argument so use request argument
214235 request = kwargs .get ("request" )
215236 if request is None :
216237 logger .debug ("Unable to retrieve the request from signal arguments" )
238+ return None
217239
218- task_id = getattr (request , "id" )
240+ task_id = cast ( Optional [ str ], getattr (request , "id" , None ) )
219241 if task_id is None :
220242 logger .debug ("Unable to retrieve the task_id from the request" )
221243
222244 return task_id
223245
224246
225- def retrieve_task_id_from_message (kwargs ) :
247+ def retrieve_task_id_from_message (kwargs : Mapping [ str , Any ]) -> Optional [ str ] :
226248 """Helper to retrieve the `Task` identifier from the message `body`.
227249 This helper supports Protocol Version 1 and 2. The Protocol is well
228250 detailed in the official documentation:
@@ -232,12 +254,14 @@ def retrieve_task_id_from_message(kwargs):
232254 body = kwargs .get ("body" )
233255 if headers is not None and len (headers ) > 0 :
234256 # Protocol Version 2 (default from Celery 4.0)
235- return headers .get ("id" )
257+ return cast ( Optional [ str ], headers .get ("id" ) )
236258 # Protocol Version 1
237- return body .get ("id" )
259+ if body is None :
260+ return None
261+ return cast (Optional [str ], body .get ("id" ))
238262
239263
240- def retrieve_reason (kwargs ) :
264+ def retrieve_reason (kwargs : Mapping [ str , Any ]) -> Optional [ object ] :
241265 reason = kwargs .get ("reason" )
242266 if not reason :
243267 logger .debug ("Unable to retrieve the retry reason" )
0 commit comments