Skip to content

Commit 270ea73

Browse files
authored
[fix] Fix BatchMeta.union semantics (#95)
## Problem PR https://gitcode.com/Ascend/TransferQueue/pull/28 incorrectly rewrote `BatchMeta.union` to behave like `concat` with global_index deduplication. - **Original semantics**: merge fields for samples with identical `global_indexes` (row-aligned, column-expanded). - **Broken semantics**: append rows from `other` whose `global_indexes` are not present in `self`. This broke the design boundary between `union` (same rows, merge columns) and `concat` (same columns, append rows). ## Changes ### 1. Restore `BatchMeta.union` (`transfer_queue/metadata.py`) - Validate that both batches have the same `size`, `global_indexes`, and `partition_ids`. - Merge `field_schema` with `other` overriding `self` on name conflicts. - Merge `production_status` conservatively via `np.bitwise_and` (both sides must report ready). - Merge `extra_info`, `custom_meta`, and `_custom_backend_meta` per sample. ### 2. Update tutorial (`tutorial/03_metadata_concepts.py`) - Example now uses overlapping fields (`attention_mask` present in both batches) to demonstrate the override behavior. - Corrected comments to clearly distinguish `concat` (append rows) vs `union` (merge columns). ### 3. Update unit tests (`tests/test_metadata.py`) --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 6397458 commit 270ea73

3 files changed

Lines changed: 210 additions & 21 deletions

File tree

tests/test_metadata.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,135 @@ def test_chunk_concat_roundtrip_preserves_extra_info(self):
360360
assert len(restored) == 6
361361
assert restored.global_indexes == list(range(6))
362362

363+
def test_union_basic(self):
364+
"""union merges fields from two batches with identical global_indexes."""
365+
batch_a = BatchMeta(
366+
global_indexes=[0, 1, 2],
367+
partition_ids=["p0", "p0", "p0"],
368+
field_schema={
369+
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
370+
},
371+
production_status=np.ones(3, dtype=np.int8),
372+
custom_meta=[{"a": 1}, {"a": 2}, {"a": 3}],
373+
)
374+
batch_b = BatchMeta(
375+
global_indexes=[0, 1, 2],
376+
partition_ids=["p0", "p0", "p0"],
377+
field_schema={
378+
"field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False},
379+
},
380+
production_status=np.ones(3, dtype=np.int8),
381+
custom_meta=[{"b": 10}, {"b": 20}, {"b": 30}],
382+
)
383+
result = batch_a.union(batch_b)
384+
assert result.global_indexes == [0, 1, 2]
385+
assert result.partition_ids == ["p0", "p0", "p0"]
386+
assert sorted(result.field_names) == ["field_a", "field_b"]
387+
assert result.is_ready
388+
assert result.custom_meta == [{"a": 1, "b": 10}, {"a": 2, "b": 20}, {"a": 3, "b": 30}]
389+
390+
def test_union_overlapping_fields(self):
391+
"""union replaces overlapping fields with other's definitions."""
392+
batch_a = BatchMeta(
393+
global_indexes=[0, 1],
394+
partition_ids=["p0", "p0"],
395+
field_schema={
396+
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
397+
},
398+
production_status=np.ones(2, dtype=np.int8),
399+
)
400+
batch_b = BatchMeta(
401+
global_indexes=[0, 1],
402+
partition_ids=["p0", "p0"],
403+
field_schema={
404+
"field_a": {"dtype": torch.int64, "shape": (8,), "is_nested": False, "is_non_tensor": False},
405+
},
406+
production_status=np.ones(2, dtype=np.int8),
407+
)
408+
result = batch_a.union(batch_b)
409+
assert result.field_schema["field_a"]["dtype"] == torch.int64
410+
assert result.field_schema["field_a"]["shape"] == (8,)
411+
412+
def test_union_production_status_and(self):
413+
"""union conservatively merges production_status via bitwise AND."""
414+
batch_a = BatchMeta(
415+
global_indexes=[0, 1],
416+
partition_ids=["p0", "p0"],
417+
field_schema={
418+
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
419+
},
420+
production_status=np.array([1, 0], dtype=np.int8),
421+
)
422+
batch_b = BatchMeta(
423+
global_indexes=[0, 1],
424+
partition_ids=["p0", "p0"],
425+
field_schema={
426+
"field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False},
427+
},
428+
production_status=np.array([1, 1], dtype=np.int8),
429+
)
430+
result = batch_a.union(batch_b)
431+
assert list(result.production_status) == [1, 0]
432+
assert result.is_ready is False
433+
434+
def test_union_validation_global_index_mismatch(self):
435+
"""union raises ValueError when global_indexes do not match."""
436+
batch_a = BatchMeta(
437+
global_indexes=[0, 1],
438+
partition_ids=["p0", "p0"],
439+
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
440+
production_status=np.ones(2, dtype=np.int8),
441+
)
442+
batch_b = BatchMeta(
443+
global_indexes=[1, 2],
444+
partition_ids=["p0", "p0"],
445+
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
446+
production_status=np.ones(2, dtype=np.int8),
447+
)
448+
with pytest.raises(ValueError, match="global_indexes do not match"):
449+
batch_a.union(batch_b)
450+
451+
def test_union_validation_partition_id_mismatch(self):
452+
"""union raises ValueError when partition_ids do not match."""
453+
batch_a = BatchMeta(
454+
global_indexes=[0, 1],
455+
partition_ids=["p0", "p0"],
456+
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
457+
production_status=np.ones(2, dtype=np.int8),
458+
)
459+
batch_b = BatchMeta(
460+
global_indexes=[0, 1],
461+
partition_ids=["p0", "p1"],
462+
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
463+
production_status=np.ones(2, dtype=np.int8),
464+
)
465+
with pytest.raises(ValueError, match="partition_ids do not match"):
466+
batch_a.union(batch_b)
467+
468+
def test_union_empty_other_returns_copy(self):
469+
"""union with an empty batch returns a copy, not the original identity."""
470+
batch = self._make_batch(batch_size=2)
471+
empty = BatchMeta.empty()
472+
result = batch.union(empty)
473+
assert result is not batch
474+
assert result.global_indexes == batch.global_indexes
475+
assert result.field_names == batch.field_names
476+
# Mutating the result must not affect the original
477+
result.extra_info["new_key"] = "new_value"
478+
assert "new_key" not in batch.extra_info
479+
480+
def test_union_empty_self_returns_copy(self):
481+
"""union when self is empty returns a copy, not the original identity."""
482+
batch = self._make_batch(batch_size=2)
483+
empty = BatchMeta.empty()
484+
result = empty.union(batch)
485+
assert result is not batch
486+
assert result.global_indexes == batch.global_indexes
487+
assert result.field_names == batch.field_names
488+
# Mutating the result must not affect the original
489+
result.extra_info["new_key"] = "new_value"
490+
assert "new_key" not in batch.extra_info
491+
363492

364493
# ==============================================================================
365494
# KVBatchMeta Tests (all migrated from main with no modification)

transfer_queue/metadata.py

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,18 @@ def select_fields(self, field_names: list[str]) -> "BatchMeta":
544544
_custom_backend_meta=selected_custom_backend_meta,
545545
)
546546

