Commit aec9192
authored
[feat] Support user-defined data parser for
## Background
Users often need to store lightweight **references** (e.g., URLs, file
paths, etc.) rather than the full data into TransferQueue to avoid the
expensive loading and decoding processes happening within user code,
which reduces a data copy.
<img width="1032" height="745" alt="image"
src="https://github.com/user-attachments/assets/4a62b350-b6f1-441d-8c05-b35f8cf2e7de"
/>
## Solution
This PR introduces support for a user-defined data parser in the
`SimpleStorage` backend.
The `kv_put` and `kv_batch_put` methods now accept an optional
`data_parser` callable. This parser is executed **inside** each
`SimpleStorageUnit` at put time. It receives the raw `field_data`
dictionary during the put request and should return a dictionary with
the same structure, replacing reference values with the actual parsed
data.
## Limitations & Future Work
- **Synchronous Execution:** In the current design, the data parser
execution is synchronous and part of the `put` operation. This means the
put request is only completed when the data parser finishes execution.
- **Backend Support:** `data_parser` is currently only supported by the
**SimpleStorage** backend.
- **Incorrect Metadata:** Allowing user-provided functions to modify
data in `SimpleStorageUnit` may lead to incorrect `shape` & `dtype`
metadata, which is captured when the data is still in
`TransferQueueClient`. This can lead to problems for RDMA transport,
which leverages these metadata collected by TQ to restore tensor during
`get`.
## Demo Script
```python3
"""Demo: concurrent data_parser with separated single-sample logic.
This demo shows how to structure a data_parser so that:
1. The **core parser** only handles a **single sample**.
2. The **batch wrapper** uses asyncio to process all samples in parallel.
3. The wrapper is **synchronous to the outside**: it blocks until every
sample finishes, so ``data_parser`` returning means data is ready.
Scenario:
- Users pass URL-like strings in a column.
- The parser sleeps 1 s per sample (simulating I/O / decode) and then
creates a random tensor of the requested dtype & shape.
- Because the sleeps run concurrently via asyncio, a batch of N samples
finishes in ~1 s instead of ~N s.
"""
import asyncio
import time
import ray
import torch
from tensordict import TensorDict, NonTensorStack
import transfer_queue as tq
# ---------------------------------------------------------------------------
# Core single-sample parser
# ---------------------------------------------------------------------------
def parse_url(url: str) -> torch.Tensor:
"""Parse a URL-like descriptor 'dtype:HxW' into a random tensor."""
dtype_str, shape_str = url.split(":")
dtype = getattr(torch, dtype_str)
shape = [int(dim) for dim in shape_str.split("x")]
return torch.randn(shape, dtype=dtype)
# ---------------------------------------------------------------------------
# Batch-level parser
# ---------------------------------------------------------------------------
def concurrent_batch_url_parser(field_data: dict) -> dict:
"""Batch-level data_parser executed inside SimpleStorageUnit.
It receives a ``dict`` (not a TensorDict) where each value is a
batched column. For columns created from ``NonTensorStack`` the
value is a plain ``list`` of Python objects.
Workflow:
1. Spawns one async task per list element.
2. Waits until *all* tasks finish (``asyncio.gather``).
3. Replaces the list with the list of results.
Because ``asyncio.run`` blocks until the loop finishes, this function
is **synchronous** to its caller: when it returns, every sample has
been processed.
Args:
field_data: Mapping ``field_name -> batched_values``. The dict
keys must stay exactly the same; only values may be
transformed in-place.
Returns:
The same dict with parsed values substituted.
"""
if "data_to_be_parsed" not in field_data:
return field_data
urls:list[str] = field_data["data_to_be_parsed"]
async def _async_parse_single(url: str) -> torch.Tensor:
await asyncio.sleep(1.0) # Add fixed delay per sample
return parse_url(url)
async def _process_all():
tasks = [asyncio.create_task(_async_parse_single(url)) for url in urls]
return await asyncio.gather(*tasks)
start = time.perf_counter()
field_data["data_to_be_parsed"] = asyncio.run(_process_all())
elapsed = time.perf_counter() - start
print(
f"[data_parser] Processed {len(urls)} samples in {elapsed:.2f}s "
f"(serial would be ~{len(urls)}.0s)"
)
return field_data
# ---------------------------------------------------------------------------
# Main demo flow
# ---------------------------------------------------------------------------
def main():
ray.init(ignore_reinit_error=True)
try:
tq.init()
batch_size = 32
# Column that stays untouched
normal_data = torch.randn(batch_size, 2)
# Column to be parsed: URL-like strings describing dtype & shape.
shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]
urls = [f"float32:{h}x{w}" for h, w in shapes]
data_to_be_parsed = NonTensorStack(*urls)
data = TensorDict({
"normal_data": normal_data,
"data_to_be_parsed": data_to_be_parsed,
}, batch_size=batch_size)
keys = [f"sample_{i}" for i in range(batch_size)]
# -------------------------------------------------------------------
# Put with data_parser
# -------------------------------------------------------------------
put_start_time = time.perf_counter()
meta = tq.kv_batch_put(
keys=keys,
partition_id="train",
fields=data,
data_parser=concurrent_batch_url_parser,
)
put_elapsed = time.perf_counter() - put_start_time
print(f"Put succeeded. Fields: {meta.fields}")
print(
f"Total kv_batch_put time: {put_elapsed:.2f}s "
f"(concurrency keeps it ~1s, not {batch_size}s)\n"
)
# -------------------------------------------------------------------
# Fetch back and verify
# -------------------------------------------------------------------
result = tq.kv_batch_get(keys=keys, partition_id="train")
# 1) normal_data unchanged
torch.testing.assert_close(result["normal_data"], normal_data)
print("[PASS] normal_data is unchanged.")
# 2) Parsed tensors have correct dtype & shape
expected_shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]
for i, exp_shape in enumerate(expected_shapes):
tensor = result["data_to_be_parsed"][i]
assert tensor.dtype == torch.float32, (
f"dtype mismatch at index {i}: expected torch.float32, got {tensor.dtype}"
)
assert tuple(tensor.shape) == exp_shape, (
f"shape mismatch at index {i}: expected {exp_shape}, got {tuple(tensor.shape)}"
)
print(f"[PASS] All {batch_size} parsed tensors have correct dtype & shape.")
# 3) Timing sanity check
# Serial execution would be ~batch_size seconds.
# Because asyncio tasks run in parallel, it should be ~1 s.
# We allow generous headroom for TQ network / serialization overhead.
assert put_elapsed < 2.0, (
f"Expected concurrent execution (~1s), but took {put_elapsed:.2f}s. "
"Are the asyncio tasks actually running in parallel?"
)
print(f"[PASS] Timing looks concurrent: {put_elapsed:.2f}s < 2.0s")
print("\n=== All verifications passed! ===")
# wait for Ray log collect
time.sleep(2)
except Exception as e:
print(f"Error: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
finally:
tq.close()
ray.shutdown()
if __name__ == "__main__":
main()
```
---------
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>SimpleStorage backend (Ascend#82)1 parent c0c0e19 commit aec9192
9 files changed
Lines changed: 570 additions & 132 deletions
File tree
- tests
- transfer_queue
- storage
- managers
- utils
- tutorial
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
418 | 418 | | |
419 | 419 | | |
420 | 420 | | |
421 | | - | |
| 421 | + | |
422 | 422 | | |
423 | 423 | | |
424 | 424 | | |
| |||
511 | 511 | | |
512 | 512 | | |
513 | 513 | | |
514 | | - | |
| 514 | + | |
515 | 515 | | |
516 | 516 | | |
517 | 517 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
38 | 41 | | |
39 | 42 | | |
40 | 43 | | |
41 | | - | |
| 44 | + | |
42 | 45 | | |
43 | 46 | | |
44 | 47 | | |
| |||
434 | 437 | | |
435 | 438 | | |
436 | 439 | | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
| 503 | + | |
| 504 | + | |
| 505 | + | |
| 506 | + | |
| 507 | + | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
| 520 | + | |
| 521 | + | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
| 526 | + | |
| 527 | + | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
324 | 324 | | |
325 | 325 | | |
326 | 326 | | |
| 327 | + | |
327 | 328 | | |
328 | 329 | | |
329 | 330 | | |
| |||
342 | 343 | | |
343 | 344 | | |
344 | 345 | | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
345 | 356 | | |
346 | 357 | | |
347 | 358 | | |
| |||
411 | 422 | | |
412 | 423 | | |
413 | 424 | | |
414 | | - | |
| 425 | + | |
415 | 426 | | |
416 | 427 | | |
417 | 428 | | |
| |||
1279 | 1290 | | |
1280 | 1291 | | |
1281 | 1292 | | |
1282 | | - | |
| 1293 | + | |
| 1294 | + | |
| 1295 | + | |
| 1296 | + | |
| 1297 | + | |
1283 | 1298 | | |
1284 | 1299 | | |
1285 | 1300 | | |
| |||
1298 | 1313 | | |
1299 | 1314 | | |
1300 | 1315 | | |
| 1316 | + | |
| 1317 | + | |
| 1318 | + | |
| 1319 | + | |
| 1320 | + | |
| 1321 | + | |
| 1322 | + | |
| 1323 | + | |
| 1324 | + | |
| 1325 | + | |
1301 | 1326 | | |
1302 | 1327 | | |
1303 | 1328 | | |
| |||
1336 | 1361 | | |
1337 | 1362 | | |
1338 | 1363 | | |
1339 | | - | |
| 1364 | + | |
1340 | 1365 | | |
1341 | 1366 | | |
1342 | 1367 | | |
| |||
0 commit comments