Skip to content

Commit 5f793df

Browse files
authored
[StreamingDataLoader, 4/N] feat: Introduce sample pre-allocation for dynamic streaming (#16)
## Background In PR #9, we introduced initial support for the `StreamingDataLoader` interface. Currently, the system assumes prompts are pre-loaded into the TransferQueue. However, a critical use case involves generation workers put both prompts and responses into `TransferQueue` on the run (e.g., `rollout_buffer` mechanism in [Slime](https://github.com/THUDM/slime/blob/main/slime_plugins/rollout_buffer/README.md)). Since TransferQueue supports dynamic expansion, if the producer has not yet pushed any data to the TransferQueue, the TransferQueue appears empty. Consequently, the consumer's `check_consumption_status` API incorrectly assumes no data is available and prematurely terminates the data retrieval iteration. ## Solution This PR introduces a new environment variable, `TQ_PRE_ALLOC_SAMPLE_NUM`, to handle sample pre-allocation in TransferQueue. - **Mechanism**: When set (typically to `global_batch_size`), the controller pre-allocates a fixed number of global indexes before data production begins. - **Effect**: The `check_consumption_status` API now accounts for these pre-allocated slots. This ensures the `StreamingDataLoader` waits for the pending data instead of exiting immediately when the TransferQueue is temporarily empty. ## Other Changes Deprecate `TQ_INIT_SAMPLE_NUM`, `TQ_INIT_FIELD_NUM`, `TQ_SAMPLE_MIN_EXPANSION_SIZE` and `TQ_SAMPLE_MIN_EXPANSION_SIZE` for simplicity. --- CC: @NINGBENZHE --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 6f44413 commit 5f793df

4 files changed

Lines changed: 286 additions & 38 deletions

File tree

tests/test_controller_data_partitions.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,172 @@ def test_get_consumption_status_parameter():
653653
print("✓ get_consumption_status with mask works")
654654

655655
print("Consumption status mask parameter tests passed!\n")
656+
657+
658+
def test_pre_allocated_indexes_basic():
659+
"""Test basic pre-allocated indexes functionality in DataPartitionStatus."""
660+
from transfer_queue.controller import DataPartitionStatus
661+
662+
print("Testing pre-allocated indexes basic functionality...")
663+
664+
partition = DataPartitionStatus(partition_id="prealloc_test")
665+
666+
# Initially, pre_allocated_global_indexes should be empty
667+
assert len(partition.pre_allocated_global_indexes) == 0
668+
assert partition.total_samples_num == 0
669+
670+
print("✓ Initial state correct")
671+
672+
# Register pre-allocated indexes
673+
pre_allocated = [0, 1, 2, 3, 4]
674+
partition.register_pre_allocated_indexes(pre_allocated)
675+
676+
assert partition.pre_allocated_global_indexes == set(pre_allocated)
677+
# global_indexes should still be empty until retrieved
678+
assert partition.total_samples_num == 0
679+
680+
print("✓ Pre-allocated indexes registered")
681+
682+
# activate pre-allocated indexes
683+
retrieved = partition.activate_pre_allocated_indexes(3)
684+
685+
assert len(retrieved) == 3
686+
assert set(retrieved) == {0, 1, 2}
687+
assert partition.global_indexes == {0, 1, 2}
688+
assert partition.pre_allocated_global_indexes == {3, 4}
689+
assert partition.total_samples_num == 3
690+
691+
print("✓ Pre-allocated indexes activate & retrieved correctly")
692+
693+
# Activate remaining indexes
694+
retrieved = partition.activate_pre_allocated_indexes(5)
695+
696+
assert len(retrieved) == 2 # Only 2 remaining
697+
assert set(retrieved) == {3, 4}
698+
assert partition.global_indexes == {0, 1, 2, 3, 4}
699+
assert partition.pre_allocated_global_indexes == set()
700+
assert partition.total_samples_num == 5
701+
702+
print("✓ All pre-allocated indexes retrieved")
703+
704+
print("Pre-allocated indexes basic tests passed!\n")
705+
706+
707+
def test_pre_allocated_indexes_consumption_status():
708+
"""Test that pre-allocated indexes are included in consumption status."""
709+
import torch
710+
711+
from transfer_queue.controller import DataPartitionStatus
712+
713+
print("Testing pre-allocated indexes in consumption status...")
714+
715+
partition = DataPartitionStatus(partition_id="consumption_test")
716+
717+
# Register pre-allocated indexes
718+
partition.register_pre_allocated_indexes([0, 1, 2, 3, 4])
719+
720+
# Get consumption status - should include pre-allocated indexes
721+
global_index, consumption_status = partition.get_consumption_status("test_task", mask=True)
722+
723+
# global_index should include all pre-allocated indexes
724+
assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long))
725+
# All consumption statuses should be 0 (not consumed yet)
726+
assert torch.all(consumption_status == 0)
727+
728+
print("✓ Consumption status includes pre-allocated indexes")
729+
730+
# Mark some samples as consumed
731+
partition.mark_consumed("test_task", [0, 2, 4])
732+
733+
# Get consumption status again
734+
global_index, consumption_status = partition.get_consumption_status("test_task", mask=True)
735+
736+
assert consumption_status[0].item() == 1 # consumed
737+
assert consumption_status[1].item() == 0 # not consumed
738+
assert consumption_status[2].item() == 1 # consumed
739+
assert consumption_status[3].item() == 0 # not consumed
740+
assert consumption_status[4].item() == 1 # consumed
741+
742+
print("✓ Marked consumed works with pre-allocated indexes")
743+
744+
print("Pre-allocated indexes consumption status tests passed!\n")
745+
746+
747+
def test_pre_allocated_indexes_in_scan_data_status():
748+
"""Test that pre-allocated indexes affect scan_data_status behavior."""
749+
from transfer_queue.controller import DataPartitionStatus
750+
751+
print("Testing pre-allocated indexes in scan_data_status...")
752+
753+
partition = DataPartitionStatus(partition_id="scan_test")
754+
755+
# Register pre-allocated indexes (5 samples)
756+
partition.register_pre_allocated_indexes([0, 1, 2, 3, 4])
757+
758+
# Before any production, scan should return empty (no samples produced yet)
759+
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
760+
assert ready == []
761+
762+
print("✓ Scan returns empty before production")
763+
764+
# Now produce some samples (0, 2, 4)
765+
partition.update_production_status(
766+
global_indices=[0, 2, 4],
767+
field_names=["input_ids"],
768+
dtypes={i: {"input_ids": "torch.int32"} for i in [0, 2, 4]},
769+
shapes={i: {"input_ids": (32,)} for i in [0, 2, 4]},
770+
)
771+
772+
# Scan should return produced and unconsumed samples
773+
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
774+
assert set(ready) == {0, 2, 4}
775+
776+
print("✓ Scan returns produced samples correctly")
777+
778+
# Mark sample 2 as consumed
779+
partition.mark_consumed("test_task", [2])
780+
781+
# Scan should now return only 0 and 4
782+
ready = partition.scan_data_status(field_names=["input_ids"], task_name="test_task")
783+
assert set(ready) == {0, 4}
784+
785+
print("✓ Scan respects consumption status")
786+
787+
print("Pre-allocated indexes scan_data_status tests passed!\n")
788+
789+
790+
def test_pre_allocated_indexes_mixed_with_dynamic():
791+
"""Test mixing pre-allocated indexes with dynamically allocated ones."""
792+
from transfer_queue.controller import DataPartitionStatus
793+
794+
print("Testing mixed pre-allocated and dynamic indexes...")
795+
796+
partition = DataPartitionStatus(partition_id="mixed_test")
797+
798+
# Register 3 pre-allocated indexes
799+
partition.register_pre_allocated_indexes([0, 1, 2])
800+
801+
# Simulate adding more samples (indexes 5, 6, 7)
802+
# This would happen when producer calls update_production_status
803+
partition.update_production_status(
804+
global_indices=[5, 6, 7],
805+
field_names=["input_ids"],
806+
dtypes={i: {"input_ids": "torch.int32"} for i in [5, 6, 7]},
807+
shapes={i: {"input_ids": (32,)} for i in [5, 6, 7]},
808+
)
809+
810+
# Now global_indexes should only contain dynamically generated in (5,6,7)
811+
assert partition.global_indexes == {5, 6, 7}
812+
assert partition.total_samples_num == 3
813+
814+
# all pre-allocated
815+
retrieved = partition.activate_pre_allocated_indexes(3)
816+
assert set(retrieved) == {0, 1, 2}
817+
818+
# Now global_indexes should have both pre-allocated (0,1,2) and dynamic (5,6,7)
819+
assert partition.global_indexes == {0, 1, 2, 5, 6, 7}
820+
assert partition.total_samples_num == 6
821+
822+
print("✓ Mixed pre-allocated and dynamic indexes work correctly")
823+
824+
print("Mixed indexes tests passed!\n")

0 commit comments

Comments
 (0)