Skip to content

Commit 4df106b

Browse files
authored
[feat] Provide user-defined custom_meta methods (Ascend#21)
## Background When integrating with upstream RL frameworks, it is often necessary to maintain some sample-level metadata alongside the batch data for computation. Typical examples include: - uid - session_id - trajectory_id - start_model_version/end_model_version - current_status ... This PR introduces a flexible **User-Defined Metadata** mechanism. Users can now attach arbitrary key-value pairs (`custom_meta`) to `BatchMeta` and synchronize them with the `TransferQueueController`. <img width="1774" height="465" alt="image" src="https://github.com/user-attachments/assets/c6a3971d-23d4-427f-b732-972a03d4324c" /> > Note: Different from `custom_meta` that maintains **sample-level** info, the `extra_info` stores **batch-level** info and will not be stored into the controller. ## Key Changes - **`BatchMeta` Enhancement**: Added interfaces to `BatchMeta` for setting and retrieving user-defined `custom_meta`. - Put `custom_meta`: - Explicit Sync: Metadata can be sent to the controller via `TransferQueueClient.set_custom_meta`. - Implicit Sync: Metadata is automatically stored when calling `TransferQueueClient.put` with given batch_meta input. - Get `custom_meta`: Users can query the stored `custom_meta` from the controller using `get_meta` with `mode="force_fetch"`. We refactor the previous meta for storage backend as `_custom_backend_meta`, which is a field-level information. ## Related API ### BatchMeta - `set_custom_meta(global_index:int, value:Any)` Insert or update a single key-value pair for a specific index. - `update_custom_meta(new_meta:dict[int, Any])` Batch update `custom_meta` using a dictionary. - `get_all_custom_meta() -> dict[int, Any]` Retrieve all stored `custom_meta`. ### TransferQueueClient - `(async_)set_custom_meta(metadata: BatchMeta)` Explicitly send `custom_meta` to the `TransferQueueController` - `(async_)put(data:TensorDict, metadata: BatchMeta)` Implicitly send `custom_meta` to the controller while putting data. - `(async_)get_meta(partition_id:str, mode="force_fetch")` Query and retrieve all `BatchMeta` alone with `custom_meta` stored in the controller (requires `mode="force_fetch"`). ## Usage Example Refer to `tutorial/02_metadata_concepts.py`. ## Other Changes 1. Add CI for `tutorial` scripts. 2. Improve docstring for `TransferQueueClient` ## TODO - [ ] Move `custom_meta` -> `SampleMeta`, `_custom_backend_meta` -> `FieldMeta` --- CC: @wuxibin89 @tianyi-ge --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent cd12c1e commit 4df106b

18 files changed

Lines changed: 1493 additions & 403 deletions
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3+
4+
name: Tutorial check
5+
6+
on:
7+
push:
8+
branches:
9+
- main
10+
- v0.*
11+
pull_request:
12+
branches:
13+
- main
14+
- v0.*
15+
16+
jobs:
17+
build:
18+
runs-on: ubuntu-latest
19+
timeout-minutes: 10
20+
strategy:
21+
fail-fast: false
22+
matrix:
23+
python-version: ["3.11"]
24+
25+
steps:
26+
- uses: actions/checkout@v4
27+
- name: Set up Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v3
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
- name: Install dependencies
32+
run: |
33+
python -m pip install --upgrade pip
34+
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
35+
pip install -e ".[yuanrong]"
36+
- name: Run tutorials
37+
run: |
38+
export TQ_NUM_THREADS=2
39+
for file in tutorial/*.py; do python3 "$file"; done

tests/test_client.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def _handle_requests(self):
134134
"partition_ids": ["partition_0", "partition_1", "test_partition"],
135135
}
136136
response_type = ZMQRequestType.LIST_PARTITIONS_RESPONSE
137+
elif request_msg.request_type == ZMQRequestType.SET_CUSTOM_META:
138+
response_body = {"message": "success"}
139+
response_type = ZMQRequestType.SET_CUSTOM_META_RESPONSE
137140
else:
138141
response_body = {"error": f"Unknown request type: {request_msg.request_type}"}
139142
response_type = ZMQRequestType.CLEAR_META_RESPONSE
@@ -774,3 +777,51 @@ async def test_sync_and_async_methods_mixed_usage(client_setup):
774777
assert async_data is not None
775778

776779
print("✓ Mixed async and sync method calls work correctly")
780+
781+
782+
# =====================================================
783+
# Custom Meta Interface Tests
784+
# =====================================================
785+
786+
787+
class TestClientCustomMetaInterface:
788+
"""Tests for client custom_meta interface methods."""
789+
790+
def test_set_custom_meta_sync(self, client_setup):
791+
"""Test synchronous set_custom_meta method."""
792+
client, _, _ = client_setup
793+
794+
# Test synchronous set_custom_meta
795+
796+
# First get metadata
797+
metadata = client.get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
798+
# Set custom_meta on the metadata
799+
metadata.update_custom_meta(
800+
{
801+
0: {"input_ids": {"token_count": 100}},
802+
1: {"input_ids": {"token_count": 120}},
803+
}
804+
)
805+
806+
# Call set_custom_meta with metadata (BatchMeta)
807+
client.set_custom_meta(metadata)
808+
print("✓ set_custom_meta sync method works")
809+
810+
@pytest.mark.asyncio
811+
async def test_set_custom_meta_async(self, client_setup):
812+
"""Test asynchronous async_set_custom_meta method."""
813+
client, _, _ = client_setup
814+
815+
# First get metadata
816+
metadata = await client.async_get_meta(data_fields=["input_ids"], batch_size=2, partition_id="0")
817+
# Set custom_meta on the metadata
818+
metadata.update_custom_meta(
819+
{
820+
0: {"input_ids": {"token_count": 100}},
821+
1: {"input_ids": {"token_count": 120}},
822+
}
823+
)
824+
825+
# Call async_set_custom_meta with metadata (BatchMeta)
826+
await client.async_set_custom_meta(metadata)
827+
print("✓ async_set_custom_meta async method works")

tests/test_controller.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_controller_with_single_partition(self, ray_setup):
8989
field_names=metadata.field_names,
9090
dtypes=dtypes,
9191
shapes=shapes,
92-
custom_meta=None,
92+
custom_backend_meta=None,
9393
)
9494
)
9595
assert success
@@ -450,3 +450,156 @@ def test_controller_clear_meta(self, ray_setup):
450450
assert set(partition_after.global_indexes) == set([4, 5, 7])
451451

452452
print("✓ Clear meta correct")
453+
454+
455+
class TestTransferQueueControllerCustomMeta:
456+
"""Integration tests for TransferQueueController custom_meta and custom_backend_meta methods.
457+
458+
Note: In this codebase:
459+
- custom_meta: per-sample metadata (simple key-value pairs per sample)
460+
- custom_backend_meta: per-sample per-field metadata (stored via update_production_status)
461+
"""
462+
463+
def test_controller_with_custom_meta(self, ray_setup):
464+
"""Test TransferQueueController with custom_backend_meta and custom_meta functionality"""
465+
466+
batch_size = 3
467+
partition_id = "custom_meta_test"
468+
469+
tq_controller = TransferQueueController.remote()
470+
471+
# Create metadata in insert mode
472+
data_fields = ["prompt_ids", "attention_mask"]
473+
metadata = ray.get(
474+
tq_controller.get_metadata.remote(
475+
data_fields=data_fields,
476+
batch_size=batch_size,
477+
partition_id=partition_id,
478+
mode="insert",
479+
)
480+
)
481+
482+
assert metadata.global_indexes == list(range(batch_size))
483+
484+
# Build custom_backend_meta (per-sample per-field metadata)
485+
custom_backend_meta = {
486+
0: {"prompt_ids": {"token_count": 100}, "attention_mask": {"mask_ratio": 0.1}},
487+
1: {"prompt_ids": {"token_count": 120}, "attention_mask": {"mask_ratio": 0.15}},
488+
2: {"prompt_ids": {"token_count": 90}, "attention_mask": {"mask_ratio": 0.12}},
489+
}
490+
491+
# Update production status with custom_backend_meta
492+
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes}
493+
shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes}
494+
success = ray.get(
495+
tq_controller.update_production_status.remote(
496+
partition_id=partition_id,
497+
global_indexes=metadata.global_indexes,
498+
field_names=metadata.field_names,
499+
dtypes=dtypes,
500+
shapes=shapes,
501+
custom_backend_meta=custom_backend_meta,
502+
)
503+
)
504+
assert success
505+
506+
# Get partition snapshot and verify custom_backend_meta is stored
507+
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
508+
assert partition is not None
509+
510+
# Verify custom_backend_meta via get_field_custom_backend_meta
511+
result = partition.get_field_custom_backend_meta(list(range(batch_size)), ["prompt_ids", "attention_mask"])
512+
assert len(result) == batch_size
513+
assert result[0]["prompt_ids"]["token_count"] == 100
514+
assert result[2]["attention_mask"]["mask_ratio"] == 0.12
515+
516+
print("✓ Controller set custom_backend_meta via update_production_status correct")
517+
518+
# Now set custom_meta (per-sample metadata)
519+
# Format: {partition_id: {global_index: custom_meta_dict}}
520+
custom_meta = {
521+
partition_id: {
522+
0: {"sample_score": 0.9, "quality": "high"},
523+
1: {"sample_score": 0.8, "quality": "medium"},
524+
# You can set partial samples with custom_meta.
525+
}
526+
}
527+
528+
# Verify set_custom_meta method exists and can be called
529+
ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=custom_meta))
530+
531+
# Verify via partition snapshot
532+
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
533+
result = partition.get_custom_meta([0, 1])
534+
assert 0 in result
535+
assert result[0]["sample_score"] == 0.9
536+
assert result[0]["quality"] == "high"
537+
assert 1 in result
538+
assert result[1]["sample_score"] == 0.8
539+
assert 2 not in result
540+
541+
# Init another partition
542+
new_partition_id = "custom_meta_test2"
543+
# Create metadata in insert mode
544+
data_fields = ["prompt_ids", "attention_mask"]
545+
new_metadata = ray.get(
546+
tq_controller.get_metadata.remote(
547+
data_fields=data_fields,
548+
batch_size=batch_size,
549+
partition_id=new_partition_id,
550+
mode="insert",
551+
)
552+
)
553+
554+
# Update production status
555+
dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in new_metadata.global_indexes}
556+
shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in new_metadata.global_indexes}
557+
success = ray.get(
558+
tq_controller.update_production_status.remote(
559+
partition_id=new_partition_id,
560+
global_indexes=new_metadata.global_indexes,
561+
field_names=new_metadata.field_names,
562+
dtypes=dtypes,
563+
shapes=shapes,
564+
custom_backend_meta=None,
565+
)
566+
)
567+
assert success
568+
569+
# Provide complicated case: update custom_meta with mixed partitions, and update previous custom_meta
570+
new_custom_meta = {
571+
new_partition_id: {
572+
3: {"sample_score": 1, "quality": "high"},
573+
4: {"sample_score": 0, "quality": "low"},
574+
},
575+
partition_id: {
576+
2: {"sample_score": 0.7, "quality": "high"},
577+
0: {"sample_score": 0.001, "quality": "low"},
578+
},
579+
}
580+
581+
# update with new_custom_meta
582+
ray.get(tq_controller.set_custom_meta.remote(partition_custom_meta=new_custom_meta))
583+
584+
# Verify via partition snapshot
585+
partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id))
586+
result = partition.get_custom_meta([0, 1, 2])
587+
assert 0 in result
588+
assert result[0]["sample_score"] == 0.001 # updated!
589+
assert result[0]["quality"] == "low" # updated!
590+
assert 1 in result # unchanged
591+
assert result[1]["sample_score"] == 0.8 # unchanged
592+
assert 2 in result # unchanged
593+
assert result[2]["sample_score"] == 0.7 # new
594+
595+
new_partition = ray.get(tq_controller.get_partition_snapshot.remote(new_partition_id))
596+
result = new_partition.get_custom_meta([3, 4, 5])
597+
assert 3 in result
598+
assert result[3]["sample_score"] == 1
599+
assert result[3]["quality"] == "high"
600+
assert 4 in result
601+
assert result[4]["sample_score"] == 0
602+
assert 5 not in result # 5 has no custom_meta, it will not return even we retrieve for 5
603+
604+
# Clean up
605+
ray.get(tq_controller.clear_partition.remote(partition_id))

0 commit comments

Comments
 (0)