Skip to content

Commit aec9192

Browse files
authored
[feat] Support user-defined data parser for SimpleStorage backend (Ascend#82)
## Background Users often need to store lightweight **references** (e.g., URLs, file paths, etc.) rather than the full data into TransferQueue to avoid the expensive loading and decoding processes happening within user code, which reduces a data copy. <img width="1032" height="745" alt="image" src="https://github.com/user-attachments/assets/4a62b350-b6f1-441d-8c05-b35f8cf2e7de" /> ## Solution This PR introduces support for a user-defined data parser in the `SimpleStorage` backend. The `kv_put` and `kv_batch_put` methods now accept an optional `data_parser` callable. This parser is executed **inside** each `SimpleStorageUnit` at put time. It receives the raw `field_data` dictionary during the put request and should return a dictionary with the same structure, replacing reference values with the actual parsed data. ## Limitations & Future Work - **Synchronous Execution:** In the current design, the data parser execution is synchronous and part of the `put` operation. This means the put request is only completed when the data parser finishes execution. - **Backend Support:** `data_parser` is currently only supported by the **SimpleStorage** backend. - **Incorrect Metadata:** Allowing user-provided functions to modify data in `SimpleStorageUnit` may lead to incorrect `shape` & `dtype` metadata, which is captured when the data is still in `TransferQueueClient`. This can lead to problems for RDMA transport, which leverages these metadata collected by TQ to restore tensor during `get`. ## Demo Script ```python3 """Demo: concurrent data_parser with separated single-sample logic. This demo shows how to structure a data_parser so that: 1. The **core parser** only handles a **single sample**. 2. The **batch wrapper** uses asyncio to process all samples in parallel. 3. The wrapper is **synchronous to the outside**: it blocks until every sample finishes, so ``data_parser`` returning means data is ready. Scenario: - Users pass URL-like strings in a column. - The parser sleeps 1 s per sample (simulating I/O / decode) and then creates a random tensor of the requested dtype & shape. - Because the sleeps run concurrently via asyncio, a batch of N samples finishes in ~1 s instead of ~N s. """ import asyncio import time import ray import torch from tensordict import TensorDict, NonTensorStack import transfer_queue as tq # --------------------------------------------------------------------------- # Core single-sample parser # --------------------------------------------------------------------------- def parse_url(url: str) -> torch.Tensor: """Parse a URL-like descriptor 'dtype:HxW' into a random tensor.""" dtype_str, shape_str = url.split(":") dtype = getattr(torch, dtype_str) shape = [int(dim) for dim in shape_str.split("x")] return torch.randn(shape, dtype=dtype) # --------------------------------------------------------------------------- # Batch-level parser # --------------------------------------------------------------------------- def concurrent_batch_url_parser(field_data: dict) -> dict: """Batch-level data_parser executed inside SimpleStorageUnit. It receives a ``dict`` (not a TensorDict) where each value is a batched column. For columns created from ``NonTensorStack`` the value is a plain ``list`` of Python objects. Workflow: 1. Spawns one async task per list element. 2. Waits until *all* tasks finish (``asyncio.gather``). 3. Replaces the list with the list of results. Because ``asyncio.run`` blocks until the loop finishes, this function is **synchronous** to its caller: when it returns, every sample has been processed. Args: field_data: Mapping ``field_name -> batched_values``. The dict keys must stay exactly the same; only values may be transformed in-place. Returns: The same dict with parsed values substituted. """ if "data_to_be_parsed" not in field_data: return field_data urls:list[str] = field_data["data_to_be_parsed"] async def _async_parse_single(url: str) -> torch.Tensor: await asyncio.sleep(1.0) # Add fixed delay per sample return parse_url(url) async def _process_all(): tasks = [asyncio.create_task(_async_parse_single(url)) for url in urls] return await asyncio.gather(*tasks) start = time.perf_counter() field_data["data_to_be_parsed"] = asyncio.run(_process_all()) elapsed = time.perf_counter() - start print( f"[data_parser] Processed {len(urls)} samples in {elapsed:.2f}s " f"(serial would be ~{len(urls)}.0s)" ) return field_data # --------------------------------------------------------------------------- # Main demo flow # --------------------------------------------------------------------------- def main(): ray.init(ignore_reinit_error=True) try: tq.init() batch_size = 32 # Column that stays untouched normal_data = torch.randn(batch_size, 2) # Column to be parsed: URL-like strings describing dtype & shape. shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)] urls = [f"float32:{h}x{w}" for h, w in shapes] data_to_be_parsed = NonTensorStack(*urls) data = TensorDict({ "normal_data": normal_data, "data_to_be_parsed": data_to_be_parsed, }, batch_size=batch_size) keys = [f"sample_{i}" for i in range(batch_size)] # ------------------------------------------------------------------- # Put with data_parser # ------------------------------------------------------------------- put_start_time = time.perf_counter() meta = tq.kv_batch_put( keys=keys, partition_id="train", fields=data, data_parser=concurrent_batch_url_parser, ) put_elapsed = time.perf_counter() - put_start_time print(f"Put succeeded. Fields: {meta.fields}") print( f"Total kv_batch_put time: {put_elapsed:.2f}s " f"(concurrency keeps it ~1s, not {batch_size}s)\n" ) # ------------------------------------------------------------------- # Fetch back and verify # ------------------------------------------------------------------- result = tq.kv_batch_get(keys=keys, partition_id="train") # 1) normal_data unchanged torch.testing.assert_close(result["normal_data"], normal_data) print("[PASS] normal_data is unchanged.") # 2) Parsed tensors have correct dtype & shape expected_shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)] for i, exp_shape in enumerate(expected_shapes): tensor = result["data_to_be_parsed"][i] assert tensor.dtype == torch.float32, ( f"dtype mismatch at index {i}: expected torch.float32, got {tensor.dtype}" ) assert tuple(tensor.shape) == exp_shape, ( f"shape mismatch at index {i}: expected {exp_shape}, got {tuple(tensor.shape)}" ) print(f"[PASS] All {batch_size} parsed tensors have correct dtype & shape.") # 3) Timing sanity check # Serial execution would be ~batch_size seconds. # Because asyncio tasks run in parallel, it should be ~1 s. # We allow generous headroom for TQ network / serialization overhead. assert put_elapsed < 2.0, ( f"Expected concurrent execution (~1s), but took {put_elapsed:.2f}s. " "Are the asyncio tasks actually running in parallel?" ) print(f"[PASS] Timing looks concurrent: {put_elapsed:.2f}s < 2.0s") print("\n=== All verifications passed! ===") # wait for Ray log collect time.sleep(2) except Exception as e: print(f"Error: {type(e).__name__}: {e}") import traceback traceback.print_exc() finally: tq.close() ray.shutdown() if __name__ == "__main__": main() ``` --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent c0c0e19 commit aec9192

9 files changed

Lines changed: 570 additions & 132 deletions

File tree

tests/test_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def client_setup(mock_controller, mock_storage):
418418
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
419419

420420
# Mock all storage manager methods to avoid real ZMQ operations
421-
async def mock_put_data(data, metadata):
421+
async def mock_put_data(data, metadata, data_parser=None):
422422
pass # Just pretend to store the data
423423

424424
async def mock_get_data(metadata):
@@ -511,7 +511,7 @@ def test_single_controller_multiple_storages():
511511
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)
512512

