From 2e2b9af5ce8c286056d2a6e512b441ddf211a9ea Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 12 May 2026 16:52:12 +0800 Subject: [PATCH 1/2] fix union for batchmeta Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 119 +++++++++++++++++++++++++++++++ transfer_queue/metadata.py | 69 ++++++++++++++---- tutorial/03_metadata_concepts.py | 15 ++-- 3 files changed, 184 insertions(+), 19 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index b9e80dac..452d56bf 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -360,6 +360,125 @@ def test_chunk_concat_roundtrip_preserves_extra_info(self): assert len(restored) == 6 assert restored.global_indexes == list(range(6)) + def test_union_basic(self): + """union merges fields from two batches with identical global_indexes.""" + batch_a = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["p0", "p0", "p0"], + field_schema={ + "field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.ones(3, dtype=np.int8), + custom_meta=[{"a": 1}, {"a": 2}, {"a": 3}], + ) + batch_b = BatchMeta( + global_indexes=[0, 1, 2], + partition_ids=["p0", "p0", "p0"], + field_schema={ + "field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.ones(3, dtype=np.int8), + custom_meta=[{"b": 10}, {"b": 20}, {"b": 30}], + ) + result = batch_a.union(batch_b) + assert result.global_indexes == [0, 1, 2] + assert result.partition_ids == ["p0", "p0", "p0"] + assert sorted(result.field_names) == ["field_a", "field_b"] + assert result.is_ready + assert result.custom_meta == [{"a": 1, "b": 10}, {"a": 2, "b": 20}, {"a": 3, "b": 30}] + + def test_union_overlapping_fields(self): + """union replaces overlapping fields with other's definitions.""" + batch_a = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p0"], + field_schema={ + "field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.ones(2, dtype=np.int8), + ) + batch_b = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p0"], + field_schema={ + "field_a": {"dtype": torch.int64, "shape": (8,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.ones(2, dtype=np.int8), + ) + result = batch_a.union(batch_b) + assert result.field_schema["field_a"]["dtype"] == torch.int64 + assert result.field_schema["field_a"]["shape"] == (8,) + + def test_union_production_status_and(self): + """union conservatively merges production_status via bitwise AND.""" + batch_a = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p0"], + field_schema={ + "field_a": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.array([1, 0], dtype=np.int8), + ) + batch_b = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p0"], + field_schema={ + "field_b": {"dtype": torch.int64, "shape": (4,), "is_nested": False, "is_non_tensor": False}, + }, + production_status=np.array([1, 1], dtype=np.int8), + ) + result = batch_a.union(batch_b) + assert list(result.production_status) == [1, 0] + assert result.is_ready is False + + def test_union_validation_global_index_mismatch(self): + """union raises ValueError when global_indexes do not match.""" + batch_a = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(2, dtype=np.int8), + ) + batch_b = BatchMeta( + global_indexes=[1, 2], + partition_ids=["p0", "p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(2, dtype=np.int8), + ) + with pytest.raises(ValueError, match="global_indexes do not match"): + batch_a.union(batch_b) + + def test_union_validation_partition_id_mismatch(self): + """union raises ValueError when partition_ids do not match.""" + batch_a = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p0"], + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(2, dtype=np.int8), + ) + batch_b = BatchMeta( + global_indexes=[0, 1], + partition_ids=["p0", "p1"], + field_schema={"f": {"dtype": torch.float32, "shape": (2,), "is_nested": False, "is_non_tensor": False}}, + production_status=np.ones(2, dtype=np.int8), + ) + with pytest.raises(ValueError, match="partition_ids do not match"): + batch_a.union(batch_b) + + def test_union_empty_other_returns_self(self): + """union with an empty batch returns self.""" + batch = self._make_batch(batch_size=2) + empty = BatchMeta.empty() + result = batch.union(empty) + assert result is batch + + def test_union_empty_self_returns_other(self): + """union when self is empty returns other.""" + batch = self._make_batch(batch_size=2) + empty = BatchMeta.empty() + result = empty.union(batch) + assert result is batch + # ============================================================================== # KVBatchMeta Tests (all migrated from main with no modification) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 05cf3241..520b0577 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -608,31 +608,76 @@ def chunk_by_partition(self) -> list["BatchMeta"]: return chunk_list def union(self, other: "BatchMeta") -> "BatchMeta": - """Return the union of this BatchMeta and another BatchMeta. - Samples with global_indexes already present in this batch are ignored from the other batch. + """Create a union of this batch's fields with another batch's fields. + + Both batches must have the same global indices and matching partition_ids + for all samples. If fields overlap, the fields in this batch will be + replaced by the other batch's fields. Args: - other: The other BatchMeta to merge with. + other: Another BatchMeta to union with. Returns: - BatchMeta: A new merged BatchMeta. + New BatchMeta with unioned fields. + + Raises: + ValueError: If global_indexes, or partition_ids do not match. """ if not other or other.size == 0: return self if self.size == 0: return other - self_indexes = set(self.global_indexes) - unique_indices_in_other = [i for i, idx in enumerate(other.global_indexes) if idx not in self_indexes] + if self.global_indexes != other.global_indexes: + raise ValueError( + f"BatchMeta.union: global_indexes do not match. " + f"self.global_indexes={self.global_indexes}, " + f"other.global_indexes={other.global_indexes}" + ) - if not unique_indices_in_other: - return self + if self.partition_ids != other.partition_ids: + raise ValueError( + f"BatchMeta.union: partition_ids do not match. " + f"self.partition_ids={self.partition_ids}, " + f"other.partition_ids={other.partition_ids}" + ) - if len(unique_indices_in_other) == other.size: - return BatchMeta.concat([self, other]) + # Merge field_schema: other overrides self on name conflicts + merged_field_schema = copy.deepcopy(self.field_schema) + for field_name, meta in other.field_schema.items(): + merged_field_schema[field_name] = copy.deepcopy(meta) + + # Merge production_status conservatively: both sides must report ready + # for the merged sample to be considered ready, since each side may + # cover a disjoint subset of fields. + merged_production_status = np.bitwise_and(self.production_status, other.production_status) + + # Merge extra_info: other overrides self on key conflicts + merged_extra_info = {**self.extra_info, **other.extra_info} + + # Merge custom_meta per sample + merged_custom_meta = [] + for i in range(self.size): + merged_cm = copy.deepcopy(self.custom_meta[i]) + merged_cm.update(copy.deepcopy(other.custom_meta[i])) + merged_custom_meta.append(merged_cm) + + # Merge _custom_backend_meta per sample + merged_custom_backend_meta = [] + for i in range(self.size): + merged_bm = copy.deepcopy(self._custom_backend_meta[i]) + merged_bm.update(copy.deepcopy(other._custom_backend_meta[i])) + merged_custom_backend_meta.append(merged_bm) - other_unique = other.select_samples(unique_indices_in_other) - return BatchMeta.concat([self, other_unique]) + return BatchMeta( + global_indexes=list(self.global_indexes), + partition_ids=list(self.partition_ids), + field_schema=merged_field_schema, + production_status=merged_production_status, + extra_info=merged_extra_info, + custom_meta=merged_custom_meta, + _custom_backend_meta=merged_custom_backend_meta, + ) @classmethod def concat(cls, data: list["BatchMeta"], validate: bool = True) -> "BatchMeta": diff --git a/tutorial/03_metadata_concepts.py b/tutorial/03_metadata_concepts.py index 24c0e4b7..a7482196 100644 --- a/tutorial/03_metadata_concepts.py +++ b/tutorial/03_metadata_concepts.py @@ -213,13 +213,14 @@ def make_batch(global_indexes, fields=None): print(f"✓ Concatenated {len(batch1)} + {len(batch2)} = {len(concatenated)} samples") print(f" Global indexes: {concatenated.global_indexes}") - # --- 9. union (dedup by global_index) --- - print("[Example 9] Unioning batches with overlapping global_indexes...") + # --- 9. union (merge fields for same samples) --- + print("[Example 9] Unioning batches with same global_indexes but different fields...") batch_a = make_batch(list(range(3)), fields=["input_ids", "attention_mask"]) - batch_b = make_batch(list(range(2, 5)), fields=["input_ids", "attention_mask"]) - print(f" BatchA: {batch_a.global_indexes}, BatchB: {batch_b.global_indexes}") + batch_b = make_batch(list(range(3)), fields=["attention_mask", "responses"]) + print(f" BatchA fields: {batch_a.field_names}, BatchB fields: {batch_b.field_names}") unioned = batch_a.union(batch_b) - print(f"✓ Unioned: {unioned.global_indexes} (global_index=2 deduplicated)") + print(f"✓ Unioned fields: {unioned.field_names} (same global_indexes={unioned.global_indexes})") + print(" Note: 'attention_mask' was present in both; other's definition is kept.") # --- 10. Empty BatchMeta --- print("[Example 10] Creating an empty BatchMeta...") @@ -228,8 +229,8 @@ def make_batch(global_indexes, fields=None): print("=" * 80) print("concat vs union:") - print(" - concat: Combines batches with SAME field structure") - print(" - union: Merges batches, deduplicating by global_index") + print(" - concat: Combines batches with SAME field structure (append rows)") + print(" - union: Merges batches with SAME global_indexes (append columns/fields)") print("=" * 80) From 0705688dd3ba8ea96722d17658d044e4e59bd5e3 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 12 May 2026 17:02:47 +0800 Subject: [PATCH 2/2] use deepcopy Signed-off-by: 0oshowero0 --- tests/test_metadata.py | 22 ++++++++++++++++------ transfer_queue/metadata.py | 20 +++++++++++++++++--- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 452d56bf..6fc49d2e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -465,19 +465,29 @@ def test_union_validation_partition_id_mismatch(self): with pytest.raises(ValueError, match="partition_ids do not match"): batch_a.union(batch_b) - def test_union_empty_other_returns_self(self): - """union with an empty batch returns self.""" + def test_union_empty_other_returns_copy(self): + """union with an empty batch returns a copy, not the original identity.""" batch = self._make_batch(batch_size=2) empty = BatchMeta.empty() result = batch.union(empty) - assert result is batch + assert result is not batch + assert result.global_indexes == batch.global_indexes + assert result.field_names == batch.field_names + # Mutating the result must not affect the original + result.extra_info["new_key"] = "new_value" + assert "new_key" not in batch.extra_info - def test_union_empty_self_returns_other(self): - """union when self is empty returns other.""" + def test_union_empty_self_returns_copy(self): + """union when self is empty returns a copy, not the original identity.""" batch = self._make_batch(batch_size=2) empty = BatchMeta.empty() result = empty.union(batch) - assert result is batch + assert result is not batch + assert result.global_indexes == batch.global_indexes + assert result.field_names == batch.field_names + # Mutating the result must not affect the original + result.extra_info["new_key"] = "new_value" + assert "new_key" not in batch.extra_info # ============================================================================== diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 520b0577..e89c1531 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -544,6 +544,18 @@ def select_fields(self, field_names: list[str]) -> "BatchMeta": _custom_backend_meta=selected_custom_backend_meta, ) + def copy(self) -> "BatchMeta": + """Return a deep copy of this BatchMeta.""" + return BatchMeta( + global_indexes=list(self.global_indexes), + partition_ids=list(self.partition_ids), + field_schema=copy.deepcopy(self.field_schema), + production_status=self.production_status.copy(), + extra_info=copy.deepcopy(self.extra_info), + custom_meta=copy.deepcopy(self.custom_meta), + _custom_backend_meta=copy.deepcopy(self._custom_backend_meta), + ) + def __len__(self) -> int: """Return the number of samples in this batch.""" return self.size @@ -618,15 +630,17 @@ def union(self, other: "BatchMeta") -> "BatchMeta": other: Another BatchMeta to union with. Returns: - New BatchMeta with unioned fields. + A new BatchMeta instance with unioned fields. Even when one side is + empty, a copy is returned so callers can safely mutate the result + without affecting the original. Raises: ValueError: If global_indexes, or partition_ids do not match. """ if not other or other.size == 0: - return self + return self.copy() if self.size == 0: - return other + return other.copy() if self.global_indexes != other.global_indexes: raise ValueError(