@@ -234,15 +234,18 @@ def poll_for_meta(client, partition_id, data_fields, batch_size, task_name, mode
234234# Helper Functions for Data Verification
235235def verify_special_values (retrieved : torch .Tensor , expected : torch .Tensor ) -> bool :
236236 """Verify special values (NaN, Inf) are preserved."""
237- # Check Inf column
238- if not torch .all (torch .isinf (retrieved [:, 0 ]) & (retrieved [:, 0 ] > 0 )):
239- return False
240- # Check NaN column
241- if not torch .all (torch .isnan (retrieved [:, 1 ])):
242- return False
243- # Check regular values column
244- if not torch .allclose (retrieved [:, 2 ], expected [:, 2 ]):
237+ if len (retrieved ) != len (expected ):
245238 return False
239+ for r , e in zip (retrieved , expected , strict = True ):
240+ # Check Inf column
241+ if not (torch .isinf (r [0 ]) and r [0 ] > 0 ):
242+ return False
243+ # Check NaN column
244+ if not torch .isnan (r [1 ]):
245+ return False
246+ # Check regular values column
247+ if not torch .allclose (r [2 ], e [2 ]):
248+ return False
246249 return True
247250
248251
@@ -293,11 +296,17 @@ def verify_list_equal(retrieved, expected) -> bool:
293296 if isinstance (retrieved , NonTensorStack ):
294297 retrieved = retrieved .tolist ()
295298 elif isinstance (retrieved , torch .Tensor ):
296- retrieved = retrieved .reshape (- 1 ).tolist () # may get 2D tensor back using key-value based backend
299+ if retrieved .is_nested :
300+ retrieved = [t .item () for t in retrieved ]
301+ else :
302+ retrieved = retrieved .reshape (- 1 ).tolist () # may get 2D tensor back using key-value based backend
297303 if isinstance (expected , NonTensorStack ):
298304 expected = expected .tolist ()
299305 elif isinstance (expected , torch .Tensor ):
300- expected = expected .tolist ()
306+ if expected .is_nested :
307+ expected = [t .item () for t in expected ]
308+ else :
309+ expected = expected .tolist ()
301310 return retrieved == expected
302311
303312
@@ -317,14 +326,10 @@ def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict:
317326 items = field .tolist ()
318327 reordered_items = [items [i ] for i in order ]
319328 reordered [key ] = NonTensorStack (* reordered_items , batch_size = [len (order )])
320- elif hasattr (field , "unbind" ) :
321- items = field . unbind ( 0 )
329+ elif isinstance (field , torch . Tensor ) and field . is_nested :
330+ items = list ( field )
322331 reordered_items = [items [i ] for i in order ]
323- try :
324- reordered [key ] = torch .stack (reordered_items )
325- except (RuntimeError , TypeError ):
326- # RuntimeError: shape mismatch (jagged); TypeError: non-Tensor items
327- reordered [key ] = torch .nested .as_nested_tensor (reordered_items , layout = field .layout )
332+ reordered [key ] = torch .nested .as_nested_tensor (reordered_items , layout = field .layout )
328333 elif isinstance (field , list ):
329334 reordered [key ] = [field [i ] for i in order ]
330335 else :
@@ -365,11 +370,20 @@ def test_core_consistency(e2e_client):
365370 assert retrieved_meta is not None and retrieved_meta .size == batch_size , "Failed to retrieve metadata"
366371 retrieved_data = client .get_data (retrieved_meta )
367372
368- # 3. Verify Standard Tensors
369- assert torch .allclose (retrieved_data ["tensor_f32" ], original_data ["tensor_f32" ]), "tensor_f32 mismatch"
370- assert torch .equal (retrieved_data ["tensor_i64" ], original_data ["tensor_i64" ]), "tensor_i64 mismatch"
371- assert torch .equal (retrieved_data ["tensor_bf16" ], original_data ["tensor_bf16" ]), "tensor_bf16 mismatch"
372- assert torch .equal (retrieved_data ["tensor_f16" ], original_data ["tensor_f16" ]), "tensor_f16 mismatch"
373+ # 3. Verify Standard Tensors (may be returned as nested tensors)
374+ for i in range (batch_size ):
375+ assert torch .allclose (retrieved_data ["tensor_f32" ][i ], original_data ["tensor_f32" ][i ]), (
376+ f"tensor_f32 mismatch at index { i } "
377+ )
378+ assert torch .equal (retrieved_data ["tensor_i64" ][i ], original_data ["tensor_i64" ][i ]), (
379+ f"tensor_i64 mismatch at index { i } "
380+ )
381+ assert torch .equal (retrieved_data ["tensor_bf16" ][i ], original_data ["tensor_bf16" ][i ]), (
382+ f"tensor_bf16 mismatch at index { i } "
383+ )
384+ assert torch .equal (retrieved_data ["tensor_f16" ][i ], original_data ["tensor_f16" ][i ]), (
385+ f"tensor_f16 mismatch at index { i } "
386+ )
373387
374388 # 4. Verify Nested Tensors (Jagged)
375389 assert verify_nested_tensor_equal (retrieved_data ["nested_jagged" ], original_data ["nested_jagged" ]), (
@@ -386,12 +400,17 @@ def test_core_consistency(e2e_client):
386400 assert verify_list_equal (retrieved_data ["list_str" ], original_data ["list_str" ]), "list_str mismatch"
387401 assert verify_list_equal (retrieved_data ["list_obj" ], original_data ["list_obj" ]), "list_obj mismatch"
388402
389- # 7. Verify NumPy Arrays
390- assert np .allclose (retrieved_data ["np_array" ], original_data ["np_array" ]), "np_array mismatch"
403+ # 7. Verify NumPy Arrays (may be returned as nested tensors)
404+ for i in range (batch_size ):
405+ assert np .allclose (retrieved_data ["np_array" ][i ].numpy (), original_data ["np_array" ][i ]), (
406+ f"np_array mismatch at index { i } "
407+ )
391408
392409 # np_bytes_str: bytes string numpy via CUSTOM_TYPE_NUMPY path
393410 retrieved_bs = retrieved_data ["np_bytes_str" ]
394- if hasattr (retrieved_bs , "tolist" ):
411+ if isinstance (retrieved_bs , torch .Tensor ) and retrieved_bs .is_nested :
412+ retrieved_bs = [t .item () for t in retrieved_bs ]
413+ elif hasattr (retrieved_bs , "tolist" ):
395414 retrieved_bs = retrieved_bs .tolist ()
396415 expected_bs = original_data ["np_bytes_str" ]
397416 if hasattr (expected_bs , "tolist" ) and not isinstance (expected_bs , np .ndarray ):
@@ -400,7 +419,9 @@ def test_core_consistency(e2e_client):
400419
401420 # np_obj may be returned as NonTensorStack; normalize to list before comparing
402421 retrieved_np_obj = retrieved_data ["np_obj" ]
403- if hasattr (retrieved_np_obj , "tolist" ):
422+ if isinstance (retrieved_np_obj , torch .Tensor ) and retrieved_np_obj .is_nested :
423+ retrieved_np_obj = [t .item () for t in retrieved_np_obj ]
424+ elif hasattr (retrieved_np_obj , "tolist" ):
404425 retrieved_np_obj = retrieved_np_obj .tolist ()
405426 expected_np_obj = original_data ["np_obj" ]
406427 if hasattr (expected_np_obj , "tolist" ) and not isinstance (expected_np_obj , np .ndarray ):
@@ -490,21 +511,24 @@ def test_cross_shard_complex_update(e2e_client):
490511
491512 # 6. Verify region 0-9: original Put A values
492513 original_data_0_9 = generate_complex_data (list (range (0 , 10 )))
493- assert torch .allclose (full_data ["tensor_f32" ][:10 ], original_data_0_9 ["tensor_f32" ]), (
494- "Region 0-9 tensor_f32 should match original Put A"
495- )
514+ for i in range (10 ):
515+ assert torch .allclose (full_data ["tensor_f32" ][i ], original_data_0_9 ["tensor_f32" ][i ]), (
516+ f"Region 0-9 tensor_f32 mismatch at index { i } "
517+ )
496518
497519 # 7. Verify region 10-29: updated values (using offset indices 1010-1029)
498520 updated_expected = generate_complex_data ([i + 1000 for i in range (10 , 30 )])
499- assert torch .allclose (full_data ["tensor_f32" ][10 :30 ], updated_expected ["tensor_f32" ]), (
500- "Region 10-29 tensor_f32 should match updated values"
501- )
521+ for i in range (20 ):
522+ assert torch .allclose (full_data ["tensor_f32" ][10 + i ], updated_expected ["tensor_f32" ][i ]), (
523+ f"Region 10-29 tensor_f32 mismatch at index { 10 + i } "
524+ )
502525
503526 # 8. Verify region 30-39: original Put B values
504527 original_data_30_39 = generate_complex_data (list (range (30 , 40 )))
505- assert torch .allclose (full_data ["tensor_f32" ][30 :40 ], original_data_30_39 ["tensor_f32" ]), (
506- "Region 30-39 tensor_f32 should match original Put B"
507- )
528+ for i in range (10 ):
529+ assert torch .allclose (full_data ["tensor_f32" ][30 + i ], original_data_30_39 ["tensor_f32" ][i ]), (
530+ f"Region 30-39 tensor_f32 mismatch at index { 30 + i } "
531+ )
508532
509533 # 9. Verify new fields exist in update region (indices 10-29 only have new fields).
510534 # Build extended_meta from full_meta (which has valid _custom_backend_meta)
@@ -760,12 +784,13 @@ def test_dynamic_tensor_shape_nested_transition(e2e_client):
760784 meta1_put = client .put (data = data1 , partition_id = partition_id )
761785 assert meta1_put .size == 2
762786
763- # Poll and verify first batch is regular tensor
787+ # Poll and verify first batch (now returned as nested tensor by default)
764788 meta1 = poll_for_meta (client , partition_id , ["dynamic_feature" ], 2 , task_name , mode = "force_fetch" )
765789 assert not meta1 .field_schema ["dynamic_feature" ]["is_nested" ]
766790 retrieved_1 = client .get_data (meta1 )
767- assert not retrieved_1 ["dynamic_feature" ].is_nested
768- assert retrieved_1 ["dynamic_feature" ].shape == (2 , 4 )
791+ assert retrieved_1 ["dynamic_feature" ].is_nested
792+ assert len (retrieved_1 ["dynamic_feature" ]) == 2
793+ assert retrieved_1 ["dynamic_feature" ][0 ].shape == (4 ,)
769794
770795 # 2. Allocate 2 more slots via insert mode, put different-shape tensor (shape: (2, 6))
771796 alloc_meta2 = client .get_meta (
@@ -802,7 +827,7 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
802827 """Verify that all data types retrieved via GET are writable and memory-independent.
803828
804829 This test validates the ZMQ copy=False GET path (Plan 1):
805- - Tensors (f32, i64, bf16, f16): writable after torch.stack detaches from frame
830+ - Tensors (f32, i64, bf16, f16): writable after nested tensor creation
806831 - Nested tensors (jagged, strided): writable after as_nested_tensor
807832 - Numpy arrays (float64, bytes string): writable after .copy() in _pack_field_values
808833 - Modifications to retrieved data do not affect stored data (memory independence)
@@ -861,7 +886,7 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
861886 assert retrieved ["special_val" ][0 , 2 ].item () == 33333.0 , "special_val should be writable"
862887
863888 # 8. np_array: verify it's a tensor now (TensorDict auto-converts numeric numpy)
864- # If it's a tensor, writability is guaranteed by torch.stack
889+ # If it's a tensor, writability is guaranteed by nested tensor creation
865890 np_arr_retrieved = retrieved ["np_array" ]
866891 if isinstance (np_arr_retrieved , torch .Tensor ):
867892 np_arr_retrieved [0 , 0 ] = 22222.0
@@ -880,24 +905,28 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
880905 retrieved2 = client .get_data (meta2 )
881906
882907 # tensor_f32[0,0] should be the original value, not 99999.0
883- assert torch .allclose (retrieved2 ["tensor_f32" ], original_data ["tensor_f32" ]), (
884- "Modifying retrieved tensor_f32 should not affect stored data"
885- )
908+ for i in range (batch_size ):
909+ assert torch .allclose (retrieved2 ["tensor_f32" ][i ], original_data ["tensor_f32" ][i ]), (
910+ "Modifying retrieved tensor_f32 should not affect stored data"
911+ )
886912
887913 # tensor_i64[0,0] should be the original value, not 88888
888- assert torch .equal (retrieved2 ["tensor_i64" ], original_data ["tensor_i64" ]), (
889- "Modifying retrieved tensor_i64 should not affect stored data"
890- )
914+ for i in range (batch_size ):
915+ assert torch .equal (retrieved2 ["tensor_i64" ][i ], original_data ["tensor_i64" ][i ]), (
916+ "Modifying retrieved tensor_i64 should not affect stored data"
917+ )
891918
892919 # tensor_bf16 should match original
893- assert torch .equal (retrieved2 ["tensor_bf16" ], original_data ["tensor_bf16" ]), (
894- "Modifying retrieved tensor_bf16 should not affect stored data"
895- )
920+ for i in range (batch_size ):
921+ assert torch .equal (retrieved2 ["tensor_bf16" ][i ], original_data ["tensor_bf16" ][i ]), (
922+ "Modifying retrieved tensor_bf16 should not affect stored data"
923+ )
896924
897925 # tensor_f16 should match original
898- assert torch .equal (retrieved2 ["tensor_f16" ], original_data ["tensor_f16" ]), (
899- "Modifying retrieved tensor_f16 should not affect stored data"
900- )
926+ for i in range (batch_size ):
927+ assert torch .equal (retrieved2 ["tensor_f16" ][i ], original_data ["tensor_f16" ][i ]), (
928+ "Modifying retrieved tensor_f16 should not affect stored data"
929+ )
901930
902931 # nested_jagged should match original
903932 assert verify_nested_tensor_equal (retrieved2 ["nested_jagged" ], original_data ["nested_jagged" ]), (
0 commit comments