Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,125 @@ def test_chunk_concat_roundtrip_preserves_extra_info(self):
assert len(restored) == 6
assert restored.global_indexes == list(range(6))

def test_union_basic(self):
"""union merges fields from two batches with identical global_indexes."""
batch_a = BatchMeta(
global_indexes=[0, 1, 2],
partition_ids=["p0", "p0", "p0"],
field_schema={
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(3, dtype=np.int8),
custom_meta=[{"a": 1}, {"a": 2}, {"a": 3}],
)
batch_b = BatchMeta(
global_indexes=[0, 1, 2],
partition_ids=["p0", "p0", "p0"],
field_schema={
"field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(3, dtype=np.int8),
custom_meta=[{"b": 10}, {"b": 20}, {"b": 30}],
)
result = batch_a.union(batch_b)
assert result.global_indexes == [0, 1, 2]
assert result.partition_ids == ["p0", "p0", "p0"]
assert sorted(result.field_names) == ["field_a", "field_b"]
assert result.is_ready
assert result.custom_meta == [{"a": 1, "b": 10}, {"a": 2, "b": 20}, {"a": 3, "b": 30}]

def test_union_overlapping_fields(self):
"""union replaces overlapping fields with other's definitions."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(2, dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_a": {"dtype": torch.int64, "shape": (8,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(2, dtype=np.int8),
)
result = batch_a.union(batch_b)
assert result.field_schema["field_a"]["dtype"] == torch.int64
assert result.field_schema["field_a"]["shape"] == (8,)

def test_union_production_status_and(self):
"""union conservatively merges production_status via bitwise AND."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.array([1, 0], dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.array([1, 1], dtype=np.int8),
)
result = batch_a.union(batch_b)
assert list(result.production_status) == [1, 0]
assert result.is_ready is False

def test_union_validation_global_index_mismatch(self):
"""union raises ValueError when global_indexes do not match."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[1, 2],
partition_ids=["p0", "p0"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
with pytest.raises(ValueError, match="global_indexes do not match"):
batch_a.union(batch_b)

def test_union_validation_partition_id_mismatch(self):
"""union raises ValueError when partition_ids do not match."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p1"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
with pytest.raises(ValueError, match="partition_ids do not match"):
batch_a.union(batch_b)

def test_union_empty_other_returns_self(self):
"""union with an empty batch returns self."""
batch = self._make_batch(batch_size=2)
empty = BatchMeta.empty()
result = batch.union(empty)
assert result is batch

def test_union_empty_self_returns_other(self):
"""union when self is empty returns other."""
batch = self._make_batch(batch_size=2)
empty = BatchMeta.empty()
result = empty.union(batch)
assert result is batch


# ==============================================================================
# KVBatchMeta Tests (all migrated from main with no modification)
Expand Down
69 changes: 57 additions & 12 deletions transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,31 +608,76 @@ def chunk_by_partition(self) -> list["BatchMeta"]:
return chunk_list

def union(self, other: "BatchMeta") -> "BatchMeta":
"""Return the union of this BatchMeta and another BatchMeta.
Samples with global_indexes already present in this batch are ignored from the other batch.
"""Create a union of this batch's fields with another batch's fields.

Both batches must have the same global indices and matching partition_ids
for all samples. If fields overlap, the fields in this batch will be
replaced by the other batch's fields.

Args:
other: The other BatchMeta to merge with.
other: Another BatchMeta to union with.

Returns:
BatchMeta: A new merged BatchMeta.
New BatchMeta with unioned fields.

Raises:
ValueError: If global_indexes, or partition_ids do not match.
"""
if not other or other.size == 0:
return self
if self.size == 0:
return other
Comment thread
0oshowero0 marked this conversation as resolved.
Outdated

self_indexes = set(self.global_indexes)
unique_indices_in_other = [i for i, idx in enumerate(other.global_indexes) if idx not in self_indexes]
if self.global_indexes != other.global_indexes:
raise ValueError(
f"BatchMeta.union: global_indexes do not match. "
f"self.global_indexes={self.global_indexes}, "
f"other.global_indexes={other.global_indexes}"
)

if not unique_indices_in_other:
return self
if self.partition_ids != other.partition_ids:
raise ValueError(
f"BatchMeta.union: partition_ids do not match. "
f"self.partition_ids={self.partition_ids}, "
f"other.partition_ids={other.partition_ids}"
)
Comment on lines +645 to +657

if len(unique_indices_in_other) == other.size:
return BatchMeta.concat([self, other])
# Merge field_schema: other overrides self on name conflicts
merged_field_schema = copy.deepcopy(self.field_schema)
for field_name, meta in other.field_schema.items():
merged_field_schema[field_name] = copy.deepcopy(meta)

# Merge production_status conservatively: both sides must report ready
# for the merged sample to be considered ready, since each side may
# cover a disjoint subset of fields.
merged_production_status = np.bitwise_and(self.production_status, other.production_status)

# Merge extra_info: other overrides self on key conflicts
merged_extra_info = {**self.extra_info, **other.extra_info}

# Merge custom_meta per sample
merged_custom_meta = []
for i in range(self.size):
merged_cm = copy.deepcopy(self.custom_meta[i])
merged_cm.update(copy.deepcopy(other.custom_meta[i]))
merged_custom_meta.append(merged_cm)

# Merge _custom_backend_meta per sample
merged_custom_backend_meta = []
for i in range(self.size):
merged_bm = copy.deepcopy(self._custom_backend_meta[i])
merged_bm.update(copy.deepcopy(other._custom_backend_meta[i]))
merged_custom_backend_meta.append(merged_bm)

other_unique = other.select_samples(unique_indices_in_other)
return BatchMeta.concat([self, other_unique])
return BatchMeta(
global_indexes=list(self.global_indexes),
partition_ids=list(self.partition_ids),
field_schema=merged_field_schema,
production_status=merged_production_status,
extra_info=merged_extra_info,
custom_meta=merged_custom_meta,
_custom_backend_meta=merged_custom_backend_meta,
)

@classmethod
def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta":
Expand Down
15 changes: 8 additions & 7 deletions tutorial/03_metadata_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,14 @@ def make_batch(global_indexes, fields=None):
print(f"✓ Concatenated {len(batch1)} + {len(batch2)} = {len(concatenated)} samples")
print(f" Global indexes: {concatenated.global_indexes}")

# --- 9. union (dedup by global_index) ---
print("[Example 9] Unioning batches with overlapping global_indexes...")
# --- 9. union (merge fields for same samples) ---
print("[Example 9] Unioning batches with same global_indexes but different fields...")
batch_a = make_batch(list(range(3)), fields=["input_ids", "attention_mask"])
batch_b = make_batch(list(range(2, 5)), fields=["input_ids", "attention_mask"])
print(f" BatchA: {batch_a.global_indexes}, BatchB: {batch_b.global_indexes}")
batch_b = make_batch(list(range(3)), fields=["attention_mask", "responses"])
print(f" BatchA fields: {batch_a.field_names}, BatchB fields: {batch_b.field_names}")
unioned = batch_a.union(batch_b)
print(f"✓ Unioned: {unioned.global_indexes} (global_index=2 deduplicated)")
print(f"✓ Unioned fields: {unioned.field_names} (same global_indexes={unioned.global_indexes})")
print(" Note: 'attention_mask' was present in both; other's definition is kept.")

# --- 10. Empty BatchMeta ---
print("[Example 10] Creating an empty BatchMeta...")
Expand All @@ -228,8 +229,8 @@ def make_batch(global_indexes, fields=None):

print("=" * 80)
print("concat vs union:")
print(" - concat: Combines batches with SAME field structure")
print(" - union: Merges batches, deduplicating by global_index")
print(" - concat: Combines batches with SAME field structure (append rows)")
print(" - union: Merges batches with SAME global_indexes (append columns/fields)")
print("=" * 80)


Expand Down
Loading