Skip to content

Commit 34e9a50

Browse files
authored
[feat] Provide fine-grained production & consumption status retrieval (#8)
## Background In previous PR TransferQueue/TransferQueue#157, a coarse-grained status check was implemented. Through the following APIs in `TransferQueueClient`, a single boolean flag is returned to indicate whether all samples are produced or consumed. ```python3 async def async_check_consumption_status( self, task_name: str, partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, ) -> bool: """Check if all samples for current partition have been consumed by a specific task. Args: task_name: Name of the task to check consumption for partition_id: Partition id to check consumption status for socket: ZMQ async socket for message transmission (injected by decorator) Returns: bool: True if all samples have been consumed by the task, False otherwise Raises: RuntimeError: If communication fails or controller returns error response """ async def async_check_production_status( self, data_fields: list[str], partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, ) -> bool: """Check if all samples for current partition are ready (produced) for consumption. Args: data_fields: Data fields to check production status for partition_id: Partition id to check production status for socket: ZMQ async socket for message transmission (injected by decorator) Returns: bool: True if all samples have been produced and ready, False otherwise Raises: RuntimeError: If communication fails or controller returns error response\ """ ``` # Changes This PR introduces a fine-grained check mechanism. It provides the capability that allows users to retrieving the detailed production and consumption status of every individual sample. ```python3 async def async_get_consumption_status( self, task_name: str, partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]]: """Get consumption status for current partition in a specific task. Args: task_name: Name of the task to check consumption for partition_id: Partition id to check consumption status for socket: ZMQ async socket for message transmission (injected by decorator) Returns: Tuple of: - Partition global index tensor - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. """ async def async_get_production_status( self, data_fields: list[str], partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]]: """Get production status for current partition in a specific task. Args: data_fields: Data fields to check production status for partition_id: Partition id to check production status for socket: ZMQ async socket for message transmission (injected by decorator) Returns: Tuple of: - Partition global index tensor - Production status tensor for the specified task. 1 for ready, 0 for not ready. Raises: RuntimeError: If communication fails or controller returns error response """ ``` CC @NINGBENZHE @walterchenchn --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent ab06f87 commit 34e9a50

6 files changed

Lines changed: 548 additions & 104 deletions

File tree

tests/test_client.py

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,20 @@ def _handle_requests(self):
112112
# Mock partition metadata response
113113
response_body = {"metadata": self._mock_batch_meta(request_msg.body)}
114114
response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE
115-
elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
115+
elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION:
116116
# Mock consumption status check - all consumed
117117
response_body = {
118118
"partition_id": request_msg.body.get("partition_id"),
119-
"consumed": True,
119+
"global_index": torch.tensor([0, 1, 2]),
120+
"consumption_status": torch.tensor([1, 1, 1]),
120121
}
121122
response_type = ZMQRequestType.CONSUMPTION_RESPONSE
122-
elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION:
123+
elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION:
123124
# Mock production status check - all produced
124125
response_body = {
125126
"partition_id": request_msg.body.get("partition_id"),
126-
"produced": True,
127+
"global_index": torch.tensor([0, 1, 2]),
128+
"production_status": torch.tensor([[1, 1, 1], [1, 1, 1]]),
127129
}
128130
response_type = ZMQRequestType.PRODUCTION_RESPONSE
129131
elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS:
@@ -467,6 +469,52 @@ def test_check_production_status(client_setup):
467469
assert is_produced is True
468470

469471

472+
def test_get_consumption_status(client_setup):
473+
"""Test get_consumption_status - returns global_index and consumption_status tensors"""
474+
client, _, _ = client_setup
475+
476+
# Test synchronous get_consumption_status
477+
global_index, consumption_status = client.get_consumption_status(
478+
task_name="generate_sequences", partition_id="train_0"
479+
)
480+
481+
# Verify return types
482+
assert global_index is not None
483+
assert consumption_status is not None
484+
485+
# Verify global_index contains expected values
486+
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
487+
488+
# Verify consumption_status (mock returns all consumed)
489+
expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
490+
assert torch.equal(consumption_status, expected_status)
491+
492+
print("✓ get_consumption_status returns correct global_index and consumption_status")
493+
494+
495+
def test_get_production_status(client_setup):
496+
"""Test get_production_status - returns global_index and production_status tensors"""
497+
client, _, _ = client_setup
498+
499+
# Test synchronous get_production_status
500+
global_index, production_status = client.get_production_status(
501+
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
502+
)
503+
504+
# Verify return types
505+
assert global_index is not None
506+
assert production_status is not None
507+
508+
# Verify global_index contains expected values
509+
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
510+
511+
# Verify production_status shape (mock returns 2x3 matrix)
512+
expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
513+
assert torch.equal(production_status, expected_status)
514+
515+
print("✓ get_production_status returns correct global_index and production_status")
516+
517+
470518
def test_get_partition_list(client_setup):
471519
"""Test partition list retrieval"""
472520
client, _, _ = client_setup
@@ -502,6 +550,54 @@ async def test_async_check_production_status(client_setup):
502550
assert is_produced is True
503551

504552

553+
@pytest.mark.asyncio
554+
async def test_async_get_consumption_status(client_setup):
555+
"""Test async get_consumption_status - returns global_index and consumption_status tensors"""
556+
client, _, _ = client_setup
557+
558+
# Test async_get_consumption_status
559+
global_index, consumption_status = await client.async_get_consumption_status(
560+
task_name="generate_sequences", partition_id="train_0"
561+
)
562+
563+
# Verify return types
564+
assert global_index is not None
565+
assert consumption_status is not None
566+
567+
# Verify global_index contains expected values
568+
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
569+
570+
# Verify consumption_status (mock returns all consumed)
571+
expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
572+
assert torch.equal(consumption_status, expected_status)
573+
574+
print("✓ async_get_consumption_status returns correct global_index and consumption_status")
575+
576+
577+
@pytest.mark.asyncio
578+
async def test_async_get_production_status(client_setup):
579+
"""Test async get_production_status - returns global_index and production_status tensors"""
580+
client, _, _ = client_setup
581+
582+
# Test async_get_production_status
583+
global_index, production_status = await client.async_get_production_status(
584+
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
585+
)
586+
587+
# Verify return types
588+
assert global_index is not None
589+
assert production_status is not None
590+
591+
# Verify global_index contains expected values
592+
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))
593+
594+
# Verify production_status shape (mock returns 2x3 matrix)
595+
expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
596+
assert torch.equal(production_status, expected_status)
597+
598+
print("✓ async_get_production_status returns correct global_index and production_status")
599+
600+
505601
@pytest.mark.asyncio
506602
async def test_async_get_partition_list(client_setup):
507603
"""Test async partition list retrieval"""