547+
def copy(self) -> "BatchMeta":
548+
"""Return a deep copy of this BatchMeta."""
549+
return BatchMeta(
550+
global_indexes=list(self.global_indexes),
551+
partition_ids=list(self.partition_ids),
552+
field_schema=copy.deepcopy(self.field_schema),
553+
production_status=self.production_status.copy(),
554+
extra_info=copy.deepcopy(self.extra_info),
555+
custom_meta=copy.deepcopy(self.custom_meta),
556+
_custom_backend_meta=copy.deepcopy(self._custom_backend_meta),
557+
)
558+
547559
def __len__(self) -> int:
548560
"""Return the number of samples in this batch."""
549561
return self.size
@@ -608,31 +620,78 @@ def chunk_by_partition(self) -> list["BatchMeta"]:
608620
return chunk_list
609621

610622
def union(self, other: "BatchMeta") -> "BatchMeta":
611-
"""Return the union of this BatchMeta and another BatchMeta.
612-
Samples with global_indexes already present in this batch are ignored from the other batch.
623+
"""Create a union of this batch's fields with another batch's fields.
624+
625+
Both batches must have the same global indices and matching partition_ids
626+
for all samples. If fields overlap, the fields in this batch will be
627+
replaced by the other batch's fields.
613628
614629
Args:
615-
other: The other BatchMeta to merge with.
630+
other: Another BatchMeta to union with.
616631
617632
Returns:
618-
BatchMeta: A new merged BatchMeta.
633+
A new BatchMeta instance with unioned fields. Even when one side is
634+
empty, a copy is returned so callers can safely mutate the result
635+
without affecting the original.
636+
637+
Raises:
638+
ValueError: If global_indexes, or partition_ids do not match.
619639
"""
620640
if not other or other.size == 0:
621-
return self
641+
return self.copy()
622642
if self.size == 0:
623-
return other
643+
return other.copy()
624644

625-
self_indexes = set(self.global_indexes)
626-
unique_indices_in_other = [i for i, idx in enumerate(other.global_indexes) if idx not in self_indexes]
645+
if self.global_indexes != other.global_indexes:
646+
raise ValueError(
647+
f"BatchMeta.union: global_indexes do not match. "
648+
f"self.global_indexes={self.global_indexes}, "
649+
f"other.global_indexes={other.global_indexes}"
650+
)
627651

628-
if not unique_indices_in_other:
629-
return self
652+
if self.partition_ids != other.partition_ids:
653+
raise ValueError(
654+
f"BatchMeta.union: partition_ids do not match. "
655+
f"self.partition_ids={self.partition_ids}, "
656+
f"other.partition_ids={other.partition_ids}"
657+
)
630658

631-
if len(unique_indices_in_other) == other.size:
632-
return BatchMeta.concat([self, other])
659+
# Merge field_schema: other overrides self on name conflicts
660+
merged_field_schema = copy.deepcopy(self.field_schema)
661+
for field_name, meta in other.field_schema.items():
662+
merged_field_schema[field_name] = copy.deepcopy(meta)
663+
664+
# Merge production_status conservatively: both sides must report ready
665+
# for the merged sample to be considered ready, since each side may
666+
# cover a disjoint subset of fields.
667+
merged_production_status = np.bitwise_and(self.production_status, other.production_status)
668+
669+
# Merge extra_info: other overrides self on key conflicts
670+
merged_extra_info = {**self.extra_info, **other.extra_info}
671+
672+
# Merge custom_meta per sample
673+
merged_custom_meta = []
674+
for i in range(self.size):
675+
merged_cm = copy.deepcopy(self.custom_meta[i])
676+
merged_cm.update(copy.deepcopy(other.custom_meta[i]))
677+
merged_custom_meta.append(merged_cm)
678+
679+
# Merge _custom_backend_meta per sample
680+
merged_custom_backend_meta = []
681+
for i in range(self.size):
682+
merged_bm = copy.deepcopy(self._custom_backend_meta[i])
683+
merged_bm.update(copy.deepcopy(other._custom_backend_meta[i]))
684+
merged_custom_backend_meta.append(merged_bm)
633685

634-
other_unique = other.select_samples(unique_indices_in_other)
635-
return BatchMeta.concat([self, other_unique])
686+
return BatchMeta(
687+
global_indexes=list(self.global_indexes),
688+
partition_ids=list(self.partition_ids),
689+
field_schema=merged_field_schema,
690+
production_status=merged_production_status,
691+
extra_info=merged_extra_info,
692+
custom_meta=merged_custom_meta,
693+
_custom_backend_meta=merged_custom_backend_meta,
694+
)
636695

637696
@classmethod
638697
def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta":

tutorial/03_metadata_concepts.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,14 @@ def make_batch(global_indexes, fields=None):
213213
print(f"✓ Concatenated {len(batch1)} + {len(batch2)} = {len(concatenated)} samples")
214214
print(f" Global indexes: {concatenated.global_indexes}")
215215

216-
# --- 9. union (dedup by global_index) ---
217-
print("[Example 9] Unioning batches with overlapping global_indexes...")
216+
# --- 9. union (merge fields for same samples) ---
217+
print("[Example 9] Unioning batches with same global_indexes but different fields...")
218218
batch_a = make_batch(list(range(3)), fields=["input_ids", "attention_mask"])
219-
batch_b = make_batch(list(range(2, 5)), fields=["input_ids", "attention_mask"])
220-
print(f" BatchA: {batch_a.global_indexes}, BatchB: {batch_b.global_indexes}")
219+
batch_b = make_batch(list(range(3)), fields=["attention_mask", "responses"])
220+
print(f" BatchA fields: {batch_a.field_names}, BatchB fields: {batch_b.field_names}")
221221
unioned = batch_a.union(batch_b)
222-
print(f"✓ Unioned: {unioned.global_indexes} (global_index=2 deduplicated)")
222+
print(f"✓ Unioned fields: {unioned.field_names} (same global_indexes={unioned.global_indexes})")
223+
print(" Note: 'attention_mask' was present in both; other's definition is kept.")
223224

224225
# --- 10. Empty BatchMeta ---
225226
print("[Example 10] Creating an empty BatchMeta...")
@@ -228,8 +229,8 @@ def make_batch(global_indexes, fields=None):
228229

229230
print("=" * 80)
230231
print("concat vs union:")
231-
print(" - concat: Combines batches with SAME field structure")
232-
print(" - union: Merges batches, deduplicating by global_index")
232+
print(" - concat: Combines batches with SAME field structure (append rows)")
233+
print(" - union: Merges batches with SAME global_indexes (append columns/fields)")
233234
print("=" * 80)
234235

235236

0 commit comments

Comments
 (0)