Skip to content

Commit 61aadaf

Browse files
committed
fix partial pre-commit
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 3fe421d commit 61aadaf

4 files changed

Lines changed: 26 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ pretty = true
7373
ignore_missing_imports = true
7474
explicit_package_bases = true
7575
follow_imports = "skip"
76-
ignore_errors = false
76+
77+
# Blanket silence
78+
ignore_errors = true
7779

7880
# -------------------------------
7981
# tool.pytest - pytest config
@@ -83,6 +85,12 @@ filterwarnings = [
8385
"ignore:.*PyTorch API of nested tensors.*prototype.*:UserWarning",
8486
]
8587

88+
[[tool.mypy.overrides]]
89+
module = [
90+
"transfer_queue.*",
91+
]
92+
ignore_errors = false
93+
8694
# -------------------------------
8795
# tool.setuptools - Additional config
8896
# -------------------------------

transfer_queue/controller.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,17 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]):
188188
if not partition_indexes:
189189
self.partition_to_indexes.pop(partition_id, None)
190190

191-
def get_indexes_for_partition(self, partition_id) -> set[int]:
191+
def get_indexes_for_partition(self, partition_id) -> list[int]:
192192
"""
193193
Get all global_indexes for the specified partition.
194194
195195
Args:
196196
partition_id: Partition ID
197197
198198
Returns:
199-
set: Set of global_indexes for this partition
199+
list: List of global_indexes for this partition
200200
"""
201-
return self.partition_to_indexes.get(partition_id, set()).copy()
201+
return list(self.partition_to_indexes.get(partition_id, set()).copy())
202202

203203

204204
@dataclass
@@ -216,7 +216,7 @@ class DataPartitionStatus:
216216

217217
# Production status tensor - dynamically expandable
218218
# Values: 0 = not produced, 1 = ready for consumption
219-
production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
219+
production_status: Tensor = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
220220

221221
# Consumption status per task - task_name -> consumption_tensor
222222
# Each tensor tracks which samples have been consumed by that task
@@ -260,7 +260,7 @@ def allocated_samples_num(self) -> int:
260260

261261
# ==================== Dynamic Expansion Methods ====================
262262

263-
def ensure_samples_capacity(self, required_samples: int) -> bool:
263+
def ensure_samples_capacity(self, required_samples: int):
264264
"""
265265
Ensure the production status tensor has enough rows for the required samples.
266266
Dynamically expands if needed using unified minimum expansion size.
@@ -498,7 +498,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
498498
return partition_global_index, consumption_status
499499

500500
# ==================== Production Status Interface ====================
501-
def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]:
501+
def get_production_status_for_fields(
502+
self, field_names: list[str], mask: bool = False
503+
) -> tuple[Optional[Tensor], Optional[Tensor]]:
502504
"""
503505
Check if all samples for specified fields are fully produced and ready.
504506
@@ -512,12 +514,12 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool =
512514
- Production status tensor for the specified task. 1 for ready, 0 for not ready.
513515
"""
514516
if self.production_status is None or field_names is None or len(field_names) == 0:
515-
return False
517+
return None, None
516518

517519
# Check if all requested fields are registered
518520
for field_name in field_names:
519521
if field_name not in self.field_name_mapping:
520-
return False
522+
return None, None
521523

522524
# Create column mask for requested fields
523525
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
@@ -837,15 +839,15 @@ def list_partitions(self) -> list[str]:
837839

838840
# ==================== Partition Index Management API ====================
839841

840-
def get_partition_index_range(self, partition: DataPartitionStatus) -> set:
842+
def get_partition_index_range(self, partition: DataPartitionStatus) -> list[int]:
841843
"""
842844
Get all indexes for a specific partition.
843845
844846
Args:
845847
partition: Partition identifier
846848
847849
Returns:
848-
Set of indexes allocated to the partition
850+
List of indexes allocated to the partition
849851
"""
850852
return self.index_manager.get_indexes_for_partition(partition)
851853

@@ -980,6 +982,9 @@ def get_metadata(
980982
if mode == "fetch":
981983
# Find ready samples within current data partition and package into BatchMeta when reading
982984

985+
if batch_size is None:
986+
raise ValueError("must provide batch_size in fetch mode")
987+
983988
start_time = time.time()
984989
while True:
985990
# ready_for_consume_indexes: samples where all required fields are produced

transfer_queue/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_all_custom_meta(self) -> dict[int, dict[str, Any]]:
265265
"""Get the entire custom meta dictionary"""
266266
return copy.deepcopy(self._custom_meta)
267267

268-
def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None):
268+
def update_custom_meta(self, new_custom_meta: Optional[dict[int, dict[str, Any]]]):
269269
"""Update custom meta with a new dictionary"""
270270
if new_custom_meta:
271271
self._custom_meta.update(new_custom_meta)

transfer_queue/storage/managers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class TransferQueueStorageManager(ABC):
6060
def __init__(self, config: dict[str, Any]):
6161
self.storage_manager_id = f"TQ_STORAGE_{uuid4().hex[:8]}"
6262
self.config = config
63-
self.controller_info = config.get("controller_info", None) # type: ZMQServerInfo
63+
self.controller_info = config.get("controller_info") # type: ZMQServerInfo
6464

6565
self.data_status_update_socket = None
6666
self.controller_handshake_socket = None

0 commit comments

Comments
 (0)