Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,135 @@ def test_chunk_concat_roundtrip_preserves_extra_info(self):
assert len(restored) == 6
assert restored.global_indexes == list(range(6))

def test_union_basic(self):
"""union merges fields from two batches with identical global_indexes."""
batch_a = BatchMeta(
global_indexes=[0, 1, 2],
partition_ids=["p0", "p0", "p0"],
field_schema={
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(3, dtype=np.int8),
custom_meta=[{"a": 1}, {"a": 2}, {"a": 3}],
)
batch_b = BatchMeta(
global_indexes=[0, 1, 2],
partition_ids=["p0", "p0", "p0"],
field_schema={
"field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(3, dtype=np.int8),
custom_meta=[{"b": 10}, {"b": 20}, {"b": 30}],
)
result = batch_a.union(batch_b)
assert result.global_indexes == [0, 1, 2]
assert result.partition_ids == ["p0", "p0", "p0"]
assert sorted(result.field_names) == ["field_a", "field_b"]
assert result.is_ready
assert result.custom_meta == [{"a": 1, "b": 10}, {"a": 2, "b": 20}, {"a": 3, "b": 30}]

def test_union_overlapping_fields(self):
"""union replaces overlapping fields with other's definitions."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(2, dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_a": {"dtype": torch.int64, "shape": (8,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.ones(2, dtype=np.int8),
)
result = batch_a.union(batch_b)
assert result.field_schema["field_a"]["dtype"] == torch.int64
assert result.field_schema["field_a"]["shape"] == (8,)

def test_union_production_status_and(self):
"""union conservatively merges production_status via bitwise AND."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.array([1, 0], dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={
"field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False},
},
production_status=np.array([1, 1], dtype=np.int8),
)
result = batch_a.union(batch_b)
assert list(result.production_status) == [1, 0]
assert result.is_ready is False

def test_union_validation_global_index_mismatch(self):
"""union raises ValueError when global_indexes do not match."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[1, 2],
partition_ids=["p0", "p0"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
with pytest.raises(ValueError, match="global_indexes do not match"):
batch_a.union(batch_b)

def test_union_validation_partition_id_mismatch(self):
"""union raises ValueError when partition_ids do not match."""
batch_a = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p0"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
batch_b = BatchMeta(
global_indexes=[0, 1],
partition_ids=["p0", "p1"],
field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}},
production_status=np.ones(2, dtype=np.int8),
)
with pytest.raises(ValueError, match="partition_ids do not match"):
batch_a.union(batch_b)

def test_union_empty_other_returns_copy(self):
"""union with an empty batch returns a copy, not the original identity."""
batch = self._make_batch(batch_size=2)
empty = BatchMeta.empty()
result = batch.union(empty)
assert result is not batch
assert result.global_indexes == batch.global_indexes
assert result.field_names == batch.field_names
# Mutating the result must not affect the original
result.extra_info["new_key"] = "new_value"
assert "new_key" not in batch.extra_info

def test_union_empty_self_returns_copy(self):
"""union when self is empty returns a copy, not the original identity."""
batch = self._make_batch(batch_size=2)
empty = BatchMeta.empty()
result = empty.union(batch)
assert result is not batch
assert result.global_indexes == batch.global_indexes
assert result.field_names == batch.field_names
# Mutating the result must not affect the original
result.extra_info["new_key"] = "new_value"
assert "new_key" not in batch.extra_info


# ==============================================================================
# KVBatchMeta Tests (all migrated from main with no modification)
Expand Down
87 changes: 73 additions & 14 deletions transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,18 @@ def select_fields(self, field_names: list[str]) -> "BatchMeta":
_custom_backend_meta=selected_custom_backend_meta,
)

def copy(self) -> "BatchMeta":
"""Return a deep copy of this BatchMeta."""
return BatchMeta(
global_indexes=list(self.global_indexes),
partition_ids=list(self.partition_ids),
field_schema=copy.deepcopy(self.field_schema),
production_status=self.production_status.copy(),
extra_info=copy.deepcopy(self.extra_info),
custom_meta=copy.deepcopy(self.custom_meta),
_custom_backend_meta=copy.deepcopy(self._custom_backend_meta),
)

def __len__(self) -> int:
"""Return the number of samples in this batch."""
return self.size
Expand Down Expand Up @@ -608,31 +620,78 @@ def chunk_by_partition(self) -> list["BatchMeta"]:
return chunk_list

def union(self, other: "BatchMeta") -> "BatchMeta":
"""Return the union of this BatchMeta and another BatchMeta.
Samples with global_indexes already present in this batch are ignored from the other batch.
"""Create a union of this batch's fields with another batch's fields.

Both batches must have the same global indices and matching partition_ids
for all samples. If fields overlap, the fields in this batch will be
replaced by the other batch's fields.

Args:
other: The other BatchMeta to merge with.
other: Another BatchMeta to union with.

Returns:
BatchMeta: A new merged BatchMeta.
A new BatchMeta instance with unioned fields. Even when one side is
empty, a copy is returned so callers can safely mutate the result
without affecting the original.

Raises:
ValueError: If global_indexes, or partition_ids do not match.
"""
if not other or other.size == 0:
return self
return self.copy()
if self.size == 0:
return other
return other.copy()

