Skip to content

Commit 0945d28

Browse files
mpb159753ascend-robot
authored andcommitted
[fix,refactor] Complete columnar metadata refactor for manager→controller path
Co-authored-by: 看我72遍<m.pb@msn.com> # message auto-generated for no-merge-commit merge: !29 merge refactor/columnar-field-schema into main [fix,refactor] Complete columnar metadata refactor for manager→controller path Created-by: mpb159753 Commit-by: 看我72遍 Merged-by: ascend-robot Description: # Columnar FieldSchema + Unified Controller Metadata ## 1. Context & Motivation Follows: [Ascend#28 — Columnar BatchMeta + Zero-Copy Default](https://gitcode.com/Ascend/TransferQueue/pull/28) PR Ascend#39 converted `BatchMeta` from row-oriented to columnar layout, but two O(B×F) bottlenecks remained on the **Manager → Controller** path: 1. **`notify_data_update` payload**: The Manager expanded columnar `field_schema` back into per-sample dicts (`dtypes: {global_index: {field: dtype}}`, `shapes: {global_index: {field: shape}}`), transmitting O(B×F) data over ZMQ for information that is inherently O(F). 2. **Controller metadata storage**: `DataPartitionStatus` maintained three separate stores (`field_dtypes`, `field_shapes`, `field_schema_cache`) with redundant per-sample indexing, requiring multi-pass reconciliation logic to detect nested tensors. This PR completes the columnar refactoring by: - Transmitting `field_schema` directly as O(F) columnar data (no per-sample expansion) - Introducing `FieldColumnMeta` as the **single source of truth** for per-field metadata in the Controller - Adding `RoutingGroup` to carry batch positions alongside global indexes, eliminating intermediate mapping - Extracting `_pack_field_values` as a reusable static method with defensive checks ## 2. Key Changes ### 2.1 Columnar `notify_data_update` Protocol (`base.py`, `simple_backend_manager.py`) **Before** (O(B×F) expansion in Manager): ```python dtypes_for_notify = { global_index: {field_name: field_meta.get("dtype") for field_name, field_meta in field_schema.items()} for global_index in metadata.global_indexes } shapes_for_notify = { ... } # same pattern await self.notify_data_update(partition_id, field_names, global_indexes, dtypes_for_notify, shapes_for_notify) ``` **After** (O(F) — pass through as-is): ```python await self.notify_data_update(partition_id, global_indexes, field_schema) ``` - Removed `fields`, `dtypes`, `shapes` parameters - `field_schema` is already columnar from `metadata.py` — no expansion needed - KV path (`base.py`) similarly simplified, removing 25-line per-sample expansion loop ### 2.2 `FieldColumnMeta` Dataclass (`controller.py`) Replaces three separate stores (`field_dtypes`, `field_shapes`, `field_schema_cache`) with a single `@dataclass`: ```python @DataClass class FieldColumnMeta: dtype: Any = None shape: Optional[tuple] = None is_nested: bool = False is_non_tensor: bool = False per_sample_shapes: dict[int, tuple] = field(default_factory=dict) ``` - Field-level attributes are O(1) — shared across all samples - Sample-level shapes only stored for nested tensors — O(B_nested) not O(B) - `to_batch_schema()` generates `BatchMeta`-compatible dicts on demand - `remove_samples()` cleans up released indexes ### 2.3 `RoutingGroup` NamedTuple (`simple_backend_manager.py`) ```python class RoutingGroup(NamedTuple): global_indexes: list[int] batch_positions: list[int] ``` - `_group_by_hash` now returns `dict[str, RoutingGroup]` instead of `dict[str, list[int]]` - Carries both global indexes and batch positions, eliminating the intermediate `global_idx → position` mapping in `get_data` - GET merge logic simplified: scatter results directly to batch positions without building per-sample dicts ### 2.4 `_pack_field_values` Extraction (`simple_backend_manager.py`) Extracted inline packing logic into a reusable `@staticmethod` with explicit error handling: - Validates non-empty input and absence of `None` values - Handles regular tensors (`torch.stack`), nested tensors (`torch.nested.as_nested_tensor`), and non-tensors (`NonTensorStack`) ### 2.5 Simplified Controller API - `update_production_status`: Removed `field_names` and `dtypes`/`shapes` parameters; `field_names` derived from `field_schema.keys()` - `get_field_schema`: Delegates to `FieldColumnMeta.to_batch_schema()` instead of building from cache - Removed `get_field_dtype` and `get_field_shape` helper methods (no longer needed) ### 2.6 Test Suite - All test files updated to match new `notify_data_update` and `update_production_status` signatures - `test_controller_data_partitions.py`: Tests adapted for `FieldColumnMeta`-based schema storage ## 3. Benchmark Results Tests conducted in Docker (single-node Ray) across 7 payload sizes (0.05 MB → 25.4 GB). Three configurations compared: - **pre-refactor**: Baseline (row-oriented, before PR Ascend#39) - **columnar-batch-meta**: After PR Ascend#39 (columnar BatchMeta + zero-copy) - **columnar-field-schema**: This PR (columnar notify + FieldColumnMeta + RoutingGroup) ### Speedup (relative to pre-refactor baseline) ![image.png](https://raw.gitcode.com/user-images/assets/8886051/4c49b557-9d15-4298-9d5e-bd06e8ea05a6/image.png 'image.png') ![image.png](https://raw.gitcode.com/user-images/assets/8886051/8992bfb1-e5fc-4f06-9585-f72906c53863/image.png 'image.png') | Data Scale | PUT Speedup (vs baseline) | PUT Speedup (vs PR Ascend#39) | GET Speedup (vs baseline) | GET Speedup (vs PR Ascend#39) | |------------|:------------------------:|:-----------------------:|:------------------------:|:-----------------------:| | debug (0.05 MB) | **1.4×** | +12% | **1.5×** | +16% | | tiny (1.5 MB) | **1.8×** | +19% | **2.1×** | +13% | | small (0.15 GB) | **5.1×** | +20% | **3.4×** | ≈0% | | medium (1.5 GB) | **5.8×** | +7% | **2.2×** | −1% | | large (6.3 GB) | **5.6×** | +8% | **2.0×** | −4% | | xlarge (12.7 GB) | **5.5×** | +8% | **2.2×** | +1% | | huge (25.4 GB) | **5.4×** | +6% | **2.2×** | +1% | ### Absolute Bandwidth ![image.png](https://raw.gitcode.com/user-images/assets/8886051/05b789cc-f4aa-4a5a-833b-55617cd3a673/image.png 'image.png') ![image.png](https://raw.gitcode.com/user-images/assets/8886051/e2f927cd-5556-46af-bf7b-71e451752c11/image.png 'image.png') | Data Scale | Pre-Refactor | Columnar BatchMeta (PR Ascend#39) | Columnar FieldSchema (This PR) | |------------|:-----------:|:---------------------------:|:------------------------------:| | **PUT** medium | 3.95 Gbps | 21.29 Gbps | **22.84 Gbps** | | **PUT** large | 5.04 Gbps | 26.14 Gbps | **28.18 Gbps** | | **PUT** huge | 5.09 Gbps | 26.05 Gbps | **27.49 Gbps** | | **GET** medium | 4.24 Gbps | 9.50 Gbps | **9.39 Gbps** | | **GET** large | 4.98 Gbps | 10.51 Gbps | **10.14 Gbps** | | **GET** huge | 4.86 Gbps | 10.46 Gbps | **10.53 Gbps** | ### Summary - **PUT path** benefits most: +6% to +20% over PR Ascend#39 across all scales, consistent 5×+ improvement over pre-refactor baseline at medium+ scales - **GET path** maintains parity with PR Ascend#39 — improvements are within noise margin; the GET bottleneck is in ZMQ transport, not metadata - Small payloads see the largest relative improvement, confirming the metadata overhead reduction ### Resource Usage Memory usage is comparable or slightly reduced (eliminated per-sample `field_dtypes`/`field_shapes` dicts in Controller). ## 4. API Breaking Changes - `notify_data_update()`: Removed `fields`, `dtypes`, `shapes` parameters; replaced with single `field_schema` dict - `update_production_status()`: Removed `field_names`, `dtypes`, `shapes` parameters; replaced with single `field_schema` dict; `field_names` derived from `field_schema.keys()` - `get_field_dtype()` / `get_field_shape()`: Removed (replaced by `FieldColumnMeta`) - `_group_by_hash()`: Now returns `dict[str, RoutingGroup]` instead of `dict[str, list[int]]` ## 5. Files Changed ``` 7 files changed, 451 insertions(+), 440 deletions(-) ``` | File | Description | |------|-------------| | `controller.py` | `FieldColumnMeta` dataclass; simplified `update_production_status` / `get_field_schema`; removed `get_field_dtype`/`get_field_shape` | | `simple_backend_manager.py` | `RoutingGroup`; `_pack_field_values`; position-based GET merge; columnar `notify_data_update` | | `base.py` | Columnar `notify_data_update` protocol; simplified KV path | | `test_controller.py` | Adapted to new API signatures | | `test_controller_data_partitions.py` | Adapted to `FieldColumnMeta`-based schema | | `test_async_simple_storage_manager.py` | Adapted to `RoutingGroup` and new notify protocol | | `test_kv_storage_manager.py` | Minor signature update | ## 6. Conclusion This PR completes the second phase of columnar refactoring by eliminating the remaining O(B×F) metadata expansion in the Manager→Controller path and unifying metadata storage in the Controller: - **PUT throughput**: Up to 5.8× over pre-refactor baseline, +6–20% over PR Ascend#39 - **GET throughput**: Up to 3.4× over pre-refactor baseline, parity with PR Ascend#39 - **Code clarity**: Three separate metadata stores → one `FieldColumnMeta` dataclass; per-sample expansion loops eliminated - **Net change**: +451 / −440 lines across 7 files > **Note on GET path**: The GET path performance improvement from metadata-level refactoring has reached diminishing returns — the minor fluctuations (±1–4%) observed in benchmarks are within normal measurement noise. Further GET throughput gains would likely require a deeper architectural change: fully columnarizing the GET data flow itself (e.g., columnar storage layout in StorageUnit, field-level parallel retrieval), rather than continuing to optimize the metadata layer. See merge request: Ascend/TransferQueue!29
1 parent 6ad4d07 commit 0945d28