513513
# Mock all storage manager methods to avoid real ZMQ operations
514-
async def mock_put_data(data, metadata):
514+
async def mock_put_data(data, metadata, data_parser=None):
515515
pass # Just pretend to store the data
516516

517517
async def mock_get_data(metadata):

tests/test_simple_storage_unit.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@ def __init__(self, storage_put_get_address):
3434
self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
3535
self.socket.connect(storage_put_get_address)
3636

37-
def send_put(self, client_id, global_indexes, field_data):
37+
def send_put(self, client_id, global_indexes, field_data, data_parser=None):
38+
body = {"global_indexes": global_indexes, "data": field_data}
39+
if data_parser is not None:
40+
body["data_parser"] = data_parser
3841
msg = ZMQMessage.create(
3942
request_type=ZMQRequestType.PUT_DATA,
4043
sender_id=f"mock_client_{client_id}",
41-
body={"global_indexes": global_indexes, "data": field_data},
44+
body=body,
4245
)
4346
self.socket.send_multipart(msg.serialize())
4447
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))
@@ -434,3 +437,177 @@ def test_storage_unit_data_capacity_uses_active_keys():
434437
assert len(storage._active_keys) == 2
435438
storage.put_data({"f": [4]}, global_indexes=[3])
436439
assert storage._active_keys == {0, 1, 3}
440+
441+
442+
def test_storage_unit_data_parser(storage_setup):
443+
"""Test data_parser functionality in SimpleStorageUnit.
444+
445+
Writes two columns:
446+
- normal_data: regular tensors, should remain unchanged
447+
- data_to_be_parsed: list of shape descriptors (list of ints)
448+
449+
data_parser converts shape descriptors into random tensors of those shapes.
450+
"""
451+
_, put_get_address = storage_setup
452+
client = MockStorageClient(put_get_address)
453+
454+
def create_data_by_shape_parser(field_data):
455+
if "data_to_be_parsed" in field_data:
456+
shapes = field_data["data_to_be_parsed"]
457+
field_data["data_to_be_parsed"] = [torch.randn(shape) for shape in shapes]
458+
return field_data
459+
460+
# Prepare data: normal_data is a batch tensor, data_to_be_parsed is a list of shape lists
461+
field_data = {
462+
"normal_data": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
463+
"data_to_be_parsed": [[2, 3], [1, 4], [3, 2]],
464+
}
465+
global_indexes = [0, 1, 2]
466+
467+
# Put with data_parser
468+
response = client.send_put(0, global_indexes, field_data, data_parser=create_data_by_shape_parser)
469+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"Put failed: {response.body}"
470+
471+
# Get back
472+
response = client.send_get(0, global_indexes, ["normal_data", "data_to_be_parsed"])
473+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
474+
475+
result = response.body["data"]
476+
477+
# Verify normal_data is unchanged
478+
torch.testing.assert_close(result["normal_data"][0], torch.tensor([1.0, 2.0]))
479+
torch.testing.assert_close(result["normal_data"][1], torch.tensor([3.0, 4.0]))
480+
torch.testing.assert_close(result["normal_data"][2], torch.tensor([5.0, 6.0]))
481+
482+
# Verify data_to_be_parsed shapes match the input shape descriptors
483+
expected_shapes = [(2, 3), (1, 4), (3, 2)]
484+
for i, expected_shape in enumerate(expected_shapes):
485+
actual_shape = tuple(result["data_to_be_parsed"][i].shape)
486+
assert actual_shape == expected_shape, (
487+
f"Shape mismatch at index {i}: expected {expected_shape}, got {actual_shape}"
488+
)
489+
490+
client.close()
491+
492+
493+
def test_storage_unit_data_parser_callable_types(storage_setup):
494+
"""Test that various callable types (partial, callable class) work as data_parser."""
495+
_, put_get_address = storage_setup
496+
client = MockStorageClient(put_get_address)
497+
498+
from functools import partial
499+
500+
# 1. Test functools.partial
501+
def _partial_parser(field_data, prefix):
502+
if "text" in field_data:
503+
field_data["text"] = [f"{prefix}{t}" for t in field_data["text"]]
504+
return field_data
505+
506+
partial_parser = partial(_partial_parser, prefix="parsed_")
507+
508+
response = client.send_put(
509+
0,
510+
[0, 1],
511+
{"text": ["a", "b"]},
512+
data_parser=partial_parser,
513+
)
514+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"partial parser failed: {response.body}"
515+
516+
response = client.send_get(0, [0, 1], ["text"])
517+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
518+
assert response.body["data"]["text"] == ["parsed_a", "parsed_b"]
519+
520+
# 2. Test callable class instance
521+
class CallableParser:
522+
def __call__(self, field_data):
523+
if "value" in field_data:
524+
field_data["value"] = [v * 2 for v in field_data["value"]]
525+
return field_data
526+
527+
callable_parser = CallableParser()
528+
response = client.send_put(
529+
0,
530+
[2, 3],
531+
{"value": [1, 2]},
532+
data_parser=callable_parser,
533+
)
534+
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"callable class parser failed: {response.body}"
535+
536+
response = client.send_get(0, [2, 3], ["value"])
537+
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
538+
assert response.body["data"]["value"] == [2, 4]
539+
540+
client.close()
541+
542+
543+
def test_storage_unit_data_parser_validation(storage_setup):
544+
"""Test that invalid data_parser inputs produce clear error messages."""
545+
_, put_get_address = storage_setup
546+
client = MockStorageClient(put_get_address)
547+
548+
# 1. Non-callable data_parser should return a clear TypeError
549+
response = client.send_put(
550+
0,
551+
[0],
552+
{"data": [1]},
553+
data_parser="not_callable",
554+
)
555+
assert response.request_type == ZMQRequestType.PUT_ERROR
556+
assert "data_parser must be callable" in response.body["message"]
557+
558+
# 2. data_parser returning non-dict should return a clear TypeError
559+
def bad_parser(field_data):
560+
return "not_a_dict"
561+
562+
response = client.send_put(
563+
0,
564+
[1],
565+
{"data": [1]},
566+
data_parser=bad_parser,
567+
)
568+
assert response.request_type == ZMQRequestType.PUT_ERROR
569+
assert "data_parser must return a dict" in response.body["message"]
570+
571+
# 3. data_parser deleting a key should return a clear ValueError
572+
def delete_key_parser(field_data):
573+
del field_data["data"]
574+
return field_data
575+
576+
response = client.send_put(
577+
0,
578+
[2],
579+
{"data": [1], "extra": [2]},
580+
data_parser=delete_key_parser,
581+
)
582+
assert response.request_type == ZMQRequestType.PUT_ERROR
583+
assert "data_parser must not change dict keys" in response.body["message"]
584+
585+
# 4. data_parser adding a key should return a clear ValueError
586+
def add_key_parser(field_data):
587+
field_data["new_key"] = [999]
588+
return field_data
589+
590+
response = client.send_put(
591+
0,
592+
[3],
593+
{"data": [1]},
594+
data_parser=add_key_parser,
595+
)
596+
assert response.request_type == ZMQRequestType.PUT_ERROR
597+
assert "data_parser must not change dict keys" in response.body["message"]
598+
599+
# 5. data_parser changing element count should return a clear ValueError
600+
def wrong_len_parser(field_data):
601+
field_data["data"] = field_data["data"][:-1]
602+
return field_data
603+
604+
response = client.send_put(
605+
0,
606+
[4, 5],
607+
{"data": [1, 2]},
608+
data_parser=wrong_len_parser,
609+
)
610+
assert response.request_type == ZMQRequestType.PUT_ERROR
611+
assert "data_parser changed the number of elements" in response.body["message"]
612+
613+
client.close()