tests/test_controller.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def test_controller_with_single_partition(self, ray_setup):
8989
field_names=metadata.field_names,
9090
dtypes=dtypes,
9191
shapes=shapes,
92+
custom_meta=None,
9293
)
9394
)
9495
assert success
@@ -97,13 +98,18 @@ def test_controller_with_single_partition(self, ray_setup):
9798
assert partition.production_status.size(0) == gbs * num_n_samples
9899

99100
# Test for get production status
100-
production_status = ray.get(
101+
global_index, production_status = ray.get(
101102
tq_controller.get_production_status.remote(
102103
partition_id=partition_id,
103104
data_fields=data_fields,
104105
)
105106
)
106-
assert production_status
107+
# Verify global_index contains all expected indexes
108+
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
109+
# Verify all samples are produced for all fields (status should be 1)
110+
expected_production_status = torch.ones(gbs * num_n_samples, len(metadata.field_names), dtype=torch.int8)
111+
assert torch.equal(production_status, expected_production_status)
112+
print("✓ Get production status returns correct global_index and production_status")
107113

108114
# Total fields should match the number of fields we added
109115
assert partition.total_fields_num == len(data_fields)
@@ -126,14 +132,19 @@ def test_controller_with_single_partition(self, ray_setup):
126132

127133
print(f"✓ Updated production status for partition {partition_id}")
128134

129-
# Test for get consumption status
130-
consumption_status = ray.get(
135+
# Test for get consumption status BEFORE consumption
136+
global_index, consumption_status = ray.get(
131137
tq_controller.get_consumption_status.remote(
132138
partition_id=partition_id,
133139
task_name="generate_sequences",
134140
)
135141
)
136-
assert torch.equal(consumption_status, torch.zeros(gbs * num_n_samples))
142+
# Verify global_index
143+
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
144+
# Verify all samples are NOT consumed yet (status should be 0)
145+
expected_consumption_status_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
146+
assert torch.equal(consumption_status, expected_consumption_status_before)
147+
print("✓ Get consumption status returns correct global_index and status (before consumption)")
137148

138149
# Test get metadate in fetch mode
139150
gen_meta = ray.get(
@@ -153,14 +164,19 @@ def test_controller_with_single_partition(self, ray_setup):
153164
assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples))
154165
print("✓ Get metadata in fetch mode correct")
155166

156-
# Test for get consumption status
157-
consumption_status = ray.get(
167+
# Test for get consumption status AFTER consumption
168+
global_index, consumption_status = ray.get(
158169
tq_controller.get_consumption_status.remote(
159170
partition_id=partition_id,
160171
task_name="generate_sequences",
161172
)
162173
)
163-
assert torch.equal(consumption_status, torch.ones(gbs * num_n_samples))
174+
# Verify global_index
175+
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
176+
# Verify all samples are consumed (status should be 1)
177+
expected_consumption_status_after = torch.ones(gbs * num_n_samples, dtype=torch.int8)
178+
assert torch.equal(consumption_status, expected_consumption_status_after)
179+
print("✓ Get consumption status returns correct global_index and status (after consumption)")
164180

165181
# Test get clear meta
166182
clear_meta = ray.get(
@@ -222,6 +238,19 @@ def test_controller_with_multi_partitions(self, ray_setup):
222238
)
223239
assert success
224240

241+
# Verify get production status returns correct data
242+
global_index_1, production_status_1 = ray.get(
243+
tq_controller.get_production_status.remote(
244+
partition_id=partition_id_1,
245+
data_fields=data_fields,
246+
)
247+
)
248+
expected_global_index_1 = torch.tensor(range(gbs_1 * num_n_samples_1), dtype=torch.long)
249+
assert torch.equal(global_index_1, expected_global_index_1)
250+
expected_production_status_1 = torch.ones(gbs_1 * num_n_samples_1, len(data_fields), dtype=torch.int8)
251+
assert torch.equal(production_status_1, expected_production_status_1)
252+
print("✓ Get production status for partition_1 returns correct global_index and status")
253+
225254
# Test get metadate in fetch mode
226255
gen_meta = ray.get(
227256
tq_controller.get_metadata.remote(
@@ -234,6 +263,18 @@ def test_controller_with_multi_partitions(self, ray_setup):
234263
)
235264
assert gen_meta
236265

266+
# Verify get consumption status after fetch (samples should be consumed)
267+
global_index_1_consumed, consumption_status_1 = ray.get(
268+
tq_controller.get_consumption_status.remote(
269+
partition_id=partition_id_1,
270+
task_name="generate_sequences",
271+
)
272+
)
273+
assert torch.equal(global_index_1_consumed, expected_global_index_1)
274+
expected_consumption_status_1 = torch.ones(gbs_1 * num_n_samples_1, dtype=torch.int8)
275+
assert torch.equal(consumption_status_1, expected_consumption_status_1)
276+
print("✓ Get consumption status for partition_1 returns correct global_index and status (after fetch)")
277+
237278
# Test get clear meta
238279
clear_meta = ray.get(
239280
tq_controller.get_metadata.remote(
@@ -282,6 +323,33 @@ def test_controller_with_multi_partitions(self, ray_setup):
282323
)
283324
assert success
284325

326+
# Verify get production status for partition_2
327+
global_index_2, production_status_2 = ray.get(
328+
tq_controller.get_production_status.remote(
329+
partition_id=partition_id_2,
330+
data_fields=data_fields,
331+
)
332+
)
333+
expected_global_index_2 = torch.tensor(
334+
range(part1_index_range, part2_index_range + part1_index_range), dtype=torch.long
335+
)
336+
assert torch.equal(global_index_2, expected_global_index_2)
337+
expected_production_status_2 = torch.ones(part2_index_range, len(data_fields), dtype=torch.int8)
338+
assert torch.equal(production_status_2, expected_production_status_2)
339+
print("✓ Get production status for partition_2 returns correct global_index and status")
340+
341+
# Verify get consumption status for partition_2 (before consumption - should be all zeros)
342+
global_index_2_consumed, consumption_status_2 = ray.get(
343+
tq_controller.get_consumption_status.remote(
344+
partition_id=partition_id_2,
345+
task_name="generate_sequences",
346+
)
347+
)
348+
assert torch.equal(global_index_2_consumed, expected_global_index_2)
349+
expected_consumption_status_2 = torch.zeros(part2_index_range, dtype=torch.int8)
350+
assert torch.equal(consumption_status_2, expected_consumption_status_2)
351+
print("✓ Get consumption status for partition_2 returns correct global_index and status (before consumption)")
352+
285353
# Clear partition 1
286354
partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
287355
assert partition_index_range_1

0 commit comments

Comments
 (0)