Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions tests/test_graph_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for graph save/load round-trip, including per-node graph_id."""
"""Tests for graph save/load round-trip, including per-node graph_id and multi-label names."""

import json
from pathlib import Path
Expand Down Expand Up @@ -31,9 +31,9 @@ def test_save_load_roundtrip_unified_graph(tmp_path: Path):
"""Unified graph with nodes from different source graphs preserves graph_id."""
g = Graph(id="unified")
# Manually add nodes with different source graph_ids
g.nodes["n1"] = Node(id="n1", graph_id="article-1", name="Alice")
g.nodes["n2"] = Node(id="n2", graph_id="article-2", name="Bob")
g.nodes["n3"] = Node(id="n3", graph_id="unified", name="Carol")
g.nodes["n1"] = Node(id="n1", graph_id="article-1", names=["Alice"])
g.nodes["n2"] = Node(id="n2", graph_id="article-2", names=["Bob"])
g.nodes["n3"] = Node(id="n3", graph_id="unified", names=["Carol"])
g.edges.append(Edge(source="n1", target="n2", relation="knows"))

path = tmp_path / "unified.json"
Expand All @@ -52,3 +52,24 @@ def test_save_load_roundtrip_unified_graph(tmp_path: Path):
assert loaded.nodes["n1"].graph_id == "article-1"
assert loaded.nodes["n2"].graph_id == "article-2"
assert loaded.nodes["n3"].graph_id == "unified"


def test_save_load_roundtrip_multi_label_names(tmp_path: Path):
"""Entities with multiple names survive save/load round-trip."""
g = Graph(id="article-1")
n1 = g.add_entity(["Meridian Technologies", "Meridian Tech"])
n2 = g.add_entity("DataVault")
g.add_edge(n1, n2, "acquired")

path = tmp_path / "g.json"
save_graph(g, path)

with open(path) as f:
data = json.load(f)
node_by_id = {n["id"]: n for n in data["nodes"]}
assert node_by_id[n1.id]["names"] == ["Meridian Technologies", "Meridian Tech"]
assert node_by_id[n2.id]["names"] == ["DataVault"]

loaded = load_graph(path)
assert loaded.nodes[n1.id].names == ["Meridian Technologies", "Meridian Tech"]
assert loaded.nodes[n2.id].names == ["DataVault"]
56 changes: 55 additions & 1 deletion tests/test_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Name variation with structural reinforcement (the core use case)
- Dangling entities get no structural evidence
- Exponential sum accumulates evidence from multiple paths (bidirectional > unidirectional)
- Multi-label entities use max similarity across all names during seeding
"""

from worldgraph.graph import Graph
Expand Down Expand Up @@ -427,7 +428,9 @@ def test_shared_anchor_does_not_override_name_dissimilarity(embedder):
# Premise: name similarity alone is below threshold
from worldgraph.names import build_idf, soft_tfidf

names = [n.name for g in [g1, g2, *bg_graphs] for n in g.nodes.values()]
names = [
name for g in [g1, g2, *bg_graphs] for n in g.nodes.values() for name in n.names
]
idf = build_idf(names)
sv_name_sim = soft_tfidf("Dr. Priya Sharma", "Dr. Elena Vasquez", idf)
assert sv_name_sim < 0.8
Expand Down Expand Up @@ -598,3 +601,54 @@ def test_positive_evidence_is_monotonically_nondecreasing(embedder):
f"at max_iter={n_iter}"
)
prev_conf = curr_conf


# ---------------------------------------------------------------------------
# Multi-label name seeding
# ---------------------------------------------------------------------------


def test_multi_label_entity_uses_best_name_pair(embedder):
"""An entity with multiple names should seed similarity using the best
name pair across both entities' name lists.

"Meridian Technologies" stored as names=["Meridian Technologies"] in g1,
and names=["Meridian Tech", "Meridian Technologies"] in g2. The best
pair is "Meridian Technologies"/"Meridian Technologies" (score ~1.0),
not "Meridian Technologies"/"Meridian Tech" (~0.88).

Without multi-label support, only one name is stored and the closest
pair may be missed, under-estimating similarity."""
g1 = Graph(id="g1")
m1 = g1.add_entity("Meridian Technologies")
dv1 = g1.add_entity("DataVault")
g1.add_edge(m1, dv1, "acquired")

g2 = Graph(id="g2")
m2 = g2.add_entity(["Meridian Tech", "Meridian Technologies"])
dv2 = g2.add_entity("DataVault")
g2.add_edge(m2, dv2, "purchased")

confidence = match_graphs([g1, g2], embedder)

# With multi-label, the best name pair is exact match → seed ~1.0
# Without, if only "Meridian Tech" is stored, seed would be ~0.88
assert confidence[(m1.id, m2.id)] > 0.8


