Skip to content

Commit 6f0fc06

Browse files
authored
Merge branch 'main' into mooncake_dev
2 parents c74d687 + 0a5176b commit 6f0fc06

23 files changed

Lines changed: 787 additions & 424 deletions

recipe/simple_use_case/single_controller_demo.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,24 @@ def compute_loss(data1, _data2):
5353

5454
def compute_reward(response_ids: torch.Tensor) -> TensorDict:
5555
"""Simulate a reward model that scores each token position in the response.
56+
Returns a TensorDict with a ``"rm_score"`` field whose shape matches
57+
``response_ids`` (i.e. one scalar per response token).
58+
"""
59+
time.sleep(1)
60+
reward = torch.randn_like(response_ids, dtype=torch.float32)
61+
62+
return TensorDict({"rm_score": reward}, batch_size=response_ids.size(0))
63+
64+
65+
def compute_advantage(rewards: torch.Tensor) -> TensorDict:
66+
"""Simulate the process of computing advantage.
5667
5768
Returns a TensorDict with an ``"advantage"`` field whose shape matches
58-
``response_ids`` (i.e. one scalar per response token).
69+
``rewards`` (i.e. one scalar per reward).
5970
"""
6071
time.sleep(1)
61-
advantage = torch.randn_like(response_ids, dtype=torch.float32)
62-
return TensorDict({"advantage": advantage}, batch_size=response_ids.size(0))
72+
advantage = torch.randn_like(rewards, dtype=torch.float32)
73+
return TensorDict({"advantage": advantage}, batch_size=rewards.size(0))
6374

6475

6576
class TrainingWorker:
@@ -89,7 +100,7 @@ def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
89100
"""Simulate forward-only inference"""
90101
# 1. Pull data from storage
91102
data = tq.kv_batch_get_by_meta(meta=kv_meta)
92-
logger.info(f"compute_log_prob: got data {data}")
103+
logger.info(f"infer_batch: got data {data}")
93104

94105
# 2. Model forward
95106
output = compute_log_prob(data["prompt_ids"], data["response_ids"])
@@ -494,6 +505,13 @@ def fit(self):
494505
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=reward_output)
495506
logger.info(f"demo reward KVBatchMeta: {meta}")
496507

