diff --git a/tests/test_graph_io.py b/tests/test_graph_io.py index 6158223..ab377e3 100644 --- a/tests/test_graph_io.py +++ b/tests/test_graph_io.py @@ -1,4 +1,4 @@ -"""Tests for graph save/load round-trip, including per-node graph_id.""" +"""Tests for graph save/load round-trip, including per-node graph_id and multi-label names.""" import json from pathlib import Path @@ -31,9 +31,9 @@ def test_save_load_roundtrip_unified_graph(tmp_path: Path): """Unified graph with nodes from different source graphs preserves graph_id.""" g = Graph(id="unified") # Manually add nodes with different source graph_ids - g.nodes["n1"] = Node(id="n1", graph_id="article-1", name="Alice") - g.nodes["n2"] = Node(id="n2", graph_id="article-2", name="Bob") - g.nodes["n3"] = Node(id="n3", graph_id="unified", name="Carol") + g.nodes["n1"] = Node(id="n1", graph_id="article-1", names=["Alice"]) + g.nodes["n2"] = Node(id="n2", graph_id="article-2", names=["Bob"]) + g.nodes["n3"] = Node(id="n3", graph_id="unified", names=["Carol"]) g.edges.append(Edge(source="n1", target="n2", relation="knows")) path = tmp_path / "unified.json" @@ -52,3 +52,24 @@ def test_save_load_roundtrip_unified_graph(tmp_path: Path): assert loaded.nodes["n1"].graph_id == "article-1" assert loaded.nodes["n2"].graph_id == "article-2" assert loaded.nodes["n3"].graph_id == "unified" + + +def test_save_load_roundtrip_multi_label_names(tmp_path: Path): + """Entities with multiple names survive save/load round-trip.""" + g = Graph(id="article-1") + n1 = g.add_entity(["Meridian Technologies", "Meridian Tech"]) + n2 = g.add_entity("DataVault") + g.add_edge(n1, n2, "acquired") + + path = tmp_path / "g.json" + save_graph(g, path) + + with open(path) as f: + data = json.load(f) + node_by_id = {n["id"]: n for n in data["nodes"]} + assert node_by_id[n1.id]["names"] == ["Meridian Technologies", "Meridian Tech"] + assert node_by_id[n2.id]["names"] == ["DataVault"] + + loaded = load_graph(path) + assert loaded.nodes[n1.id].names == ["Meridian Technologies", "Meridian Tech"] + assert loaded.nodes[n2.id].names == ["DataVault"] diff --git a/tests/test_propagation.py b/tests/test_propagation.py index 89ae71e..75c31f5 100644 --- a/tests/test_propagation.py +++ b/tests/test_propagation.py @@ -13,6 +13,7 @@ - Name variation with structural reinforcement (the core use case) - Dangling entities get no structural evidence - Exponential sum accumulates evidence from multiple paths (bidirectional > unidirectional) +- Multi-label entities use max similarity across all names during seeding """ from worldgraph.graph import Graph @@ -427,7 +428,9 @@ def test_shared_anchor_does_not_override_name_dissimilarity(embedder): # Premise: name similarity alone is below threshold from worldgraph.names import build_idf, soft_tfidf - names = [n.name for g in [g1, g2, *bg_graphs] for n in g.nodes.values()] + names = [ + name for g in [g1, g2, *bg_graphs] for n in g.nodes.values() for name in n.names + ] idf = build_idf(names) sv_name_sim = soft_tfidf("Dr. Priya Sharma", "Dr. Elena Vasquez", idf) assert sv_name_sim < 0.8 @@ -598,3 +601,54 @@ def test_positive_evidence_is_monotonically_nondecreasing(embedder): f"at max_iter={n_iter}" ) prev_conf = curr_conf + + +# --------------------------------------------------------------------------- +# Multi-label name seeding +# --------------------------------------------------------------------------- + + +def test_multi_label_entity_uses_best_name_pair(embedder): + """An entity with multiple names should seed similarity using the best + name pair across both entities' name lists. + + "Meridian Technologies" stored as names=["Meridian Technologies"] in g1, + and names=["Meridian Tech", "Meridian Technologies"] in g2. The best + pair is "Meridian Technologies"/"Meridian Technologies" (score ~1.0), + not "Meridian Technologies"/"Meridian Tech" (~0.88). + + Without multi-label support, only one name is stored and the closest + pair may be missed, under-estimating similarity.""" + g1 = Graph(id="g1") + m1 = g1.add_entity("Meridian Technologies") + dv1 = g1.add_entity("DataVault") + g1.add_edge(m1, dv1, "acquired") + + g2 = Graph(id="g2") + m2 = g2.add_entity(["Meridian Tech", "Meridian Technologies"]) + dv2 = g2.add_entity("DataVault") + g2.add_edge(m2, dv2, "purchased") + + confidence = match_graphs([g1, g2], embedder) + + # With multi-label, the best name pair is exact match → seed ~1.0 + # Without, if only "Meridian Tech" is stored, seed would be ~0.88 + assert confidence[(m1.id, m2.id)] > 0.8 + + +def test_multi_label_all_names_contribute_to_idf(embedder): + """All names in an entity's name list should contribute to IDF + computation, not just the first.""" + g1 = Graph(id="g1") + m1 = g1.add_entity(["Meridian Technologies", "Meridian Tech"]) + dv1 = g1.add_entity("DataVault") + g1.add_edge(m1, dv1, "acquired") + + g2 = Graph(id="g2") + m2 = g2.add_entity("Meridian Technologies") + dv2 = g2.add_entity("DataVault") + g2.add_edge(m2, dv2, "purchased") + + # Should not raise — multi-label names flow through the pipeline + confidence = match_graphs([g1, g2], embedder) + assert confidence[(m1.id, m2.id)] > 0.8 diff --git a/worldgraph/graph.py b/worldgraph/graph.py index 8838626..fe4320e 100644 --- a/worldgraph/graph.py +++ b/worldgraph/graph.py @@ -10,7 +10,7 @@ class Node: id: str graph_id: str - name: str + names: list[str] @dataclass @@ -26,9 +26,11 @@ class Graph: nodes: dict[str, Node] = field(default_factory=dict) edges: list[Edge] = field(default_factory=list) - def add_entity(self, name: str) -> Node: - """Add an entity node with the given name.""" - entity = Node(id=str(uuid.uuid4()), graph_id=self.id, name=name) + def add_entity(self, names: str | list[str]) -> Node: + """Add an entity node with the given name(s).""" + if isinstance(names, str): + names = [names] + entity = Node(id=str(uuid.uuid4()), graph_id=self.id, names=names) self.nodes[entity.id] = entity return entity @@ -50,7 +52,7 @@ def load_graph(path: Path) -> Graph: nodes[node_id] = Node( id=node_id, graph_id=node_data["graph_id"], - name=node_data["name"], + names=node_data["names"], ) edges: list[Edge] = [] @@ -74,7 +76,9 @@ def save_graph( """Write graph to JSON, with optional match groups.""" nodes_out = [] for node in graph.nodes.values(): - nodes_out.append({"id": node.id, "graph_id": node.graph_id, "name": node.name}) + nodes_out.append( + {"id": node.id, "graph_id": node.graph_id, "names": node.names} + ) edges_out = [] for edge in graph.edges: diff --git a/worldgraph/match.py b/worldgraph/match.py index f442710..b8d573b 100644 --- a/worldgraph/match.py +++ b/worldgraph/match.py @@ -1,8 +1,9 @@ """Stage 2: Entity alignment via PARIS-style similarity propagation. -Entity names are stored directly on nodes. Name similarity seeds the -confidence dict before the iteration loop, so structural evidence -propagates from iteration 1. Relation similarity is treated as binary +Entity names are stored as lists on nodes (multi-label). Name similarity +seeds the confidence dict before the iteration loop using the max over all +name pairs, so structural evidence propagates from iteration 1. Relation +similarity is treated as binary via a single threshold that defines equivalence classes over free-text relation phrases — above threshold = same relation, below = different. This threshold is used consistently for functionality pooling, positive @@ -116,8 +117,8 @@ def compute_functionality( phrase_pairs: dict[str, list[tuple[str, str]]] = defaultdict(list) for graph in graphs: for edge in graph.edges: - source_name = graph.nodes[edge.source].name - target_name = graph.nodes[edge.target].name + source_name = graph.nodes[edge.source].names[0] + target_name = graph.nodes[edge.target].names[0] phrase_pairs[edge.relation].append((source_name, target_name)) result: dict[str, Functionality] = {} @@ -380,9 +381,11 @@ def propagate_similarity( if graph.nodes[id_a].graph_id == graph.nodes[id_b].graph_id: continue name_sim = max( - 0.0, - soft_tfidf(graph.nodes[id_a].name, graph.nodes[id_b].name, idf), + soft_tfidf(na, nb, idf) + for na in graph.nodes[id_a].names + for nb in graph.nodes[id_b].names ) + name_sim = max(0.0, name_sim) confidence[(id_a, id_b)] = name_sim confidence[(id_b, id_a)] = name_sim pairs.append((id_a, id_b)) @@ -487,7 +490,9 @@ def match_graphs( """ unified = build_unified_graph(graphs) - all_names = [node.name for graph in graphs for node in graph.nodes.values()] + all_names = [ + name for graph in graphs for node in graph.nodes.values() for name in node.names + ] all_relations = sorted({edge.relation for graph in graphs for edge in graph.edges}) idf = build_idf(all_names) @@ -564,7 +569,7 @@ def run_matching( click.echo(f"\n{len(match_groups)} match groups:") for members in match_groups: - names = {unified.nodes[eid].name for eid in members} + names = {n for eid in members for n in unified.nodes[eid].names} click.echo(f" {' / '.join(sorted(names))}") click.echo(f"\nWrote {output_path}")