def test_multi_label_all_names_contribute_to_idf(embedder):
"""All names in an entity's name list should contribute to IDF
computation, not just the first."""
g1 = Graph(id="g1")
m1 = g1.add_entity(["Meridian Technologies", "Meridian Tech"])
dv1 = g1.add_entity("DataVault")
g1.add_edge(m1, dv1, "acquired")

g2 = Graph(id="g2")
m2 = g2.add_entity("Meridian Technologies")
dv2 = g2.add_entity("DataVault")
g2.add_edge(m2, dv2, "purchased")

# Should not raise — multi-label names flow through the pipeline
confidence = match_graphs([g1, g2], embedder)
assert confidence[(m1.id, m2.id)] > 0.8
16 changes: 10 additions & 6 deletions worldgraph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Node:
id: str
graph_id: str
name: str
names: list[str]


@dataclass
Expand All @@ -26,9 +26,11 @@ class Graph:
nodes: dict[str, Node] = field(default_factory=dict)
edges: list[Edge] = field(default_factory=list)

def add_entity(self, name: str) -> Node:
"""Add an entity node with the given name."""
entity = Node(id=str(uuid.uuid4()), graph_id=self.id, name=name)
def add_entity(self, names: str | list[str]) -> Node:
"""Add an entity node with the given name(s)."""
if isinstance(names, str):
names = [names]
entity = Node(id=str(uuid.uuid4()), graph_id=self.id, names=names)
self.nodes[entity.id] = entity
return entity

Expand All @@ -50,7 +52,7 @@ def load_graph(path: Path) -> Graph:
nodes[node_id] = Node(
id=node_id,
graph_id=node_data["graph_id"],
name=node_data["name"],
names=node_data["names"],
)

edges: list[Edge] = []
Expand All @@ -74,7 +76,9 @@ def save_graph(
"""Write graph to JSON, with optional match groups."""
nodes_out = []
for node in graph.nodes.values():
nodes_out.append({"id": node.id, "graph_id": node.graph_id, "name": node.name})
nodes_out.append(
{"id": node.id, "graph_id": node.graph_id, "names": node.names}
)

edges_out = []
for edge in graph.edges:
Expand Down
23 changes: 14 additions & 9 deletions worldgraph/match.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Stage 2: Entity alignment via PARIS-style similarity propagation.

Entity names are stored directly on nodes. Name similarity seeds the
confidence dict before the iteration loop, so structural evidence
propagates from iteration 1. Relation similarity is treated as binary
Entity names are stored as lists on nodes (multi-label). Name similarity
seeds the confidence dict before the iteration loop using the max over all
name pairs, so structural evidence propagates from iteration 1. Relation
similarity is treated as binary
via a single threshold that defines equivalence classes over free-text
relation phrases — above threshold = same relation, below = different.
This threshold is used consistently for functionality pooling, positive
Expand Down Expand Up @@ -116,8 +117,8 @@ def compute_functionality(
phrase_pairs: dict[str, list[tuple[str, str]]] = defaultdict(list)
for graph in graphs:
for edge in graph.edges:
source_name = graph.nodes[edge.source].name
target_name = graph.nodes[edge.target].name
source_name = graph.nodes[edge.source].names[0]
target_name = graph.nodes[edge.target].names[0]
phrase_pairs[edge.relation].append((source_name, target_name))

result: dict[str, Functionality] = {}
Expand Down Expand Up @@ -380,9 +381,11 @@ def propagate_similarity(
if graph.nodes[id_a].graph_id == graph.nodes[id_b].graph_id:
continue
name_sim = max(
0.0,
soft_tfidf(graph.nodes[id_a].name, graph.nodes[id_b].name, idf),
soft_tfidf(na, nb, idf)
for na in graph.nodes[id_a].names
for nb in graph.nodes[id_b].names
)
name_sim = max(0.0, name_sim)
confidence[(id_a, id_b)] = name_sim
confidence[(id_b, id_a)] = name_sim
pairs.append((id_a, id_b))
Expand Down Expand Up @@ -487,7 +490,9 @@ def match_graphs(
"""
unified = build_unified_graph(graphs)

all_names = [node.name for graph in graphs for node in graph.nodes.values()]
all_names = [
name for graph in graphs for node in graph.nodes.values() for name in node.names
]
all_relations = sorted({edge.relation for graph in graphs for edge in graph.edges})

idf = build_idf(all_names)
Expand Down Expand Up @@ -564,7 +569,7 @@ def run_matching(

click.echo(f"\n{len(match_groups)} match groups:")
for members in match_groups:
names = {unified.nodes[eid].name for eid in members}
names = {n for eid in members for n in unified.nodes[eid].names}
click.echo(f" {' / '.join(sorted(names))}")

click.echo(f"\nWrote {output_path}")
Loading