508+
# ========================= Compute advantage =========================
509+
meta.fields = ["response_ids", "ref_log_prob", "old_log_prob", "rm_score"]
510+
advantage_data = tq.kv_batch_get_by_meta(meta=meta)
511+
advantage_output = compute_advantage(advantage_data["rm_score"])
512+
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=advantage_output)
513+
logger.info(f"demo advantage KVBatchMeta: {meta}")
514+
497515
# ========================= Update actor =========================
498516
meta.fields = [
499517
"input_ids",

tests/test_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def client_setup(mock_controller, mock_storage):
418418
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
419419

420420
# Mock all storage manager methods to avoid real ZMQ operations
421-
async def mock_put_data(data, metadata):
421+
async def mock_put_data(data, metadata, data_parser=None):
422422
pass # Just pretend to store the data
423423

424424
async def mock_get_data(metadata):
@@ -511,7 +511,7 @@ def test_single_controller_multiple_storages():
511511
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
512512

513513
# Mock all storage manager methods to avoid real ZMQ operations
514-
async def mock_put_data(data, metadata):
514+
async def mock_put_data(data, metadata, data_parser=None):
515515
pass # Just pretend to store the data
516516

517517
async def mock_get_data(metadata):

tests/test_simple_storage_unit.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@ def __init__(self, storage_put_get_address):
3434
self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
3535
self.socket.connect(storage_put_get_address)
3636

37-
def send_put(self, client_id, global_indexes, field_data):
37+
def send_put(self, client_id, global_indexes, field_data, data_parser=None):
38+
body = {"global_indexes": global_indexes, "data": field_data}
39+
if data_parser is not None:
40+
body["data_parser"] = data_parser
3841
msg = ZMQMessage.create(
3942
request_type=ZMQRequestType.PUT_DATA,
4043
sender_id=f"mock_client_{client_id}",
41-
body={"global_indexes": global_indexes, "data": field_data},
44+
body=body,
4245
)
4346
self.socket.send_multipart(msg.serialize())
4447
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))
@@ -434,3 +437,177 @@ def test_storage_unit_data_capacity_uses_active_keys():
434437
assert len(storage._active_keys) == 2
435438
storage.put_data({"f": [4]}, global_indexes=[3])
436439
assert storage._active_keys == {0, 1, 3}
440+
441+
442+
def test_storage_unit_data_parser(storage_setup):
443+
"""Test data_parser functionality in SimpleStorageUnit.
444+
445+
Writes two columns:
446+
- normal_data: regular tensors, should remain unchanged
447+
- data_to_be_parsed: list of shape descriptors (list of ints)
448+
449+
data_parser converts shape descriptors into random tensors of those shapes.
450+
"""
451+
_, put_get_address = storage_setup
452+
client = MockStorageClient(put_get_address)
453+
454+
def create_data_by_shape_parser(field_data):
455+
if "data_to_be_parsed" in field_data:
456+
shapes = field_data["data_to_be_parsed"]
457+
field_data["data_to_be_parsed"] = [torch.randn(shape) for shape in shapes]
458+
return field_data
459+
460+
# Prepare data: normal_data is a batch tensor, data_to_be_parsed is a list of shape lists
461+
field_data = {
462+
"normal_data": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
463+
"data_to_be_parsed": [[2, 3], [1, 4], [3, 2]],
464+
}
465+
global_indexes = [0, 1, 2]
466+
467+
# Put with data_parser
468+
response = client.send_put(0, global_indexes, field_data, data_parser=create_data_by_shape_parser)
469+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"Put failed: {response.body}"
470+
471+
# Get back
472+
response = client.send_get(0, global_indexes, ["normal_data", "data_to_be_parsed"])
473+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
474+
475+
result = response.body["data"]
476+
477+
# Verify normal_data is unchanged
478+
torch.testing.assert_close(result["normal_data"][0], torch.tensor([1.0, 2.0]))
479+
torch.testing.assert_close(result["normal_data"][1], torch.tensor([3.0, 4.0]))
480+
torch.testing.assert_close(result["normal_data"][2], torch.tensor([5.0, 6.0]))
481+
482+
# Verify data_to_be_parsed shapes match the input shape descriptors
483+
expected_shapes = [(2, 3), (1, 4), (3, 2)]
484+
for i, expected_shape in enumerate(expected_shapes):
485+
actual_shape = tuple(result["data_to_be_parsed"][i].shape)
486+
assert actual_shape == expected_shape, (
487+
f"Shape mismatch at index {i}: expected {expected_shape}, got {actual_shape}"
488+
)
489+
490+
client.close()
491+
492+
493+
def test_storage_unit_data_parser_callable_types(storage_setup):
494+
"""Test that various callable types (partial, callable class) work as data_parser."""
495+
_, put_get_address = storage_setup
496+
client = MockStorageClient(put_get_address)
497+
498+
from functools import partial
499+
500+
# 1. Test functools.partial
501+
def _partial_parser(field_data, prefix):
502+
if "text" in field_data:
503+
field_data["text"] = [f"{prefix}{t}" for t in field_data["text"]]
504+
return field_data
505+
506+
partial_parser = partial(_partial_parser, prefix="parsed_")
507+
508+
response = client.send_put(
509+
0,
510+
[0, 1],
511+
{"text": ["a", "b"]},
512+
data_parser=partial_parser,
513+
)
514+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"partial parser failed: {response.body}"
515+
516+
response = client.send_get(0, [0, 1], ["text"])
517+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
518+
assert response.body["data"]["text"] == ["parsed_a", "parsed_b"]
519+
520+
# 2. Test callable class instance
521+
class CallableParser:
522+
def __call__(self, field_data):
523+
if "value" in field_data:
524+
field_data["value"] = [v * 2 for v in field_data["value"]]
525+
return field_data
526+
527+
callable_parser = CallableParser()
528+
response = client.send_put(
529+
0,
530+
[2, 3],
531+
{"value": [1, 2]},
532+
data_parser=callable_parser,
533+
)
534+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"callable class parser failed: {response.body}"
535+
536+
response = client.send_get(0, [2, 3], ["value"])
537+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
538+
assert response.body["data"]["value"] == [2, 4]
539+
540+
client.close()
541+
542+
543+
def test_storage_unit_data_parser_validation(storage_setup):
544+
"""Test that invalid data_parser inputs produce clear error messages."""
545+
_, put_get_address = storage_setup
546+
client = MockStorageClient(put_get_address)
547+
548+
# 1. Non-callable data_parser should return a clear TypeError
549+
response = client.send_put(
550+
0,
551+
[0],
552+
{"data": [1]},
553+
data_parser="not_callable",
554+
)
555+
assert response.request_type == ZMQRequestType.PUT_ERROR
556+
assert "data_parser must be callable" in response.body["message"]
557+
558+
# 2. data_parser returning non-dict should return a clear TypeError
559+
def bad_parser(field_data):
560+
return "not_a_dict"
561+
562+
response = client.send_put(
563+
0,
564+
[1],
565+
{"data": [1]},
566+
data_parser=bad_parser,
567+
)
568+
assert response.request_type == ZMQRequestType.PUT_ERROR
569+
assert "data_parser must return a dict" in response.body["message"]
570+
571+
# 3. data_parser deleting a key should return a clear ValueError
572+
def delete_key_parser(field_data):
573+
del field_data["data"]
574+
return field_data
575+
576+
response = client.send_put(
577+
0,
578+
[2],
579+
{"data": [1], "extra": [2]},
580+
data_parser=delete_key_parser,
581+
)
582+
assert response.request_type == ZMQRequestType.PUT_ERROR
583+
assert "data_parser must not change dict keys" in response.body["message"]
584+
585+
# 4. data_parser adding a key should return a clear ValueError
586+
def add_key_parser(field_data):
587+
field_data["new_key"] = [999]
588+
return field_data
589+
590+
response = client.send_put(
591+
0,
592+
[3],
593+
{"data": [1]},
594+
data_parser=add_key_parser,
595+
)
596+
assert response.request_type == ZMQRequestType.PUT_ERROR
597+
assert "data_parser must not change dict keys" in response.body["message"]
598+
599+
# 5. data_parser changing element count should return a clear ValueError
600+
def wrong_len_parser(field_data):
601+
field_data["data"] = field_data["data"][:-1]
602+
return field_data
603+
604+
response = client.send_put(
605+
0,
606+
[4, 5],
607+
{"data": [1, 2]},
608+
data_parser=wrong_len_parser,
609+
)
610+
assert response.request_type == ZMQRequestType.PUT_ERROR
611+
assert "data_parser changed the number of elements" in response.body["message"]
612+
613+
client.close()

0 commit comments

Comments
 (0)