self_indexes = set(self.global_indexes)
unique_indices_in_other = [i for i, idx in enumerate(other.global_indexes) if idx not in self_indexes]
if self.global_indexes != other.global_indexes:
raise ValueError(
f"BatchMeta.union: global_indexes do not match. "
f"self.global_indexes={self.global_indexes}, "
f"other.global_indexes={other.global_indexes}"
)

if not unique_indices_in_other:
return self
if self.partition_ids != other.partition_ids:
raise ValueError(
f"BatchMeta.union: partition_ids do not match. "
f"self.partition_ids={self.partition_ids}, "
f"other.partition_ids={other.partition_ids}"
)
Comment on lines +645 to +657

if len(unique_indices_in_other) == other.size:
return BatchMeta.concat([self, other])
# Merge field_schema: other overrides self on name conflicts
merged_field_schema = copy.deepcopy(self.field_schema)
for field_name, meta in other.field_schema.items():
merged_field_schema[field_name] = copy.deepcopy(meta)

# Merge production_status conservatively: both sides must report ready
# for the merged sample to be considered ready, since each side may
# cover a disjoint subset of fields.
merged_production_status = np.bitwise_and(self.production_status, other.production_status)

# Merge extra_info: other overrides self on key conflicts
merged_extra_info = {**self.extra_info, **other.extra_info}

# Merge custom_meta per sample
merged_custom_meta = []
for i in range(self.size):
merged_cm = copy.deepcopy(self.custom_meta[i])
merged_cm.update(copy.deepcopy(other.custom_meta[i]))
merged_custom_meta.append(merged_cm)

# Merge _custom_backend_meta per sample
merged_custom_backend_meta = []
for i in range(self.size):
merged_bm = copy.deepcopy(self._custom_backend_meta[i])
merged_bm.update(copy.deepcopy(other._custom_backend_meta[i]))
merged_custom_backend_meta.append(merged_bm)

other_unique = other.select_samples(unique_indices_in_other)
return BatchMeta.concat([self, other_unique])
return BatchMeta(
global_indexes=list(self.global_indexes),
partition_ids=list(self.partition_ids),
field_schema=merged_field_schema,
production_status=merged_production_status,
extra_info=merged_extra_info,
custom_meta=merged_custom_meta,
_custom_backend_meta=merged_custom_backend_meta,
)

@classmethod
def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta":
Expand Down
15 changes: 8 additions & 7 deletions tutorial/03_metadata_concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,14 @@ def make_batch(global_indexes, fields=None):
print(f"✓ Concatenated {len(batch1)} + {len(batch2)} = {len(concatenated)} samples")
print(f" Global indexes: {concatenated.global_indexes}")

# --- 9. union (dedup by global_index) ---
print("[Example 9] Unioning batches with overlapping global_indexes...")
# --- 9. union (merge fields for same samples) ---
print("[Example 9] Unioning batches with same global_indexes but different fields...")
batch_a = make_batch(list(range(3)), fields=["input_ids", "attention_mask"])
batch_b = make_batch(list(range(2, 5)), fields=["input_ids", "attention_mask"])
print(f" BatchA: {batch_a.global_indexes}, BatchB: {batch_b.global_indexes}")
batch_b = make_batch(list(range(3)), fields=["attention_mask", "responses"])
print(f" BatchA fields: {batch_a.field_names}, BatchB fields: {batch_b.field_names}")
unioned = batch_a.union(batch_b)
print(f"✓ Unioned: {unioned.global_indexes} (global_index=2 deduplicated)")
print(f"✓ Unioned fields: {unioned.field_names} (same global_indexes={unioned.global_indexes})")
print(" Note: 'attention_mask' was present in both; other's definition is kept.")

# --- 10. Empty BatchMeta ---
print("[Example 10] Creating an empty BatchMeta...")
Expand All @@ -228,8 +229,8 @@ def make_batch(global_indexes, fields=None):

print("=" * 80)
print("concat vs union:")
print(" - concat: Combines batches with SAME field structure")
print(" - union: Merges batches, deduplicating by global_index")
print(" - concat: Combines batches with SAME field structure (append rows)")
print(" - union: Merges batches with SAME global_indexes (append columns/fields)")
print("=" * 80)


Expand Down
Loading