9 files changed

Lines changed: 688 additions & 497 deletions

tests/e2e/test_e2e_lifecycle_consistency.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@
3535
"tensor_f32",
3636
"tensor_i64",
3737
"tensor_bf16",
38+
"tensor_f16",
3839
"nested_jagged",
3940
"nested_strided",
4041
"list_int",
4142
"list_str",
4243
"list_obj",
4344
"np_array",
45+
"np_bytes_str",
4446
"np_obj",
4547
"special_val",
4648
"non_tensor_stack",
@@ -112,6 +114,7 @@ def generate_complex_data(indices: list[int]) -> TensorDict:
112114

113115
# NumPy Arrays
114116
np_array = np.array([np.arange(i, i + 3) for i in indices], dtype=np.float64)
117+
np_bytes_str = np.array([f"bs_{i}".encode() for i in indices], dtype="|S10")
115118
np_obj = np.array([f"obj_{i}" for i in indices], dtype=object)
116119

117120
# Special Values (NaN and Inf)
@@ -127,19 +130,24 @@ def generate_complex_data(indices: list[int]) -> TensorDict:
127130
# BFloat16 Tensor
128131
tensor_bf16 = torch.stack([torch.arange(i, i + 5, dtype=torch.bfloat16) for i in indices])
129132

133+
# Float16 Tensor
134+
tensor_f16 = torch.stack([torch.arange(i, i + 5, dtype=torch.float16) for i in indices])
135+
130136
# List of objects (dicts)
131137
list_obj = [{"key": f"value_{i}", "num": i} for i in indices]
132138

