1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import struct
1716from abc import ABC , abstractmethod
1817from concurrent .futures import ThreadPoolExecutor
1918from typing import Any , Callable , Optional
2322
2423from transfer_queue .storage .clients .base import StorageClientFactory , StorageKVClient
2524from transfer_queue .utils .logging_utils import get_logger
26- from transfer_queue .utils .serial_utils import _decoder , _encoder
25+ from transfer_queue .utils .serial_utils import batch_decode_from , batch_encode_into
2726from transfer_queue .utils .yuanrong_utils import find_reachable_host
2827
2928logger = get_logger (__name__ )
@@ -193,19 +192,11 @@ def _create_empty_npu_tensorlist(self, shapes: list[Any], dtypes: list[Any]) ->
193192class GeneralKVClientAdapter (StorageStrategy ):
194193 """Adapter for general-purpose KV storage with serialization.
195194 Using yr.datasystem.KVClient to connect datasystem backends.
196- The serialization method uses '_decoder ' and '_encoder ' from 'transfer_queue.utils.serial_utils'.
195+ The serialization method uses 'batch_encode_into ' and 'batch_decode_from ' from 'transfer_queue.utils.serial_utils'.
197196 """
198197
199198 PUT_KEYS_LIMIT : int = 10_000
200199 GET_CLEAR_KEYS_LIMIT : int = 10_000
201-
202- # Header: number of entries (uint32, little-endian)
203- HEADER_FMT = "<I"
204- HEADER_SIZE = struct .calcsize (HEADER_FMT )
205- # Entry: (payload_offset: uint32, payload_size: uint32)
206- ENTRY_FMT = "<II"
207- ENTRY_SIZE = struct .calcsize (ENTRY_FMT )
208-
209200 DS_MAX_WORKERS : int = 16
210201
211202 def __init__ (self , config : dict ):
@@ -270,84 +261,23 @@ def clear(self, keys: list[str]) -> None:
270261 batch_keys = keys [i : i + self .GET_CLEAR_KEYS_LIMIT ]
271262 self ._ds_client .delete (batch_keys )
272263
273- @classmethod
274- def calc_packed_size (cls , items : list [memoryview ]) -> int :
275- """
276- Calculate the total size (in bytes) required to pack a list of memoryview items
277- into the structured binary format used by pack_into.
278-
279- Args:
280- items: List of memoryview objects to be packed.
281-
282- Returns:
283- Total buffer size in bytes.
284- """
285- return cls .HEADER_SIZE + len (items ) * cls .ENTRY_SIZE + sum (item .nbytes for item in items )
286-
287- @classmethod
288- def pack_into (cls , target : memoryview , items : list [memoryview ]):
289- """
290- Pack multiple contiguous buffers into a single buffer.
291- ┌───────────────┐
292- │ item_count │ uint32
293- ├───────────────┤
294- │ entries │ N * item entries
295- ├───────────────┤
296- │ payload blob │ N * concatenated buffers
297- └───────────────┘
298-
299- Args:
300- target (memoryview): A writable memoryview returned by StateValueBuffer.MutableData().
301- It must be large enough to accommodate the total number of bytes of HEADER + ENTRY_TABLE + all items.
302- This buffer is usually mapped to shared memory or Zero-Copy memory area.
303- items (List[memoryview]): List of read-only memory views (e.g., from serialized objects).
304- Each item must support the buffer protocol and be readable as raw bytes.
305-
306- """
307- struct .pack_into (cls .HEADER_FMT , target , 0 , len (items ))
308-
309- entry_offset = cls .HEADER_SIZE
310- payload_offset = cls .HEADER_SIZE + len (items ) * cls .ENTRY_SIZE
311-
312- target_tensor = torch .frombuffer (target , dtype = torch .uint8 )
313-
314- for item in items :
315- struct .pack_into (cls .ENTRY_FMT , target , entry_offset , payload_offset , item .nbytes )
316- src_tensor = torch .frombuffer (item , dtype = torch .uint8 )
317- target_tensor [payload_offset : payload_offset + item .nbytes ].copy_ (src_tensor )
318- entry_offset += cls .ENTRY_SIZE
319- payload_offset += item .nbytes
320-
321- @classmethod
322- def unpack_from (cls , source : memoryview ) -> list [memoryview ]:
323- """
324- Unpack multiple contiguous buffers from a single packed buffer.
325- Args:
326- source (memoryview): The packed source buffer.
327- Returns:
328- list[memoryview]: List of unpacked contiguous buffers.
329- """
330- mv = memoryview (source )
331- item_count = struct .unpack_from (cls .HEADER_FMT , mv , 0 )[0 ]
332- offsets = []
333- for i in range (item_count ):
334- offset , length = struct .unpack_from (cls .ENTRY_FMT , mv , cls .HEADER_SIZE + i * cls .ENTRY_SIZE )
335- offsets .append ((offset , length ))
336- return [mv [offset : offset + length ] for offset , length in offsets ]
337-
338264 def mset_zero_copy (self , keys : list [str ], objs : list [Any ]):
339265 """Store multiple objects in zero-copy mode using parallel serialization and buffer packing.
340266
341267 Args:
342268 keys (list[str]): List of string keys under which the objects will be stored.
343269 objs (list[Any]): List of Python objects to store (e.g., tensors, strings).
344270 """
345- items_list = [[memoryview (b ) for b in _encoder .encode (obj )] for obj in objs ]
346- packed_sizes = [self .calc_packed_size (items ) for items in items_list ]
347- buffers = self ._ds_client .mcreate (keys , packed_sizes )
348- tasks = [(target .MutableData (), item ) for target , item in zip (buffers , items_list , strict = True )]
349- with ThreadPoolExecutor (max_workers = self .DS_MAX_WORKERS ) as executor :
350- list (executor .map (lambda p : self .pack_into (* p ), tasks ))
271+ buffers : list = []
272+
273+ def alloc (sizes ):
274+ # DataSystem buffers must be converted via MutableData() to obtain
275+ # a memoryview-compatible data structure for zero-copy packing.
276+ mcreate_bufs = self ._ds_client .mcreate (keys , sizes )
277+ buffers .extend (mcreate_bufs )
278+ return [buf .MutableData () for buf in mcreate_bufs ]
279+
280+ batch_encode_into (objs , alloc , num_workers = self .DS_MAX_WORKERS )
351281 self ._ds_client .mset_buffer (buffers )
352282
353283 def mget_zero_copy (self , keys : list [str ]) -> list [Any ]:
@@ -360,7 +290,13 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]:
360290 list[Any]: List of deserialized objects corresponding to the input keys.
361291 """
362292 buffers = self ._ds_client .get_buffers (keys )
363- return [_decoder .decode (self .unpack_from (buffer )) if buffer is not None else None for buffer in buffers ]
293+ valid_indexes = [i for i , buf in enumerate (buffers ) if buf is not None ]
294+ valid_bufs = [buffers [i ] for i in valid_indexes ]
295+ decoded_objs = batch_decode_from (valid_bufs )
296+ results = [None ] * len (keys )
297+ for idx , obj in zip (valid_indexes , decoded_objs , strict = True ):
298+ results [idx ] = obj
299+ return results
364300
365301
366302@StorageClientFactory .register ("YuanrongStorageClient" )
0 commit comments