Skip to content

Commit c74d687

Browse files
committed
fix type comments
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 652d89f commit c74d687

3 files changed

Lines changed: 37 additions & 28 deletions

File tree

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def put(self, keys: list[str], values: list[Any]) -> None:
141141

142142
return None
143143

144-
def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]):
144+
def _put_tensors_thread_worker(self, batch_keys: list[str], batch_tensors: list[Tensor]) -> None:
145145
"""Worker thread for putting batch of tensors to MooncakeStore."""
146146

147147
batch_ptrs, batch_sizes, _contiguous_tensors = self._preprocess_tensors_for_put(batch_tensors)
@@ -261,7 +261,7 @@ def _get_bytes_thread_worker(self, batch_keys: list[str], indexes: list[int]) ->
261261

262262
return results, indexes
263263

264-
def clear(self, keys: list[str], custom_backend_meta=None):
264+
def clear(self, keys: list[str], custom_backend_meta: Optional[list[Any]] = None) -> None:
265265
"""Deletes multiple keys from MooncakeStore.
266266
267267
Args:
@@ -280,10 +280,10 @@ def close(self):
280280
self._store = None
281281

282282
@staticmethod
283-
def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[Any], list[Any], list[Tensor]]:
284-
ptr_list = []
285-
size_list = []
286-
tensor_list = [] # hold reference for the contiguous tensor
283+
def _preprocess_tensors_for_put(values: list[Tensor]) -> tuple[list[int], list[int], list[Tensor]]:
284+
ptr_list: list[int] = []
285+
size_list: list[int] = []
286+
tensor_list: list[Tensor] = [] # hold reference for the contiguous tensor
287287
for t in values:
288288
# TODO: support gpu direct rdma and use different data paths.
289289
# For GPU, it's more reasonable to perform data copy since

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def supports_put(self, value: Any) -> bool:
5757
"""Check if this strategy can store the given value."""
5858

5959
@abstractmethod
60-
def put(self, keys: list[str], values: list[Any]):
60+
def put(self, keys: list[str], values: list[Any]) -> None:
6161
"""Store key-value pairs using this strategy."""
6262

6363
@abstractmethod
@@ -73,7 +73,7 @@ def supports_clear(self, strategy_tag: Any) -> bool:
7373
"""Check if this strategy owns data identified by metadata."""
7474

7575
@abstractmethod
76-
def clear(self, keys: list[str]):
76+
def clear(self, keys: list[str]) -> None:
7777
"""Delete keys from storage."""
7878

7979

@@ -131,7 +131,7 @@ def supports_put(self, value: Any) -> bool:
131131
# Only contiguous NPU tensors are supported by this adapter.
132132
return value.is_contiguous()
133133

134-
def put(self, keys: list[str], values: list[Any]):
134+
def put(self, keys: list[str], values: list[Any]) -> None:
135135
"""Store NPU tensors in batches; deletes before overwrite."""
136136
for i in range(0, len(keys), self.KEYS_LIMIT):
137137
batch_keys = keys[i : i + self.KEYS_LIMIT]
@@ -169,22 +169,22 @@ def supports_clear(self, strategy_tag: str) -> bool:
169169
"""Matches 'DsTensorClient' strategy tag."""
170170
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
171171

172-
def clear(self, keys: list[str]):
172+
def clear(self, keys: list[str]) -> None:
173173
"""Delete NPU tensor keys in batches."""
174174
for i in range(0, len(keys), self.KEYS_LIMIT):
175175
batch = keys[i : i + self.KEYS_LIMIT]
176176
# Todo(dpj): Test call clear when no (key,value) put in ds
177177
self._ds_client.delete(batch)
178178

179-
def _create_empty_npu_tensorlist(self, shapes: list, dtypes: list):
179+
def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) -> list[Tensor]:
180180
"""
181181
Create a list of empty NPU tensors with given shapes and dtypes.
182182
183183
Args:
184184
shapes (list): List of tensor shapes (e.g., [(3,), (2, 4)])
185185
dtypes (list): List of torch dtypes (e.g., [torch.float32, torch.int64])
186186
Returns:
187-
list: List of uninitialized NPU tensors
187+
list[Tensor]: List of uninitialized NPU tensors
188188
"""
189189
tensors: list[Tensor] = []
190190
for shape, dtype in zip(shapes, dtypes, strict=True):
@@ -243,7 +243,7 @@ def supports_put(self, value: Any) -> bool:
243243
"""Accepts any Python object."""
244244
return True
245245

246-
def put(self, keys: list[str], values: list[Any]):
246+
def put(self, keys: list[str], values: list[Any]) -> None:
247247
"""Store objects via zero-copy serialization in batches."""
248248
for i in range(0, len(keys), self.PUT_KEYS_LIMIT):
249249
batch_keys = keys[i : i + self.PUT_KEYS_LIMIT]
@@ -267,7 +267,7 @@ def supports_clear(self, strategy_tag: str) -> bool:
267267
"""Matches 'KVClient' strategy tag."""
268268
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
269269

270-
def clear(self, keys: list[str]):
270+
def clear(self, keys: list[str]) -> None:
271271
"""Delete keys in batches."""
272272
for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
273273
batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
@@ -433,7 +433,13 @@ def put_task(strategy, indexes):
433433
strategy_tags[original_index] = tag
434434
return strategy_tags
435435

436-
def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=None) -> list[Any]:
436+
def get(
437+
self,
438+
keys: list[str],
439+
shapes: Optional[list[Any]] = None,
440+
dtypes: Optional[list[Any]] = None,
441+
custom_backend_meta: Optional[list[str]] = None,
442+
) -> list[Any]:
437443
"""Retrieves multiple values from remote storage with expected metadata.
438444
439445
Requires shape and dtype hints to reconstruct NPU tensors correctly.
@@ -472,7 +478,7 @@ def get_task(strategy, indexes):
472478
results[original_index] = value
473479
return results
474480