133139
field_values = {
134140
"tensor_f32": tensor_f32,
135141
"tensor_i64": tensor_i64,
136142
"tensor_bf16": tensor_bf16,
143+
"tensor_f16": tensor_f16,
137144
"nested_jagged": nested_jagged,
138145
"nested_strided": nested_strided,
139146
"list_int": list_int,
140147
"list_str": list_str,
141148
"list_obj": list_obj,
142149
"np_array": np_array,
150+
"np_bytes_str": np_bytes_str,
143151
"np_obj": np_obj,
144152
"special_val": special_val,
145153
"non_tensor_stack": non_tensor_stack,
@@ -300,6 +308,7 @@ def test_core_consistency(e2e_client):
300308
assert torch.allclose(retrieved_data["tensor_f32"], original_data["tensor_f32"]), "tensor_f32 mismatch"
301309
assert torch.equal(retrieved_data["tensor_i64"], original_data["tensor_i64"]), "tensor_i64 mismatch"
302310
assert torch.equal(retrieved_data["tensor_bf16"], original_data["tensor_bf16"]), "tensor_bf16 mismatch"
311+
assert torch.equal(retrieved_data["tensor_f16"], original_data["tensor_f16"]), "tensor_f16 mismatch"
303312

304313
# 4. Verify Nested Tensors (Jagged)
305314
assert verify_nested_tensor_equal(retrieved_data["nested_jagged"], original_data["nested_jagged"]), (
@@ -318,6 +327,16 @@ def test_core_consistency(e2e_client):
318327

319328
# 7. Verify NumPy Arrays
320329
assert np.allclose(retrieved_data["np_array"], original_data["np_array"]), "np_array mismatch"
330+
331+
# np_bytes_str: bytes string numpy via CUSTOM_TYPE_NUMPY path
332+
retrieved_bs = retrieved_data["np_bytes_str"]
333+
if hasattr(retrieved_bs, "tolist"):
334+
retrieved_bs = retrieved_bs.tolist()
335+
expected_bs = original_data["np_bytes_str"]
336+
if hasattr(expected_bs, "tolist") and not isinstance(expected_bs, np.ndarray):
337+
expected_bs = expected_bs.tolist()
338+
assert list(retrieved_bs) == list(expected_bs), "np_bytes_str mismatch"
339+
321340
# np_obj may be returned as NonTensorStack; normalize to list before comparing
322341
retrieved_np_obj = retrieved_data["np_obj"]
323342
if hasattr(retrieved_np_obj, "tolist"):
@@ -430,7 +449,12 @@ def test_cross_shard_complex_update(e2e_client):
430449
i for i, global_index in enumerate(full_meta.global_indexes) if global_index in update_gis
431450
]
432451
update_meta_with_backend = full_meta.select_samples(update_positions_in_full)
433-
extended_meta = update_meta_with_backend.with_data_fields(
452+
# Populate empty schema for fields not yet in field_schema so select_fields can include them
453+
for f in ["new_extra_tensor", "new_extra_non_tensor"]:
454+
if f not in update_meta_with_backend.field_schema:
455+
update_meta_with_backend.field_schema[f] = {}
456+
update_meta_with_backend._field_names = sorted(update_meta_with_backend.field_schema.keys())
457+
extended_meta = update_meta_with_backend.select_fields(
434458
base_fields + ["new_extra_tensor", "new_extra_non_tensor"]
435459
)
436460
update_region_data = client.get_data(extended_meta)
@@ -702,5 +726,103 @@ def test_dynamic_tensor_shape_nested_transition(e2e_client):
702726
client.clear_partition(partition_id)
703727

704728

729+
# Scenario Seven: Retrieved Data Writability and Memory Safety
730+
def test_retrieved_data_writability_and_memory_safety(e2e_client):
731+
"""Verify that all data types retrieved via GET are writable and memory-independent.
732+
733+
This test validates the ZMQ copy=False GET path (Plan 1):
734+
- Tensors (f32, i64, bf16, f16): writable after torch.stack detaches from frame
735+
- Nested tensors (jagged, strided): writable after as_nested_tensor
736+
- Numpy arrays (float64, bytes string): writable after .copy() in _pack_field_values
737+
- Modifications to retrieved data do not affect stored data (memory independence)
738+
"""
739+
client = e2e_client
740+
partition_id = "test_writability"
741+
batch_size = 8
742+
task_name = "writability_task"
743+
fields = DEFAULT_FIELDS
744+
745+
indices = list(range(batch_size))
746+
original_data = generate_complex_data(indices)
747+
client.put(data=original_data, partition_id=partition_id)
748+
749+
try:
750+
# === Phase 1: Retrieve and verify writability ===
751+
meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
752+
assert meta is not None and meta.size == batch_size
753+
retrieved = client.get_data(meta)
754+
755+
# 1. tensor_f32: writable
756+
retrieved["tensor_f32"][0, 0] = 99999.0
757+
assert retrieved["tensor_f32"][0, 0].item() == 99999.0, "tensor_f32 should be writable"
758+
759+
# 2. tensor_i64: writable
760+
retrieved["tensor_i64"][0, 0] = 88888
761+
assert retrieved["tensor_i64"][0, 0].item() == 88888, "tensor_i64 should be writable"
762+
763+
# 3. tensor_bf16: writable
764+
retrieved["tensor_bf16"][0, 0] = 77.0
765+
assert retrieved["tensor_bf16"][0, 0].item() == 77.0, "tensor_bf16 should be writable"
766+
767+
# 4. tensor_f16: writable
768+
retrieved["tensor_f16"][0, 0] = 66.0
769+
assert retrieved["tensor_f16"][0, 0].item() == 66.0, "tensor_f16 should be writable"
770+
771+
# 5. nested_jagged: writable via values()
772+
jagged_vals = retrieved["nested_jagged"].values()
773+
jagged_vals[0] = 55555.0
774+
assert jagged_vals[0].item() == 55555.0, "nested_jagged should be writable"
775+
776+
# 6. nested_strided: writable via unbind
777+
strided_subs = list(retrieved["nested_strided"].unbind())
778+
strided_subs[0][0, 0] = 44444.0
779+
assert strided_subs[0][0, 0].item() == 44444.0, "nested_strided should be writable"
780+
781+
# 7. special_val (tensor with NaN/Inf): writable
782+
retrieved["special_val"][0, 2] = 33333.0
783+
assert retrieved["special_val"][0, 2].item() == 33333.0, "special_val should be writable"
784+
785+
# 8. np_array: verify it's a tensor now (TensorDict auto-converts numeric numpy)
786+
# If it's a tensor, writability is guaranteed by torch.stack
787+
np_arr_retrieved = retrieved["np_array"]
788+
if isinstance(np_arr_retrieved, torch.Tensor):
789+
np_arr_retrieved[0, 0] = 22222.0
790+
assert np_arr_retrieved[0, 0].item() == 22222.0, "np_array (as tensor) should be writable"
791+
792+
# === Phase 2: Verify memory independence ===
793+
# Re-retrieve the same data — modifications above should NOT have affected storage
794+
meta2 = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
795+
assert meta2 is not None and meta2.size == batch_size
796+
retrieved2 = client.get_data(meta2)
797+
798+
# tensor_f32[0,0] should be the original value, not 99999.0
799+
assert torch.allclose(retrieved2["tensor_f32"], original_data["tensor_f32"]), (
800+
"Modifying retrieved tensor_f32 should not affect stored data"
801+
)
802+
803+
# tensor_i64[0,0] should be the original value, not 88888
804+
assert torch.equal(retrieved2["tensor_i64"], original_data["tensor_i64"]), (
805+
"Modifying retrieved tensor_i64 should not affect stored data"
806+
)
807+
808+
# tensor_bf16 should match original
809+
assert torch.equal(retrieved2["tensor_bf16"], original_data["tensor_bf16"]), (
810+
"Modifying retrieved tensor_bf16 should not affect stored data"
811+
)
812+
813+
# tensor_f16 should match original
814+
assert torch.equal(retrieved2["tensor_f16"], original_data["tensor_f16"]), (
815+
"Modifying retrieved tensor_f16 should not affect stored data"
816+
)
817+
818+
# nested_jagged should match original
819+
assert verify_nested_tensor_equal(retrieved2["nested_jagged"], original_data["nested_jagged"]), (
820+
"Modifying retrieved nested_jagged should not affect stored data"
821+
)
822+
823+
finally:
824+
client.clear_partition(partition_id)
825+
826+
705827
if __name__ == "__main__":
706828
sys.exit(pytest.main(["-v", __file__]))

tests/test_async_simple_storage_manager.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pytest_asyncio
2323
import torch
2424
import zmq
25-
from tensordict import TensorDict
25+
from tensordict import NonTensorStack, TensorDict
2626

2727
# Setup path
2828
parent_dir = Path(__file__).resolve().parent.parent
@@ -380,26 +380,32 @@ async def test_hash_routing_stable_across_batch_sizes():
380380

381381
# Build per-index mapping from the full-batch result
382382
idx_to_su_full: dict[int, str] = {}
383-
for su_id, gi_list in full_routing.items():
384-
for gi in gi_list:
383+
for su_id, group in full_routing.items():
384+
for gi in group.global_indexes:
385385
idx_to_su_full[gi] = su_id
386386

387387
# Route as two batches of 5
388388
batch_a_routing = manager._group_by_hash(all_indexes[:5])
389389
batch_b_routing = manager._group_by_hash(all_indexes[5:])
390390

391391
idx_to_su_split: dict[int, str] = {}
392-
for su_id, gi_list in batch_a_routing.items():
393-
for gi in gi_list:
392+
for su_id, group in batch_a_routing.items():
393+
for gi in group.global_indexes:
394394
idx_to_su_split[gi] = su_id
395-
for su_id, gi_list in batch_b_routing.items():
396-
for gi in gi_list:
395+
for su_id, group in batch_b_routing.items():
396+
for gi in group.global_indexes:
397397
idx_to_su_split[gi] = su_id
398398

399399
assert idx_to_su_full == idx_to_su_split, (
400400
f"Routing differs between full batch and split batches:\n full: {idx_to_su_full}\n split: {idx_to_su_split}"
401401
)
402402

403+
# Verify RoutingGroup carries correct batch_positions alongside global_indexes
404+
for su_id, group in full_routing.items():
405+
assert len(group.global_indexes) == len(group.batch_positions)
406+
for gi, pos in zip(group.global_indexes, group.batch_positions, strict=False):
407+
assert all_indexes[pos] == gi
408+
403409

404410
@pytest.mark.asyncio
405411
async def test_hash_routing_stable_reversed_order():
@@ -439,9 +445,71 @@ async def test_hash_routing_stable_reversed_order():
439445
# Build per-index mapping
440446
def _to_idx_map(routing):
441447
m = {}
442-
for su_id, gi_list in routing.items():
443-
for gi in gi_list:
448+
for su_id, group in routing.items():
449+
for gi in group.global_indexes:
444450
m[gi] = su_id
445451
return m
446452

447453
assert _to_idx_map(routing_fwd) == _to_idx_map(routing_rev), "Hash routing should be order-independent"
454+
455+
456+
class TestSelectByPositions:
457+
"""Test _select_by_positions static method for all field types."""
458+
459+
def test_regular_tensor(self):
460+
t = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
461+
result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2])
462+
assert torch.equal(result, torch.tensor([[1.0, 2.0], [5.0, 6.0]]))
463+
464+
def test_nested_tensor(self):
465+
t = torch.nested.as_nested_tensor(
466+
[torch.tensor([1.0]), torch.tensor([2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])],
467+
layout=torch.jagged,
468+
)
469+
result = AsyncSimpleStorageManager._select_by_positions(t, [0, 2])
470+
assert isinstance(result, list)
471+
assert len(result) == 2
472+
assert torch.equal(result[0], torch.tensor([1.0]))
473+
assert torch.equal(result[1], torch.tensor([4.0, 5.0, 6.0]))
474+
475+
def test_non_tensor_stack(self):
476+
nts = NonTensorStack("a", "b", "c")
477+
result = AsyncSimpleStorageManager._select_by_positions(nts, [1, 2])
478+
assert isinstance(result, NonTensorStack)
479+
assert result.tolist() == ["b", "c"]
480+
481+
def test_list(self):
482+
data = [{"x": 1}, {"x": 2}, {"x": 3}]
483+
result = AsyncSimpleStorageManager._select_by_positions(data, [0, 2])
484+
assert result == [{"x": 1}, {"x": 3}]
485+
486+
def test_numpy_array(self):
487+
arr = np.array([10, 20, 30])
488+
result = AsyncSimpleStorageManager._select_by_positions(arr, [0, 2])
489+
np.testing.assert_array_equal(result, np.array([10, 30]))
490+
491+
492+
class TestPackFieldValues:
493+
"""Test _pack_field_values static method packing logic."""
494+
495+
def test_uniform_tensors_to_stack(self):
496+
"""Same-shape tensors → torch.stack."""
497+
values = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]
498+
result = AsyncSimpleStorageManager._pack_field_values(values)
499+
assert isinstance(result, torch.Tensor)
500+
assert not result.is_nested
501+
assert result.shape == (2, 2)
502+
503+
def test_variable_length_tensors_to_nested(self):
504+
"""Different-shape tensors → nested tensor."""
505+
values = [torch.tensor([1.0]), torch.tensor([2.0, 3.0])]
506+
result = AsyncSimpleStorageManager._pack_field_values(values)
507+
assert isinstance(result, torch.Tensor)
508+
assert result.is_nested
509+
510+
def test_non_tensors_to_nontensorstack(self):
511+
"""Non-tensor values → NonTensorStack."""
512+
values = ["hello", "world"]
513+
result = AsyncSimpleStorageManager._pack_field_values(values)
514+
assert isinstance(result, NonTensorStack)
515+
assert result.tolist() == ["hello", "world"]

0 commit comments

Comments
 (0)