@@ -474,37 +474,32 @@ def kv_batch_put(
474474 tags : list [dict [str , Any ]] | None = None ,
475475 data_parser : Callable [[Any ], Any ] | None = None ,
476476) -> KVBatchMeta :
477- """Put multiple key-value pairs to TransferQueue in batch .
477+ """Batch put multiple key-value pairs into the TransferQueue .
478478
479- This method stores multiple key-value pairs in a single operation, which is more
480- efficient than calling kv_put multiple times .
479+ This method stores multiple key-value entries in a single operation,
480+ which is significantly more efficient than repeated calls to ``kv_put`` .
481481
482482 Args:
483- keys: List of user-specified keys for the data
484- partition_id: Logical partition to store the data in
485- fields: TensorDict containing data for all keys. Must have batch_size == len(keys).
486- If not provided, will only update the newly given tags to the keys.
487- tags: List of metadata tags, one for each key
488- data_parser: Optional callable to parse reference data (e.g., URLs) into real
489- content. The input is a slice of the `fields` parameter passed to
490- kv_put / kv_batch_put, in plain dict format (not TensorDict),
491- mapping field_name -> batched values. For a regular tensor column
492- the value is a batched tensor; for nested tensors (jagged or
493- strided) and NonTensorStack columns the values are extracted into
494- a list. It must modify values in-place based on the original keys;
495- do not add or remove keys. The number of elements per column must
496- also remain unchanged. Do not change the inner order of values
497- within each column. Only supported by SimpleStorage.
483+ keys: List of user-defined unique keys for the data entries.
484+ partition_id: Logical partition where the data will be stored.
485+ fields: TensorDict containing batched data for all keys. Must have ``batch_size == len(keys)``.
486+ If not provided, only the associated tags will be updated.
487+ tags: List of metadata dictionaries, one per key. Length must match the number of keys.
488+ data_parser: Optional callable to parse raw reference data (e.g., URLs) into real content
489+ before storage. The input is a plain dict (not TensorDict) mapping field names to
490+ batched values. The parser **must modify data in-place** without adding/removing
491+ keys or changing element counts/order. Only supported by ``SimpleStorage`` backend.
498492
499493 Returns:
500- KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields.
501- The `fields` attribute includes all fields stored for these samples,
502- including any new fields written by this put operation .
494+ KVBatchMeta: Metadata object containing stored keys, tags, partition ID,
495+ and field information. The `` fields`` attribute includes all
496+ persisted fields for the written samples .
503497
504498 Raises:
505- ValueError: If neither `fields` nor `tags` is provided
506- ValueError: If length of `keys` doesn't match length of `tags` or the batch_size of `fields` TensorDict
507- RuntimeError: If retrieved BatchMeta size doesn't match length of `keys`
499+ ValueError: When both ``fields`` and ``tags`` are empty.
500+ ValueError: When ``fields`` batch size mismatches key count.
501+ ValueError: When ``tags`` length mismatches key count.
502+ RuntimeError: When retrieved metadata size mismatches input key count.
508503
509504 Example:
510505 >>> import transfer_queue as tq
@@ -517,49 +512,37 @@ def kv_batch_put(
517512 ... }, batch_size=3)
518513 >>> tags = [{"score": 0.9}, {"score": 0.85}, {"score": 0.95}]
519514 >>> meta = tq.kv_batch_put(keys=keys, partition_id="train", fields=fields, tags=tags)
520- >>> print(meta.fields) # ['input_ids', 'attention_mask']
515+ >>> print(meta.fields)
521516 """
517+ num_keys = len (keys )
522518
523519 if fields is None and tags is None :
524520 raise ValueError ("Please provide at least one parameter of fields or tag." )
525521
526- if fields is not None and fields .batch_size [0 ] != len (keys ):
527- raise ValueError (
528- f"`keys` with length { len (keys )} does not match the `fields` TensorDict with "
529- f"batch_size { fields .batch_size [0 ]} "
530- )
522+ if fields is not None and fields .batch_size [0 ] != num_keys :
523+ raise ValueError (f"Length of `keys` ({ num_keys } ) does not match `fields` batch size ({ fields .batch_size [0 ]} )." )
531524
532525 tq_client = _maybe_create_tq_client ()
533-
534- # 1. translate user-specified key to BatchMeta
535526 batch_meta = tq_client .kv_retrieve_meta (keys = keys , partition_id = partition_id , create = True )
536527
537- if batch_meta .size != len (keys ):
538- raise RuntimeError (
539- f"Retrieved BatchMeta size { batch_meta .size } does not match with input `keys` size { len (keys )} !"
540- )
528+ if batch_meta .size != num_keys :
529+ raise RuntimeError (f"Retrieved BatchMeta size { batch_meta .size } does not match input `keys` size { num_keys } ." )
541530
542- # 2. register the user-specified tags to BatchMeta
543531 if tags is not None :
544- if len (tags ) != len ( keys ) :
545- raise ValueError (f"keys with length { len ( keys ) } does not match length of tags { len (tags )} " )
532+ if len (tags ) != num_keys :
533+ raise ValueError (f"Length of `keys` ( { num_keys } ) does not match length of ` tags` ( { len (tags )} ). " )
546534 batch_meta .update_custom_meta (tags )
547535
548- # 3. put data
549536 if fields is not None :
550- # After put, batch_meta.field_names will include the new fields written by user
551537 batch_meta = tq_client .put (fields , batch_meta , data_parser = data_parser )
552- else :
553- # Directly update custom_meta (tags) to controller
538+ else : # tags is not None
554539 tq_client .set_custom_meta (batch_meta )
555540
556- fields_to_return = batch_meta .field_names
557-
558541 return KVBatchMeta (
559542 keys = keys ,
560543 tags = batch_meta .custom_meta ,
561544 partition_id = partition_id ,
562- fields = fields_to_return ,
545+ fields = batch_meta . field_names ,
563546 extra_info = batch_meta .extra_info ,
564547 )
565548
0 commit comments