@@ -188,17 +188,17 @@ def release_indexes(self, partition_id: str, indexes_to_release: list[int]):
188188 if not partition_indexes :
189189 self .partition_to_indexes .pop (partition_id , None )
190190
191- def get_indexes_for_partition (self , partition_id ) -> set [int ]:
191+ def get_indexes_for_partition (self , partition_id ) -> list [int ]:
192192 """
193193 Get all global_indexes for the specified partition.
194194
195195 Args:
196196 partition_id: Partition ID
197197
198198 Returns:
199- set: Set of global_indexes for this partition
199+ list: List of global_indexes for this partition
200200 """
201- return self .partition_to_indexes .get (partition_id , set ()).copy ()
201+ return list ( self .partition_to_indexes .get (partition_id , set ()).copy () )
202202
203203
204204@dataclass
@@ -216,7 +216,7 @@ class DataPartitionStatus:
216216
217217 # Production status tensor - dynamically expandable
218218 # Values: 0 = not produced, 1 = ready for consumption
219- production_status : Optional [ Tensor ] = torch .zeros (TQ_INIT_SAMPLE_NUM , TQ_INIT_FIELD_NUM , dtype = torch .int8 )
219+ production_status : Tensor = torch .zeros (TQ_INIT_SAMPLE_NUM , TQ_INIT_FIELD_NUM , dtype = torch .int8 )
220220
221221 # Consumption status per task - task_name -> consumption_tensor
222222 # Each tensor tracks which samples have been consumed by that task
@@ -260,7 +260,7 @@ def allocated_samples_num(self) -> int:
260260
261261 # ==================== Dynamic Expansion Methods ====================
262262
263- def ensure_samples_capacity (self , required_samples : int ) -> bool :
263+ def ensure_samples_capacity (self , required_samples : int ):
264264 """
265265 Ensure the production status tensor has enough rows for the required samples.
266266 Dynamically expands if needed using unified minimum expansion size.
@@ -498,7 +498,9 @@ def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Te
498498 return partition_global_index , consumption_status
499499
500500 # ==================== Production Status Interface ====================
501- def get_production_status_for_fields (self , field_names : list [str ], mask : bool = False ) -> tuple [Tensor , Tensor ]:
501+ def get_production_status_for_fields (
502+ self , field_names : list [str ], mask : bool = False
503+ ) -> tuple [Optional [Tensor ], Optional [Tensor ]]:
502504 """
503505 Check if all samples for specified fields are fully produced and ready.
504506
@@ -512,12 +514,12 @@ def get_production_status_for_fields(self, field_names: list[str], mask: bool =
512514 - Production status tensor for the specified task. 1 for ready, 0 for not ready.
513515 """
514516 if self .production_status is None or field_names is None or len (field_names ) == 0 :
515- return False
517+ return None , None
516518
517519 # Check if all requested fields are registered
518520 for field_name in field_names :
519521 if field_name not in self .field_name_mapping :
520- return False
522+ return None , None
521523
522524 # Create column mask for requested fields
523525 col_mask = torch .zeros (self .allocated_fields_num , dtype = torch .bool )
@@ -837,15 +839,15 @@ def list_partitions(self) -> list[str]:
837839
838840 # ==================== Partition Index Management API ====================
839841
840- def get_partition_index_range (self , partition : DataPartitionStatus ) -> set :
842+ def get_partition_index_range (self , partition : DataPartitionStatus ) -> list [ int ] :
841843 """
842844 Get all indexes for a specific partition.
843845
844846 Args:
845847 partition: Partition identifier
846848
847849 Returns:
848- Set of indexes allocated to the partition
850+ List of indexes allocated to the partition
849851 """
850852 return self .index_manager .get_indexes_for_partition (partition )
851853
@@ -980,6 +982,9 @@ def get_metadata(
980982 if mode == "fetch" :
981983 # Find ready samples within current data partition and package into BatchMeta when reading
982984
985+ if batch_size is None :
986+ raise ValueError ("must provide batch_size in fetch mode" )
987+
983988 start_time = time .time ()
984989 while True :
985990 # ready_for_consume_indexes: samples where all required fields are produced
0 commit comments