transfer_queue/client.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ async def async_put(
324324
data: TensorDict,
325325
metadata: Optional[BatchMeta] = None,
326326
partition_id: Optional[str] = None,
327+
data_parser: Optional[Callable[[Any], Any]] = None,
327328
) -> BatchMeta:
328329
"""Asynchronously write data to storage units based on metadata.
329330
@@ -342,6 +343,16 @@ async def async_put(
342343
metadata: Records the metadata of a batch of data samples, containing index and
343344
storage unit information. If None, metadata will be auto-generated.
344345
partition_id: Target data partition id (required if metadata is not provided)
346+
data_parser: Optional callable to parse reference data (e.g., URLs) into real
347+
content. The input is a slice of the `data` parameter, in plain
348+
dict format (not TensorDict), mapping field_name -> batched values.
349+
For a regular tensor column the value is a batched tensor; for
350+
nested tensors (jagged or strided) and NonTensorStack columns
351+
the values are extracted into a list. It must modify values
352+
in-place based on the original keys; do not add or remove keys.
353+
The number of elements per column must also remain unchanged.
354+
Do not change the inner order of values within each column.
355+
Only supported by SimpleStorage.
345356
346357
Returns:
347358
BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
@@ -411,7 +422,7 @@ async def async_put(
411422
with limit_pytorch_auto_parallel_threads(
412423
target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_put"
413424
):
414-
await self.storage_manager.put_data(data, metadata)
425+
await self.storage_manager.put_data(data, metadata, data_parser=data_parser)
415426

416427
await self.async_set_custom_meta(metadata)
417428

@@ -1279,7 +1290,11 @@ def set_custom_meta(self, metadata: BatchMeta) -> None:
12791290
return self._set_custom_meta(metadata=metadata)
12801291

12811292
def put(
1282-
self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
1293+
self,
1294+
data: TensorDict,
1295+
metadata: Optional[BatchMeta] = None,
1296+
partition_id: Optional[str] = None,
1297+
data_parser: Optional[Callable[[Any], Any]] = None,
12831298
) -> BatchMeta:
12841299
"""Synchronously write data to storage units based on metadata.
12851300
@@ -1298,6 +1313,16 @@ def put(
12981313
metadata: Records the metadata of a batch of data samples, containing index and
12991314
storage unit information. If None, metadata will be auto-generated.
13001315
partition_id: Target data partition id (required if metadata is not provided)
1316+
data_parser: Optional callable to parse reference data (e.g., URLs) into real
1317+
content. The input is a slice of the `data` parameter, in plain
1318+
dict format (not TensorDict), mapping field_name -> batched values.
1319+
For a regular tensor column the value is a batched tensor; for
1320+
nested tensors (jagged or strided) and NonTensorStack columns
1321+
the values are extracted into a list. It must modify values
1322+
in-place based on the original keys; do not add or remove keys.
1323+
The number of elements per column must also remain unchanged.
1324+
Do not change the inner order of values within each column.
1325+
Only supported by SimpleStorage.
13011326
13021327
Returns:
13031328
BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
@@ -1336,7 +1361,7 @@ def put(
13361361
>>> # This will create metadata in "insert" mode internally.
13371362
>>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
13381363
"""
1339-
return self._put(data=data, metadata=metadata, partition_id=partition_id)
1364+
return self._put(data=data, metadata=metadata, partition_id=partition_id, data_parser=data_parser)
13401365

13411366
def get_data(self, metadata: BatchMeta) -> TensorDict:
13421367
"""Synchronously fetch data from storage units and organize into TensorDict.

0 commit comments

Comments
 (0)