Skip to content

Commit d147a33

Browse files
authored
[BREAKING][fix] Use jagged tensor as default tensor type (#92)
## Background Previously, TransferQueue would try `torch.stack()` first when merging per-sample tensors into a batched tensordict for user retrieval. As a result, tensors with uniform shapes were returned as regular dense tensors, while jagged data fell back to nested tensors. This inconsistency forced downstream code to handle two distinct data types (torch.Tensor vs. nested tensor), adding unnecessary branching logic. ## Changes This PR changes the default aggregation strategy so that all tensor fields are returned as nested tensors by default, eliminating the `torch.stack()` fast-path. Specifically: 1. `KVStorageManager._merge_tensors_to_tensordict`: Removed the `torch.stack(chunk)` fallback. The new chain is as_nested_tensor(jagged) → nested_tensor(strided) → NonTensorStack. 2. `AsyncSimpleStorageManager._pack_field_values`: Removed the `torch.stack(values)` fast-path for uniform-shape tensors. The new in is as_nested_tensor(jagged) → as_nested_tensor(strided) → NonTensorStack, consistent with the KV backend. 3. Unified strided fallback: Added the missing strided layout fallback to `KVStorageManager`, ensuring both backends behave identically when jagged layout fails (e.g., for zero-dim tensors). 4. Docstring & comment cleanup: Updated all outdated docstrings and comments that referenced the old `torch.stack`-first behavior. ## Test updates - Adapted test_async_simple_storage_manager.py, test_kv_storage_manager.py, and e2e tests to accept nested tensors as the default return type. - Reworked the test_kv_storage_manager.py fixture to use realistic variable-length fields (input_ids, prompt_ids, response_ids, response_mask) aligned the single_controller_demo.py schema, replacing the oversimplified text/label/mask example. - Replaced all torch.equal(dense, nested) assertions with safe per-component comparisons (unbind(0) + torch.equal) to accommodate the new nested-tensor contract --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 05eb2aa commit d147a33

11 files changed

Lines changed: 669 additions & 444 deletions

tests/e2e/test_e2e_lifecycle_consistency.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,18 @@ def poll_for_meta(client, partition_id, data_fields, batch_size, task_name, mode
234234
# Helper Functions for Data Verification
235235
def verify_special_values(retrieved: torch.Tensor, expected: torch.Tensor) -> bool:
236236
"""Verify special values (NaN, Inf) are preserved."""
237-
# Check Inf column
238-
if not torch.all(torch.isinf(retrieved[:, 0]) & (retrieved[:, 0] > 0)):
239-
return False
240-
# Check NaN column
241-
if not torch.all(torch.isnan(retrieved[:, 1])):
242-
return False
243-
# Check regular values column
244-
if not torch.allclose(retrieved[:, 2], expected[:, 2]):
237+
if len(retrieved) != len(expected):
245238
return False
239+
for r, e in zip(retrieved, expected, strict=True):
240+
# Check Inf column
241+
if not (torch.isinf(r[0]) and r[0] > 0):
242+
return False
243+
# Check NaN column
244+
if not torch.isnan(r[1]):
245+
return False
246+
# Check regular values column
247+
if not torch.allclose(r[2], e[2]):
248+
return False
246249
return True
247250

248251

@@ -293,11 +296,17 @@ def verify_list_equal(retrieved, expected) -> bool:
293296
if isinstance(retrieved, NonTensorStack):
294297
retrieved = retrieved.tolist()
295298
elif isinstance(retrieved, torch.Tensor):
296-
retrieved = retrieved.reshape(-1).tolist() # may get 2D tensor back using key-value based backend
299+
if retrieved.is_nested:
300+
retrieved = [t.item() for t in retrieved]
301+
else:
302+
retrieved = retrieved.reshape(-1).tolist() # may get 2D tensor back using key-value based backend
297303
if isinstance(expected, NonTensorStack):
298304
expected = expected.tolist()
299305
elif isinstance(expected, torch.Tensor):
300-
expected = expected.tolist()
306+
if expected.is_nested:
307+
expected = [t.item() for t in expected]
308+
else:
309+
expected = expected.tolist()
301310
return retrieved == expected
302311

303312

@@ -317,14 +326,10 @@ def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict:
317326
items = field.tolist()
318327
reordered_items = [items[i] for i in order]
319328
reordered[key] = NonTensorStack(*reordered_items, batch_size=[len(order)])
320-
elif hasattr(field, "unbind"):
321-
items = field.unbind(0)
329+
elif isinstance(field, torch.Tensor) and field.is_nested:
330+
items = list(field)
322331
reordered_items = [items[i] for i in order]
323-
try:
324-
reordered[key] = torch.stack(reordered_items)
325-
except (RuntimeError, TypeError):
326-
# RuntimeError: shape mismatch (jagged); TypeError: non-Tensor items
327-
reordered[key] = torch.nested.as_nested_tensor(reordered_items, layout=field.layout)
332+
reordered[key] = torch.nested.as_nested_tensor(reordered_items, layout=field.layout)
328333
elif isinstance(field, list):
329334
reordered[key] = [field[i] for i in order]
330335
else:
@@ -365,11 +370,20 @@ def test_core_consistency(e2e_client):
365370
assert retrieved_meta is not None and retrieved_meta.size == batch_size, "Failed to retrieve metadata"
366371
retrieved_data = client.get_data(retrieved_meta)
367372

368-
# 3. Verify Standard Tensors
369-
assert torch.allclose(retrieved_data["tensor_f32"], original_data["tensor_f32"]), "tensor_f32 mismatch"
370-
assert torch.equal(retrieved_data["tensor_i64"], original_data["tensor_i64"]), "tensor_i64 mismatch"
371-
assert torch.equal(retrieved_data["tensor_bf16"], original_data["tensor_bf16"]), "tensor_bf16 mismatch"
372-
assert torch.equal(retrieved_data["tensor_f16"], original_data["tensor_f16"]), "tensor_f16 mismatch"
373+
# 3. Verify Standard Tensors (may be returned as nested tensors)
374+
for i in range(batch_size):
375+
assert torch.allclose(retrieved_data["tensor_f32"][i], original_data["tensor_f32"][i]), (
376+
f"tensor_f32 mismatch at index {i}"
377+
)
378+
assert torch.equal(retrieved_data["tensor_i64"][i], original_data["tensor_i64"][i]), (
379+
f"tensor_i64 mismatch at index {i}"
380+
)
381+
assert torch.equal(retrieved_data["tensor_bf16"][i], original_data["tensor_bf16"][i]), (
382+
f"tensor_bf16 mismatch at index {i}"
383+
)
384+
assert torch.equal(retrieved_data["tensor_f16"][i], original_data["tensor_f16"][i]), (
385+
f"tensor_f16 mismatch at index {i}"
386+
)
373387

374388
# 4. Verify Nested Tensors (Jagged)
375389
assert verify_nested_tensor_equal(retrieved_data["nested_jagged"], original_data["nested_jagged"]), (
@@ -386,12 +400,17 @@ def test_core_consistency(e2e_client):
386400
assert verify_list_equal(retrieved_data["list_str"], original_data["list_str"]), "list_str mismatch"
387401
assert verify_list_equal(retrieved_data["list_obj"], original_data["list_obj"]), "list_obj mismatch"
388402

389-
# 7. Verify NumPy Arrays
390-
assert np.allclose(retrieved_data["np_array"], original_data["np_array"]), "np_array mismatch"
403+
# 7. Verify NumPy Arrays (may be returned as nested tensors)
404+
for i in range(batch_size):
405+
assert np.allclose(retrieved_data["np_array"][i].numpy(), original_data["np_array"][i]), (
406+
f"np_array mismatch at index {i}"
407+
)
391408

392409
# np_bytes_str: bytes string numpy via CUSTOM_TYPE_NUMPY path
393410
retrieved_bs = retrieved_data["np_bytes_str"]
394-
if hasattr(retrieved_bs, "tolist"):
411+
if isinstance(retrieved_bs, torch.Tensor) and retrieved_bs.is_nested:
412+
retrieved_bs = [t.item() for t in retrieved_bs]
413+
elif hasattr(retrieved_bs, "tolist"):
395414
retrieved_bs = retrieved_bs.tolist()
396415
expected_bs = original_data["np_bytes_str"]
397416
if hasattr(expected_bs, "tolist") and not isinstance(expected_bs, np.ndarray):
@@ -400,7 +419,9 @@ def test_core_consistency(e2e_client):
400419

401420
# np_obj may be returned as NonTensorStack; normalize to list before comparing
402421
retrieved_np_obj = retrieved_data["np_obj"]
403-
if hasattr(retrieved_np_obj, "tolist"):
422+
if isinstance(retrieved_np_obj, torch.Tensor) and retrieved_np_obj.is_nested:
423+
retrieved_np_obj = [t.item() for t in retrieved_np_obj]
424+
elif hasattr(retrieved_np_obj, "tolist"):
404425
retrieved_np_obj = retrieved_np_obj.tolist()
405426
expected_np_obj = original_data["np_obj"]
406427
if hasattr(expected_np_obj, "tolist") and not isinstance(expected_np_obj, np.ndarray):
@@ -490,21 +511,24 @@ def test_cross_shard_complex_update(e2e_client):
490511

491512
# 6. Verify region 0-9: original Put A values
492513
original_data_0_9 = generate_complex_data(list(range(0, 10)))
493-
assert torch.allclose(full_data["tensor_f32"][:10], original_data_0_9["tensor_f32"]), (
494-
"Region 0-9 tensor_f32 should match original Put A"
495-
)
514+
for i in range(10):
515+
assert torch.allclose(full_data["tensor_f32"][i], original_data_0_9["tensor_f32"][i]), (
516+
f"Region 0-9 tensor_f32 mismatch at index {i}"
517+
)
496518

497519
# 7. Verify region 10-29: updated values (using offset indices 1010-1029)
498520
updated_expected = generate_complex_data([i + 1000 for i in range(10, 30)])
499-
assert torch.allclose(full_data["tensor_f32"][10:30], updated_expected["tensor_f32"]), (
500-
"Region 10-29 tensor_f32 should match updated values"
501-
)
521+
for i in range(20):
522+
assert torch.allclose(full_data["tensor_f32"][10 + i], updated_expected["tensor_f32"][i]), (
523+
f"Region 10-29 tensor_f32 mismatch at index {10 + i}"
524+
)
502525

503526
# 8. Verify region 30-39: original Put B values
504527
original_data_30_39 = generate_complex_data(list(range(30, 40)))
505-
assert torch.allclose(full_data["tensor_f32"][30:40], original_data_30_39["tensor_f32"]), (
506-
"Region 30-39 tensor_f32 should match original Put B"
507-
)
528+
for i in range(10):
529+
assert torch.allclose(full_data["tensor_f32"][30 + i], original_data_30_39["tensor_f32"][i]), (
530+
f"Region 30-39 tensor_f32 mismatch at index {30 + i}"
531+
)
508532

509533
# 9. Verify new fields exist in update region (indices 10-29 only have new fields).
510534
# Build extended_meta from full_meta (which has valid _custom_backend_meta)
@@ -760,12 +784,13 @@ def test_dynamic_tensor_shape_nested_transition(e2e_client):
760784
meta1_put = client.put(data=data1, partition_id=partition_id)
761785
assert meta1_put.size == 2
762786

763-
# Poll and verify first batch is regular tensor
787+
# Poll and verify first batch (now returned as nested tensor by default)
764788
meta1 = poll_for_meta(client, partition_id, ["dynamic_feature"], 2, task_name, mode="force_fetch")
765789
assert not meta1.field_schema["dynamic_feature"]["is_nested"]
766790
retrieved_1 = client.get_data(meta1)
767-
assert not retrieved_1["dynamic_feature"].is_nested
768-
assert retrieved_1["dynamic_feature"].shape == (2, 4)
791+
assert retrieved_1["dynamic_feature"].is_nested
792+
assert len(retrieved_1["dynamic_feature"]) == 2
793+
assert retrieved_1["dynamic_feature"][0].shape == (4,)
769794

770795
# 2. Allocate 2 more slots via insert mode, put different-shape tensor (shape: (2, 6))
771796
alloc_meta2 = client.get_meta(
@@ -802,7 +827,7 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
802827
"""Verify that all data types retrieved via GET are writable and memory-independent.
803828
804829
This test validates the ZMQ copy=False GET path (Plan 1):
805-
- Tensors (f32, i64, bf16, f16): writable after torch.stack detaches from frame
830+
- Tensors (f32, i64, bf16, f16): writable after nested tensor creation
806831
- Nested tensors (jagged, strided): writable after as_nested_tensor
807832
- Numpy arrays (float64, bytes string): writable after .copy() in _pack_field_values
808833
- Modifications to retrieved data do not affect stored data (memory independence)
@@ -861,7 +886,7 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
861886
assert retrieved["special_val"][0, 2].item() == 33333.0, "special_val should be writable"
862887

863888
# 8. np_array: verify it's a tensor now (TensorDict auto-converts numeric numpy)
864-
# If it's a tensor, writability is guaranteed by torch.stack
889+
# If it's a tensor, writability is guaranteed by nested tensor creation
865890
np_arr_retrieved = retrieved["np_array"]
866891
if isinstance(np_arr_retrieved, torch.Tensor):
867892
np_arr_retrieved[0, 0] = 22222.0
@@ -880,24 +905,28 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
880905
retrieved2 = client.get_data(meta2)
881906

882907
# tensor_f32[0,0] should be the original value, not 99999.0
883-
assert torch.allclose(retrieved2["tensor_f32"], original_data["tensor_f32"]), (
884-
"Modifying retrieved tensor_f32 should not affect stored data"
885-
)
908+
for i in range(batch_size):
909+
assert torch.allclose(retrieved2["tensor_f32"][i], original_data["tensor_f32"][i]), (
910+
"Modifying retrieved tensor_f32 should not affect stored data"
911+
)
886912

887913
# tensor_i64[0,0] should be the original value, not 88888
888-
assert torch.equal(retrieved2["tensor_i64"], original_data["tensor_i64"]), (
889-
"Modifying retrieved tensor_i64 should not affect stored data"
890-
)
914+
for i in range(batch_size):
915+
assert torch.equal(retrieved2["tensor_i64"][i], original_data["tensor_i64"][i]), (
916+
"Modifying retrieved tensor_i64 should not affect stored data"
917+
)
891918

892919
# tensor_bf16 should match original
893-
assert torch.equal(retrieved2["tensor_bf16"], original_data["tensor_bf16"]), (
894-
"Modifying retrieved tensor_bf16 should not affect stored data"
895-
)
920+
for i in range(batch_size):
921+
assert torch.equal(retrieved2["tensor_bf16"][i], original_data["tensor_bf16"][i]), (
922+
"Modifying retrieved tensor_bf16 should not affect stored data"
923+
)
896924

897925
# tensor_f16 should match original
898-
assert torch.equal(retrieved2["tensor_f16"], original_data["tensor_f16"]), (
899-
"Modifying retrieved tensor_f16 should not affect stored data"
900-
)
926+
for i in range(batch_size):
927+
assert torch.equal(retrieved2["tensor_f16"][i], original_data["tensor_f16"][i]), (
928+
"Modifying retrieved tensor_f16 should not affect stored data"
929+
)
901930

902931
# nested_jagged should match original
903932
assert verify_nested_tensor_equal(retrieved2["nested_jagged"], original_data["nested_jagged"]), (

tests/e2e/test_kv_interface_e2e.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,31 @@ def get_controller_partition(controller, partition_id: str):
176176

177177

178178
def assert_tensor_equal(tensor_a, tensor_b, msg=""):
179-
"""Assert two tensors are equal."""
180-
assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}"
179+
"""Assert two tensors are equal, handling nested vs dense comparisons."""
180+
if (isinstance(tensor_a, torch.Tensor) and tensor_a.is_nested) or (
181+
isinstance(tensor_b, torch.Tensor) and tensor_b.is_nested
182+
):
183+
seq_a = list(tensor_a)
184+
seq_b = list(tensor_b)
185+
assert len(seq_a) == len(seq_b), f"{msg} Length mismatch: {len(seq_a)} vs {len(seq_b)}"
186+
for t1, t2 in zip(seq_a, seq_b, strict=True):
187+
assert torch.equal(t1, t2), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}"
188+
else:
189+
assert torch.equal(tensor_a, tensor_b), f"{msg} Tensors are not equal: {tensor_a} vs {tensor_b}"
181190

182191

183192
def assert_tensor_close(tensor_a, tensor_b, rtol=1e-5, atol=1e-8, msg=""):
184-
"""Assert two tensors are close."""
185-
assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close"
193+
"""Assert two tensors are close, handling nested vs dense comparisons."""
194+
if (isinstance(tensor_a, torch.Tensor) and tensor_a.is_nested) or (
195+
isinstance(tensor_b, torch.Tensor) and tensor_b.is_nested
196+
):
197+
seq_a = list(tensor_a)
198+
seq_b = list(tensor_b)
199+
assert len(seq_a) == len(seq_b), f"{msg} Length mismatch: {len(seq_a)} vs {len(seq_b)}"
200+
for t1, t2 in zip(seq_a, seq_b, strict=True):
201+
assert torch.allclose(t1, t2, rtol=rtol, atol=atol), f"{msg} Tensors are not close"
202+
else:
203+
assert torch.allclose(tensor_a, tensor_b, rtol=rtol, atol=atol), f"{msg} Tensors are not close"
186204

187205

188206
def assert_nested_tensor_equal(nested_a, nested_b, msg=""):

tests/test_async_simple_storage_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -537,13 +537,12 @@ def test_regular_tensor_negative_stride_rejected(self):
537537
class TestPackFieldValues:
538538
"""Test _pack_field_values static method packing logic."""
539539

540-
def test_uniform_tensors_to_stack(self):
541-
"""Same-shape tensors → torch.stack."""
540+
def test_uniform_tensors_to_nested(self):
541+
"""Same-shape tensors → nested tensor (default)."""
542542
values = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])]
543543
result = AsyncSimpleStorageManager._pack_field_values(values) # type: ignore[attr-defined]
544544
assert isinstance(result, torch.Tensor)
545-
assert not result.is_nested
546-
assert result.shape == (2, 2)
545+
assert result.is_nested
547546

548547
def test_variable_length_tensors_to_nested(self):
549548
"""Different-shape tensors → nested tensor."""
@@ -560,7 +559,7 @@ def test_non_tensors_to_nontensorstack(self):
560559
assert result.tolist() == ["hello", "world"]
561560

562561
def test_mixed_tensors_and_none_to_nontensorstack(self):
563-
"""Mixed tensor + None values should stay as NonTensorStack (no stacking)."""
562+
"""Mixed tensor + None values should stay as NonTensorStack (no nested tensor)."""
564563
t0 = torch.tensor([1.0, 2.0])
565564
t2 = torch.tensor([3.0, 4.0])
566565
values = [t0, None, t2]

0 commit comments

Comments
 (0)