@@ -34,11 +34,14 @@ def __init__(self, storage_put_get_address):
3434 self .socket .setsockopt (zmq .RCVTIMEO , 5000 ) # 5 second timeout
3535 self .socket .connect (storage_put_get_address )
3636
37- def send_put (self , client_id , global_indexes , field_data ):
37+ def send_put (self , client_id , global_indexes , field_data , data_parser = None ):
38+ body = {"global_indexes" : global_indexes , "data" : field_data }
39+ if data_parser is not None :
40+ body ["data_parser" ] = data_parser
3841 msg = ZMQMessage .create (
3942 request_type = ZMQRequestType .PUT_DATA ,
4043 sender_id = f"mock_client_{ client_id } " ,
41- body = { "global_indexes" : global_indexes , "data" : field_data } ,
44+ body = body ,
4245 )
4346 self .socket .send_multipart (msg .serialize ())
4447 return ZMQMessage .deserialize (self .socket .recv_multipart (copy = False ))
@@ -434,3 +437,177 @@ def test_storage_unit_data_capacity_uses_active_keys():
434437 assert len (storage ._active_keys ) == 2
435438 storage .put_data ({"f" : [4 ]}, global_indexes = [3 ])
436439 assert storage ._active_keys == {0 , 1 , 3 }
440+
441+
442+ def test_storage_unit_data_parser (storage_setup ):
443+ """Test data_parser functionality in SimpleStorageUnit.
444+
445+ Writes two columns:
446+ - normal_data: regular tensors, should remain unchanged
447+ - data_to_be_parsed: list of shape descriptors (list of ints)
448+
449+ data_parser converts shape descriptors into random tensors of those shapes.
450+ """
451+ _ , put_get_address = storage_setup
452+ client = MockStorageClient (put_get_address )
453+
454+ def create_data_by_shape_parser (field_data ):
455+ if "data_to_be_parsed" in field_data :
456+ shapes = field_data ["data_to_be_parsed" ]
457+ field_data ["data_to_be_parsed" ] = [torch .randn (shape ) for shape in shapes ]
458+ return field_data
459+
460+ # Prepare data: normal_data is a batch tensor, data_to_be_parsed is a list of shape lists
461+ field_data = {
462+ "normal_data" : torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ]]),
463+ "data_to_be_parsed" : [[2 , 3 ], [1 , 4 ], [3 , 2 ]],
464+ }
465+ global_indexes = [0 , 1 , 2 ]
466+
467+ # Put with data_parser
468+ response = client .send_put (0 , global_indexes , field_data , data_parser = create_data_by_shape_parser )
469+ assert response .request_type == ZMQRequestType .PUT_DATA_RESPONSE , f"Put failed: { response .body } "
470+
471+ # Get back
472+ response = client .send_get (0 , global_indexes , ["normal_data" , "data_to_be_parsed" ])
473+ assert response .request_type == ZMQRequestType .GET_DATA_RESPONSE
474+
475+ result = response .body ["data" ]
476+
477+ # Verify normal_data is unchanged
478+ torch .testing .assert_close (result ["normal_data" ][0 ], torch .tensor ([1.0 , 2.0 ]))
479+ torch .testing .assert_close (result ["normal_data" ][1 ], torch .tensor ([3.0 , 4.0 ]))
480+ torch .testing .assert_close (result ["normal_data" ][2 ], torch .tensor ([5.0 , 6.0 ]))
481+
482+ # Verify data_to_be_parsed shapes match the input shape descriptors
483+ expected_shapes = [(2 , 3 ), (1 , 4 ), (3 , 2 )]
484+ for i , expected_shape in enumerate (expected_shapes ):
485+ actual_shape = tuple (result ["data_to_be_parsed" ][i ].shape )
486+ assert actual_shape == expected_shape , (
487+ f"Shape mismatch at index { i } : expected { expected_shape } , got { actual_shape } "
488+ )
489+
490+ client .close ()
491+
492+
493+ def test_storage_unit_data_parser_callable_types (storage_setup ):
494+ """Test that various callable types (partial, callable class) work as data_parser."""
495+ _ , put_get_address = storage_setup
496+ client = MockStorageClient (put_get_address )
497+
498+ from functools import partial
499+
500+ # 1. Test functools.partial
501+ def _partial_parser (field_data , prefix ):
502+ if "text" in field_data :
503+ field_data ["text" ] = [f"{ prefix } { t } " for t in field_data ["text" ]]
504+ return field_data
505+
506+ partial_parser = partial (_partial_parser , prefix = "parsed_" )
507+
508+ response = client .send_put (
509+ 0 ,
510+ [0 , 1 ],
511+ {"text" : ["a" , "b" ]},
512+ data_parser = partial_parser ,
513+ )
514+ assert response .request_type == ZMQRequestType .PUT_DATA_RESPONSE , f"partial parser failed: { response .body } "
515+
516+ response = client .send_get (0 , [0 , 1 ], ["text" ])
517+ assert response .request_type == ZMQRequestType .GET_DATA_RESPONSE
518+ assert response .body ["data" ]["text" ] == ["parsed_a" , "parsed_b" ]
519+
520+ # 2. Test callable class instance
521+ class CallableParser :
522+ def __call__ (self , field_data ):
523+ if "value" in field_data :
524+ field_data ["value" ] = [v * 2 for v in field_data ["value" ]]
525+ return field_data
526+
527+ callable_parser = CallableParser ()
528+ response = client .send_put (
529+ 0 ,
530+ [2 , 3 ],
531+ {"value" : [1 , 2 ]},
532+ data_parser = callable_parser ,
533+ )
534+ assert response .request_type == ZMQRequestType .PUT_DATA_RESPONSE , f"callable class parser failed: { response .body } "
535+
536+ response = client .send_get (0 , [2 , 3 ], ["value" ])
537+ assert response .request_type == ZMQRequestType .GET_DATA_RESPONSE
538+ assert response .body ["data" ]["value" ] == [2 , 4 ]
539+
540+ client .close ()
541+
542+
543+ def test_storage_unit_data_parser_validation (storage_setup ):
544+ """Test that invalid data_parser inputs produce clear error messages."""
545+ _ , put_get_address = storage_setup
546+ client = MockStorageClient (put_get_address )
547+
548+ # 1. Non-callable data_parser should return a clear TypeError
549+ response = client .send_put (
550+ 0 ,
551+ [0 ],
552+ {"data" : [1 ]},
553+ data_parser = "not_callable" ,
554+ )
555+ assert response .request_type == ZMQRequestType .PUT_ERROR
556+ assert "data_parser must be callable" in response .body ["message" ]
557+
558+ # 2. data_parser returning non-dict should return a clear TypeError
559+ def bad_parser (field_data ):
560+ return "not_a_dict"
561+
562+ response = client .send_put (
563+ 0 ,
564+ [1 ],
565+ {"data" : [1 ]},
566+ data_parser = bad_parser ,
567+ )
568+ assert response .request_type == ZMQRequestType .PUT_ERROR
569+ assert "data_parser must return a dict" in response .body ["message" ]
570+
571+ # 3. data_parser deleting a key should return a clear ValueError
572+ def delete_key_parser (field_data ):
573+ del field_data ["data" ]
574+ return field_data
575+
576+ response = client .send_put (
577+ 0 ,
578+ [2 ],
579+ {"data" : [1 ], "extra" : [2 ]},
580+ data_parser = delete_key_parser ,
581+ )
582+ assert response .request_type == ZMQRequestType .PUT_ERROR
583+ assert "data_parser must not change dict keys" in response .body ["message" ]
584+
585+ # 4. data_parser adding a key should return a clear ValueError
586+ def add_key_parser (field_data ):
587+ field_data ["new_key" ] = [999 ]
588+ return field_data
589+
590+ response = client .send_put (
591+ 0 ,
592+ [3 ],
593+ {"data" : [1 ]},
594+ data_parser = add_key_parser ,
595+ )
596+ assert response .request_type == ZMQRequestType .PUT_ERROR
597+ assert "data_parser must not change dict keys" in response .body ["message" ]
598+
599+ # 5. data_parser changing element count should return a clear ValueError
600+ def wrong_len_parser (field_data ):
601+ field_data ["data" ] = field_data ["data" ][:- 1 ]
602+ return field_data
603+
604+ response = client .send_put (
605+ 0 ,
606+ [4 , 5 ],
607+ {"data" : [1 , 2 ]},
608+ data_parser = wrong_len_parser ,
609+ )
610+ assert response .request_type == ZMQRequestType .PUT_ERROR
611+ assert "data_parser changed the number of elements" in response .body ["message" ]
612+
613+ client .close ()
0 commit comments