475-
def clear(self, keys: list[str], custom_backend_meta=None):
481+
def clear(self, keys: list[str], custom_backend_meta: Optional[list[str]] = None) -> None:
476482
"""Deletes multiple keys from remote storage.
477483
478484
Args:
@@ -513,8 +519,8 @@ def _route_to_strategies(
513519
The order must correspond to the original keys.
514520
selector: A function that determines whether a strategy supports an item.
515521
Signature: `(strategy: StorageStrategy, item: Any) -> bool`.
516-
failback: If True, items that don't match any strategy will be ignored (not included in output).
517-
If False, a ValueError will be raised for any unmatched item.
522+
ignore_unmatched: If True, items that don't match any strategy will be ignored (not included in output).
523+
If False, a ValueError will be raised for any unmatched item.
518524
519525
Returns:
520526
A dictionary mapping each active strategy to a list of indexes in `items`

transfer_queue/storage/managers/base.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -401,19 +401,20 @@ def _generate_keys(field_names: list[str], global_indexes: list[int]) -> list[st
401401
return [pfx + sfx for sfx, pfx in itertools.product(keys_suffixes, keys_prefixes)]
402402

403403
@staticmethod
404-
def _generate_values(data: TensorDict) -> list[Tensor]:
404+
def _generate_values(data: TensorDict) -> list[Any]:
405405
"""
406-
Extract and flatten tensor values from a TensorDict in field-major order.
406+
Extract and flatten values from a TensorDict in field-major order.
407407
Values are ordered by sorted field names, then by row (sample) order within each field.
408408
This matches the key order generated by `_generate_keys`.
409409
410410
Args:
411-
data (TensorDict): Input data where keys are field names and values are tensors.
411+
data (TensorDict): Input data where keys are field names and values are tensors or any type
412+
wrapped by NonTensorStack.
412413
Returns:
413-
list[Tensor]: Flattened list of tensors, e.g.,
414-
[data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...]
414+
list[Any]: Flattened list of values, e.g.,
415+
[data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...]
415416
"""
416-
results: list[Tensor] = []
417+
results: list[Any] = []
417418
for field in sorted(data.keys()):
418419
field_data = data[field]
419420
if isinstance(field_data, Tensor) and field_data.is_nested:
@@ -457,17 +458,17 @@ def _get_executor(self) -> ThreadPoolExecutor:
457458
assert self._multi_threads_executor is not None
458459
return self._multi_threads_executor
459460

460-
def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Tensor]) -> TensorDict:
461+
def _merge_tensors_to_tensordict(self, metadata: BatchMeta, values: list[Any]) -> TensorDict:
461462
"""
462463
Reconstruct a TensorDict from a list of values using metadata.
463464
The values list is assumed to be in the same order as keys generated by `_generate_keys`.
464465
According to field names and global indexes in metadata, this method can determine
465-
which dict key and which row this tensor belongs to. Then it reshapes the flat tensors list
466+
which dict key and which row this value belongs to. Then it reshapes the flat values list
466467
back into a structured TensorDict .
467468
468469
Args:
469470
metadata (BatchMeta): Metadata containing global indexes and field names.
470-
values (list[Tensor]): List of tensors in field-major order.
471+
values (list[Any]): List of values in field-major order.
471472
Returns:
472473
TensorDict: Reconstructed tensor dictionary with batch size equal to number of samples.
473474
"""
@@ -534,7 +535,9 @@ def process_field(field_idx: int):
534535
return TensorDict(merged_data, batch_size=num_samples)
535536

536537
@staticmethod
537-
def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta):
538+
def _get_shape_type_custom_backend_meta_list(
539+
metadata: BatchMeta,
540+
) -> tuple[list[torch.Size], list[torch.dtype], list[Any]]:
538541
"""
539542
Extract the expected shape, dtype, and custom_backend_meta for each field-sample pair in metadata.
540543
The order matches the key/value order: sorted by field name, then by global index.

0 commit comments

Comments
 (0)