From fbc8d52a7586f4795fa1122c4795a93b735c3eb2 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 17 Dec 2025 14:28:22 +0800 Subject: [PATCH 01/29] wip: refactor data evaluators & add kg evaluators --- graphgen/{models/evaluator => bases}/base_evaluator.py | 0 graphgen/models/evaluator/length_evaluator.py | 2 +- graphgen/models/evaluator/mtld_evaluator.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename graphgen/{models/evaluator => bases}/base_evaluator.py (100%) diff --git a/graphgen/models/evaluator/base_evaluator.py b/graphgen/bases/base_evaluator.py similarity index 100% rename from graphgen/models/evaluator/base_evaluator.py rename to graphgen/bases/base_evaluator.py diff --git a/graphgen/models/evaluator/length_evaluator.py b/graphgen/models/evaluator/length_evaluator.py index d5c33211..74716f70 100644 --- a/graphgen/models/evaluator/length_evaluator.py +++ b/graphgen/models/evaluator/length_evaluator.py @@ -1,5 +1,5 @@ +from graphgen.bases.base_evaluator import BaseEvaluator from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.models.tokenizer import Tokenizer from graphgen.utils import create_event_loop diff --git a/graphgen/models/evaluator/mtld_evaluator.py b/graphgen/models/evaluator/mtld_evaluator.py index c106d86c..8503ea4e 100644 --- a/graphgen/models/evaluator/mtld_evaluator.py +++ b/graphgen/models/evaluator/mtld_evaluator.py @@ -1,7 +1,7 @@ from typing import Set +from graphgen.bases.base_evaluator import BaseEvaluator from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language nltk_helper = NLTKHelper() From 18be127d9cb95849724ab10aae2d2e1da0fb1886 Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Tue, 23 Dec 2025 13:26:07 +0800 Subject: [PATCH 02/29] feat: add KG quality evaluation module --- examples/evaluate_kg/evaluate_kg.sh | 6 + graphgen/models/__init__.py | 8 +- graphgen/models/evaluator/__init__.py | 1 + graphgen/models/evaluator/kg/README.md | 117 +++++++ graphgen/models/evaluator/kg/__init__.py | 14 + .../models/evaluator/kg/accuracy_evaluator.py | 297 ++++++++++++++++++ .../evaluator/kg/consistency_evaluator.py | 53 ++++ .../evaluator/kg/structure_evaluator.py | 141 +++++++++ graphgen/models/evaluator/kg/utils.py | 96 ++++++ .../models/evaluator/kg_quality_evaluator.py | 103 ++++++ graphgen/operators/evaluate_kg/__init__.py | 0 graphgen/operators/evaluate_kg/evaluate_kg.py | 208 ++++++++++++ 12 files changed, 1043 insertions(+), 1 deletion(-) create mode 100644 examples/evaluate_kg/evaluate_kg.sh create mode 100644 graphgen/models/evaluator/kg/README.md create mode 100644 graphgen/models/evaluator/kg/__init__.py create mode 100644 graphgen/models/evaluator/kg/accuracy_evaluator.py create mode 100644 graphgen/models/evaluator/kg/consistency_evaluator.py create mode 100644 graphgen/models/evaluator/kg/structure_evaluator.py create mode 100644 graphgen/models/evaluator/kg/utils.py create mode 100644 graphgen/models/evaluator/kg_quality_evaluator.py create mode 100644 graphgen/operators/evaluate_kg/__init__.py create mode 100644 graphgen/operators/evaluate_kg/evaluate_kg.py diff --git a/examples/evaluate_kg/evaluate_kg.sh b/examples/evaluate_kg/evaluate_kg.sh new file mode 100644 index 00000000..a846ee65 --- /dev/null +++ b/examples/evaluate_kg/evaluate_kg.sh @@ -0,0 +1,6 @@ +python3 -m graphgen.operators.evaluate_kg.evaluate_kg \ + --working_dir cache \ + --graph_backend kuzu \ + --kv_backend rocksdb \ + --sample_size 100 \ + --max_concurrent 10 diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 21344d74..127a4314 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,4 +1,10 @@ -from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator +from .evaluator import ( + KGQualityEvaluator, + LengthEvaluator, + MTLDEvaluator, + RewardEvaluator, + UniEvaluator, +) from .generator import ( AggregatedGenerator, AtomicGenerator, diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index a9b445b4..5f2716fc 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -1,3 +1,4 @@ +from .kg_quality_evaluator import KGQualityEvaluator from .length_evaluator import LengthEvaluator from .mtld_evaluator import MTLDEvaluator from .reward_evaluator import RewardEvaluator diff --git a/graphgen/models/evaluator/kg/README.md b/graphgen/models/evaluator/kg/README.md new file mode 100644 index 00000000..71554c40 --- /dev/null +++ b/graphgen/models/evaluator/kg/README.md @@ -0,0 +1,117 @@ +# KG Quality Evaluation Module + +This module provides comprehensive quality evaluation for knowledge graphs built by GraphGen. + +## Module Structure + +The evaluation functionality has been split into modular components: + +- **`accuracy_evaluator.py`**: Entity/relation/triple accuracy evaluation using LLM-as-judge +- **`consistency_evaluator.py`**: Attribute value conflict detection +- **`structure_evaluator.py`**: Graph structural robustness metrics +- **`utils.py`**: Utility functions (NetworkX conversion, text retrieval, sampling) +- **`kg_quality_evaluator.py`**: Main evaluator class that integrates all modules + +## Features + +### 1. Accuracy Assessment +- **Entity Recognition Accuracy**: Samples entities and validates them using LLM +- **Relation Extraction Accuracy**: Samples relations and validates them using LLM +- **Triple Validation (RLC)**: Samples triples and validates them using LLM +- Calculates Precision, Recall, and F1 scores for each metric + +### 2. Consistency Assessment +- Detects attribute value conflicts (same entity, same attribute, different values) +- Calculates conflict rate: `conflict_entities_count / total_entities` +- Returns detailed conflict information + +### 3. Structural Robustness Assessment +- **Noise Ratio**: Isolated nodes / total nodes (threshold: < 15%) +- **Largest Connected Component Ratio**: Largest CC nodes / total nodes (threshold: > 90%) +- **Average Node Degree**: Average degree across all nodes (threshold: 2-5) +- **Power Law Distribution R²**: Degree distribution fit (threshold: > 0.75) + +## Usage + +### Command Line Usage + +```bash +# Run all evaluations +python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache + +# Run specific evaluation +python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache --accuracy_only + +# Custom configuration +python -m graphgen.operators.evaluate_kg.evaluate_kg \ + --working_dir cache \ + --sample_size 200 \ + --graph_backend networkx \ + --kv_backend json_kv +``` + +### Shell Script Usage + +```bash +# Basic usage +bash examples/evaluate_kg/evaluate_kg.sh + +# With custom options +bash examples/evaluate_kg/evaluate_kg.sh \ + --working_dir cache \ + --sample_size 200 \ + --accuracy_only +``` + +## Requirements + +- **NetworkX**: Required for structural evaluation +- **scipy**: Required for power law distribution fitting +- **numpy**: Required for numerical calculations +- **LLM Client**: Required for accuracy evaluation (configure via `TRAINEE_*` env vars) + +## Output Format + +The evaluation returns a dictionary with the following structure: + +```python +{ + "accuracy": { + "entity_accuracy": { + "precision": float, + "recall": float, + "f1": float, + "true_positives": int, + "false_positives": int, + "sample_size": int + }, + "relation_accuracy": { ... }, + "triple_accuracy": { ... } + }, + "consistency": { + "conflict_rate": float, + "conflict_entities_count": int, + "total_entities": int, + "conflicts": [ ... ] + }, + "structure": { + "total_nodes": int, + "total_edges": int, + "noise_ratio": float, + "largest_cc_ratio": float, + "avg_degree": float, + "powerlaw_r2": float | None, + "thresholds": { + "noise_ratio": { "value": float, "threshold": float, "pass": bool }, + ... + } + } +} +``` + +## Notes + +- Accuracy evaluation requires LLM API access and may be slow for large sample sizes +- Structural evaluation automatically converts Kuzu storage to NetworkX for analysis +- All evaluations include error handling and will return error messages if something fails +- The evaluator automatically loads graph and chunk storage from the working directory diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py new file mode 100644 index 00000000..007f0c9d --- /dev/null +++ b/graphgen/models/evaluator/kg/__init__.py @@ -0,0 +1,14 @@ +from .accuracy_evaluator import AccuracyEvaluator +from .consistency_evaluator import ConsistencyEvaluator +from .structure_evaluator import StructureEvaluator +from .utils import convert_to_networkx, get_relevant_text, get_source_text, sample_items + +__all__ = [ + "AccuracyEvaluator", + "ConsistencyEvaluator", + "StructureEvaluator", + "convert_to_networkx", + "get_relevant_text", + "get_source_text", + "sample_items", +] diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py new file mode 100644 index 00000000..4066c92f --- /dev/null +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -0,0 +1,297 @@ +import asyncio +from typing import Any, Dict, List, Optional, Tuple + +from tqdm.asyncio import tqdm as tqdm_async + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.models.evaluator.kg.utils import get_relevant_text, sample_items +from graphgen.utils import create_event_loop, logger + + +class AccuracyEvaluator: + """Evaluates accuracy of entity recognition, relation extraction, and triple validation.""" + + def __init__( + self, + graph_storage: BaseGraphStorage, + chunk_storage: BaseKVStorage, + llm_client: BaseLLMWrapper, + sample_size: int = 100, + max_concurrent: int = 10, + ): + self.graph_storage = graph_storage + self.chunk_storage = chunk_storage + self.llm_client = llm_client + self.sample_size = sample_size + self.max_concurrent = max_concurrent + + def evaluate(self) -> Dict[str, Any]: + # Get all nodes and edges + all_nodes = self.graph_storage.get_all_nodes() or [] + all_edges = self.graph_storage.get_all_edges() or [] + + if not all_nodes and not all_edges: + return {"error": "Empty graph"} + + # Sample entities, relations, and triples + entity_samples = sample_items(all_nodes, self.sample_size) + relation_samples = sample_items(all_edges, self.sample_size) + triple_samples = sample_items(all_edges, self.sample_size) + + # Evaluate each type (async) + loop = create_event_loop() + entity_results = loop.run_until_complete(self._evaluate_entities(entity_samples)) + relation_results = loop.run_until_complete( + self._evaluate_relations(relation_samples) + ) + triple_results = loop.run_until_complete(self._evaluate_triples(triple_samples)) + + return { + "entity_accuracy": entity_results, + "relation_accuracy": relation_results, + "triple_accuracy": triple_results, + } + + async def _evaluate_entities( + self, entity_samples: List[Tuple[str, Dict]] + ) -> Dict[str, float]: + """Evaluate entity recognition accuracy.""" + source_text = get_relevant_text(self.chunk_storage) + + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def verify_with_semaphore(entity_sample): + async with semaphore: + entity_id, entity_data = entity_sample + return await self._verify_entity_with_llm( + entity_id, entity_data, source_text + ) + + results = [] + tasks = [verify_with_semaphore(sample) for sample in entity_samples] + for coro in tqdm_async( + asyncio.as_completed(tasks), total=len(tasks), desc="Verifying entities" + ): + result = await coro + results.append(result) + + # Calculate metrics + tp = sum(results) + fp = len(results) - tp + precision = tp / len(results) if results else 0.0 + recall = precision # Approximation: assume all sampled are relevant + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": tp, + "false_positives": fp, + "sample_size": len(results), + } + + async def _evaluate_relations( + self, relation_samples: List[Tuple[str, str, Dict]] + ) -> Dict[str, float]: + """Evaluate relation extraction accuracy.""" + source_text = get_relevant_text(self.chunk_storage) + + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def verify_with_semaphore(relation_sample): + async with semaphore: + src_id, dst_id, edge_data = relation_sample + return await self._verify_relation_with_llm( + src_id, dst_id, edge_data, source_text + ) + + results = [] + tasks = [verify_with_semaphore(sample) for sample in relation_samples] + for coro in tqdm_async( + asyncio.as_completed(tasks), total=len(tasks), desc="Verifying relations" + ): + result = await coro + results.append(result) + + tp = sum(results) + fp = len(results) - tp + precision = tp / len(results) if results else 0.0 + recall = precision + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": tp, + "false_positives": fp, + "sample_size": len(results), + } + + async def _evaluate_triples( + self, triple_samples: List[Tuple[str, str, Dict]] + ) -> Dict[str, float]: + """Evaluate triple validation accuracy (RLC).""" + source_text = get_relevant_text(self.chunk_storage) + + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def verify_with_semaphore(triple_sample): + async with semaphore: + src_id, dst_id, edge_data = triple_sample + return await self._verify_triple_with_llm( + src_id, dst_id, edge_data, source_text + ) + + results = [] + tasks = [verify_with_semaphore(sample) for sample in triple_samples] + for coro in tqdm_async( + asyncio.as_completed(tasks), total=len(tasks), desc="Verifying triples" + ): + result = await coro + results.append(result) + + tp = sum(results) + fp = len(results) - tp + precision = tp / len(results) if results else 0.0 + recall = precision + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": tp, + "false_positives": fp, + "sample_size": len(results), + } + + async def _verify_entity_with_llm( + self, entity_id: str, entity_data: Dict, source_text: str + ) -> bool: + """Verify entity correctness using LLM.""" + entity_name = entity_data.get("entity_name", entity_id) + entity_type = entity_data.get("entity_type", "unknown") + entity_summary = entity_data.get("entity_summary", entity_data.get("description", "")) + + # Try to get relevant text from source_id + source_id = entity_data.get("source_id") + if source_id: + relevant_text = get_relevant_text(self.chunk_storage, source_id) + if relevant_text: + source_text = relevant_text + + prompt = f"""给定以下文本和实体信息,请判断该实体是否在文本中正确识别。 + +文本:{source_text[:2000]} + +实体名称:{entity_name} +实体类型:{entity_type} +实体描述:{entity_summary} + +请回答:该实体是否在文本中正确识别?回答"是"或"否",并简要说明理由。""" + + try: + response = await self.llm_client.generate_answer(prompt) + response_lower = response.lower() + return ( + "是" in response_lower + or "yes" in response_lower + or "正确" in response_lower + ) + except Exception as e: + logger.error(f"LLM verification failed for entity {entity_id}: {e}") + return False + + async def _verify_relation_with_llm( + self, src_id: str, dst_id: str, edge_data: Dict, source_text: str + ) -> bool: + """Verify relation correctness using LLM.""" + src_node = self.graph_storage.get_node(src_id) or {} + dst_node = self.graph_storage.get_node(dst_id) or {} + source_entity = src_node.get("entity_name", src_id) + target_entity = dst_node.get("entity_name", dst_id) + relationship_summary = edge_data.get( + "relationship_summary", edge_data.get("description", "") + ) + + # Try to get relevant text from source_id + source_id = edge_data.get("source_id") + if source_id: + relevant_text = get_relevant_text(self.chunk_storage, source_id) + if relevant_text: + source_text = relevant_text + + prompt = f"""给定以下文本和关系信息,请判断该关系是否在文本中正确抽取。 + +文本:{source_text[:2000]} + +源实体:{source_entity} +目标实体:{target_entity} +关系描述:{relationship_summary} + +请回答:该关系是否在文本中正确抽取?回答"是"或"否",并简要说明理由。""" + + try: + response = await self.llm_client.generate_answer(prompt) + response_lower = response.lower() + return ( + "是" in response_lower + or "yes" in response_lower + or "正确" in response_lower + ) + except Exception as e: + logger.error( + f"LLM verification failed for relation {src_id}->{dst_id}: {e}" + ) + return False + + async def _verify_triple_with_llm( + self, src_id: str, dst_id: str, edge_data: Dict, source_text: str + ) -> bool: + """Verify triple correctness using LLM.""" + src_node = self.graph_storage.get_node(src_id) or {} + dst_node = self.graph_storage.get_node(dst_id) or {} + head = src_node.get("entity_name", src_id) + tail = dst_node.get("entity_name", dst_id) + relation = edge_data.get("relationship_summary", edge_data.get("description", "")) + + # Try to get relevant text from source_id + source_id = edge_data.get("source_id") + if source_id: + relevant_text = get_relevant_text(self.chunk_storage, source_id) + if relevant_text: + source_text = relevant_text + + prompt = f"""给定以下文本和三元组,请判断该三元组是否正确。 + +文本:{source_text[:2000]} + +三元组:(头实体: {head}, 关系: {relation}, 尾实体: {tail}) + +请回答:该三元组是否正确?回答"是"或"否",并简要说明理由。""" + + try: + response = await self.llm_client.generate_answer(prompt) + response_lower = response.lower() + return ( + "是" in response_lower + or "yes" in response_lower + or "正确" in response_lower + ) + except Exception as e: + logger.error(f"LLM verification failed for triple {src_id}->{dst_id}: {e}") + return False diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py new file mode 100644 index 00000000..9beff1f4 --- /dev/null +++ b/graphgen/models/evaluator/kg/consistency_evaluator.py @@ -0,0 +1,53 @@ +from typing import Any, Dict, List, Tuple + +from graphgen.bases import BaseGraphStorage + + +class ConsistencyEvaluator: + """Evaluates consistency by detecting attribute value conflicts.""" + + def __init__(self, graph_storage: BaseGraphStorage): + self.graph_storage = graph_storage + + def evaluate(self) -> Dict[str, Any]: + all_nodes = self.graph_storage.get_all_nodes() or [] + if not all_nodes: + return {"error": "Empty graph"} + + conflicts = [] + conflict_entities = set() + + for node_id, node_data in all_nodes: + if not isinstance(node_data, dict): + continue + + # Check each attribute for multiple values + for attr_key, attr_value in node_data.items(): + # Skip special keys + if attr_key.startswith("_") or attr_key in ["id", "loss"]: + continue + + # If attribute has multiple values (list), check for conflicts + if isinstance(attr_value, list): + unique_values = set(str(v) for v in attr_value if v) + if len(unique_values) > 1: + conflicts.append( + { + "entity_id": node_id, + "attribute": attr_key, + "values": list(unique_values), + } + ) + conflict_entities.add(node_id) + + total_entities = len(all_nodes) + conflict_rate = ( + len(conflict_entities) / total_entities if total_entities > 0 else 0 + ) + + return { + "conflict_rate": conflict_rate, + "conflict_entities_count": len(conflict_entities), + "total_entities": total_entities, + "conflicts": conflicts[:100], # Limit to first 100 conflicts + } diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py new file mode 100644 index 00000000..44a8cd25 --- /dev/null +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -0,0 +1,141 @@ +from typing import Any, Dict, Optional + +try: + import networkx as nx +except ImportError: + nx = None + +try: + from scipy import stats +except ImportError: + stats = None + +import numpy as np + +from graphgen.bases import BaseGraphStorage +from graphgen.models.evaluator.kg.utils import convert_to_networkx +from graphgen.utils import logger + + +class StructureEvaluator: + """Evaluates structural robustness of the graph.""" + + def __init__(self, graph_storage: BaseGraphStorage): + self.graph_storage = graph_storage + + def evaluate(self) -> Dict[str, Any]: + if nx is None: + return {"error": "NetworkX not installed"} + + # Convert graph to NetworkX + G = convert_to_networkx(self.graph_storage) + + if G.number_of_nodes() == 0: + return {"error": "Empty graph"} + + # Calculate metrics + total_nodes = G.number_of_nodes() + total_edges = G.number_of_edges() + + # Noise ratio: isolated nodes / total nodes + isolated_nodes = [n for n in G.nodes() if G.degree(n) == 0] + noise_ratio = len(isolated_nodes) / total_nodes if total_nodes > 0 else 0 + + # Largest connected component + if G.is_directed(): + G_undirected = G.to_undirected() + else: + G_undirected = G + + connected_components = list(nx.connected_components(G_undirected)) + if connected_components: + largest_cc = max(connected_components, key=len) + largest_cc_ratio = ( + len(largest_cc) / total_nodes if total_nodes > 0 else 0 + ) + else: + largest_cc_ratio = 0 + + # Average node degree + if total_nodes > 0: + total_degree = sum(G.degree(n) for n in G.nodes()) + avg_degree = total_degree / total_nodes + else: + avg_degree = 0 + + # Power law distribution R² + powerlaw_r2 = self._calculate_powerlaw_r2(G) + + # Check thresholds + thresholds = { + "noise_ratio": { + "value": noise_ratio, + "threshold": 0.15, + "pass": noise_ratio < 0.15, + }, + "largest_cc_ratio": { + "value": largest_cc_ratio, + "threshold": 0.90, + "pass": largest_cc_ratio > 0.90, + }, + "avg_degree": { + "value": avg_degree, + "threshold": (2, 5), + "pass": 2 <= avg_degree <= 5, + }, + "powerlaw_r2": { + "value": powerlaw_r2, + "threshold": 0.75, + "pass": powerlaw_r2 > 0.75 if powerlaw_r2 is not None else False, + }, + } + + return { + "total_nodes": total_nodes, + "total_edges": total_edges, + "isolated_nodes_count": len(isolated_nodes), + "noise_ratio": noise_ratio, + "largest_cc_ratio": largest_cc_ratio, + "avg_degree": avg_degree, + "powerlaw_r2": powerlaw_r2, + "thresholds": thresholds, + } + + def _calculate_powerlaw_r2(self, G: "nx.Graph") -> Optional[float]: + """ + Calculate R² for power law distribution of node degrees. + + Returns: + R² value if calculation successful, None otherwise + """ + if stats is None: + logger.warning("scipy not available, skipping power law R² calculation") + return None + + degrees = [G.degree(n) for n in G.nodes()] + if len(degrees) < 10: # Need sufficient data points + logger.warning("Insufficient nodes for power law fitting") + return None + + # Filter out zero degrees for log fitting + non_zero_degrees = [d for d in degrees if d > 0] + if len(non_zero_degrees) < 5: + return None + + try: + # Fit power law: log(y) = a * log(x) + b + log_degrees = np.log(non_zero_degrees) + sorted_log_degrees = np.sort(log_degrees) + x = np.arange(1, len(sorted_log_degrees) + 1) + log_x = np.log(x) + + # Linear regression on log-log scale + slope, intercept, r_value, p_value, std_err = stats.linregress( + log_x, sorted_log_degrees + ) + r2 = r_value ** 2 + + return float(r2) + except Exception as e: + logger.error(f"Power law R² calculation failed: {e}") + return None diff --git a/graphgen/models/evaluator/kg/utils.py b/graphgen/models/evaluator/kg/utils.py new file mode 100644 index 00000000..64963d2c --- /dev/null +++ b/graphgen/models/evaluator/kg/utils.py @@ -0,0 +1,96 @@ +from typing import Any, List, Optional + +try: + import networkx as nx +except ImportError: + nx = None + +from graphgen.bases import BaseGraphStorage, BaseKVStorage + + +def convert_to_networkx(graph_storage: BaseGraphStorage) -> "nx.Graph": + """Convert graph storage to NetworkX graph.""" + if nx is None: + raise ImportError("NetworkX is required for structural evaluation") + + G = nx.DiGraph() + + # Add nodes + nodes = graph_storage.get_all_nodes() or [] + for node_id, node_data in nodes: + if isinstance(node_data, dict): + G.add_node(node_id, **node_data) + else: + G.add_node(node_id) + + # Add edges + edges = graph_storage.get_all_edges() or [] + for src, dst, edge_data in edges: + if isinstance(edge_data, dict): + G.add_edge(src, dst, **edge_data) + else: + G.add_edge(src, dst) + + return G + + +def get_source_text(chunk_storage: BaseKVStorage, chunk_id: Optional[str] = None) -> str: + """ + Get source text from chunk storage. + + Args: + chunk_storage: KV storage containing chunks + chunk_id: Optional chunk ID. If None, returns concatenated text from all chunks. + + Returns: + Source text content + """ + if chunk_id: + chunk = chunk_storage.get_by_id(chunk_id) + if chunk and isinstance(chunk, dict): + return chunk.get("content", "") + return "" + + # Get all chunks and concatenate + all_chunks = chunk_storage.get_all() + texts = [] + for chunk_data in all_chunks.values(): + if isinstance(chunk_data, dict): + content = chunk_data.get("content", "") + if content: + texts.append(content) + return "\n\n".join(texts) + + +def get_relevant_text( + chunk_storage: BaseKVStorage, source_id: Optional[str] = None +) -> str: + """Get relevant source text from chunk storage using source_id.""" + if source_id: + # Try to get specific chunk + chunk = chunk_storage.get_by_id(source_id) + if chunk and isinstance(chunk, dict): + return chunk.get("content", "") + # If source_id contains , try multiple chunks + if "" in str(source_id): + chunk_ids = [sid.strip() for sid in str(source_id).split("") if sid.strip()] + texts = [] + for cid in chunk_ids: + chunk = chunk_storage.get_by_id(cid) + if chunk and isinstance(chunk, dict): + content = chunk.get("content", "") + if content: + texts.append(content) + return "\n\n".join(texts) if texts else "" + + # Fallback to all chunks + return get_source_text(chunk_storage) + + +def sample_items(items: List[Any], sample_size: int) -> List[Any]: + """Sample items from a list.""" + import random + + if len(items) <= sample_size: + return items + return random.sample(items, sample_size) diff --git a/graphgen/models/evaluator/kg_quality_evaluator.py b/graphgen/models/evaluator/kg_quality_evaluator.py new file mode 100644 index 00000000..0bb945be --- /dev/null +++ b/graphgen/models/evaluator/kg_quality_evaluator.py @@ -0,0 +1,103 @@ +""" +Knowledge Graph Quality Evaluator + +This module provides comprehensive quality evaluation for knowledge graphs, +1. accuracy assessment (entity/relation/triple validation), +2. consistency assessment (attribute conflict detection), and structural +3. robustness assessment (noise ratio, connectivity, degree distribution). +""" + +import os +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.models.evaluator.kg import ( + AccuracyEvaluator, + ConsistencyEvaluator, + StructureEvaluator, +) +from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger + + +@dataclass +class KGQualityEvaluator: + """Knowledge Graph Quality Evaluator.""" + + working_dir: str = "cache" + graph_backend: str = "kuzu" + kv_backend: str = "rocksdb" + llm_client: Optional[BaseLLMWrapper] = None + sample_size: int = 100 + max_concurrent: int = 10 + + def __post_init__(self): + """Initialize storage and LLM client.""" + # Lazy import to avoid circular dependency + from graphgen.common import init_llm, init_storage + + self.graph_storage: BaseGraphStorage = init_storage( + backend=self.graph_backend, + working_dir=self.working_dir, + namespace="graph", + ) + self.chunk_storage: BaseKVStorage = init_storage( + backend=self.kv_backend, + working_dir=self.working_dir, + namespace="chunk", + ) + + if self.llm_client is None: + self.llm_client = init_llm("trainee") + + def evaluate_all(self) -> Dict[str, Any]: + """Run all evaluation metrics and return comprehensive report.""" + CURRENT_LOGGER_VAR.get() + results = {} + + try: + logger.info("Starting accuracy evaluation...") + results["accuracy"] = self.evaluate_accuracy() + except Exception as e: + logger.error(f"Accuracy evaluation failed: {e}") + results["accuracy"] = {"error": str(e)} + + # Consistency evaluation + try: + logger.info("Starting consistency evaluation...") + consistency_evaluator = ConsistencyEvaluator( + graph_storage=self.graph_storage + ) + results["consistency"] = consistency_evaluator.evaluate() + except Exception as e: + logger.error(f"Consistency evaluation failed: {e}") + results["consistency"] = {"error": str(e)} + + # Structural robustness evaluation + try: + logger.info("Starting structural robustness evaluation...") + structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) + results["structure"] = structure_evaluator.evaluate() + except Exception as e: + logger.error(f"Structural evaluation failed: {e}") + results["structure"] = {"error": str(e)} + + return results + + def evaluate_accuracy(self) -> Dict[str, Any]: + accuracy_evaluator = AccuracyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + sample_size=self.sample_size, + max_concurrent=self.max_concurrent, + ) + return accuracy_evaluator.evaluate() + + def evaluate_consistency(self) -> Dict[str, Any]: + consistency_evaluator = ConsistencyEvaluator(graph_storage=self.graph_storage) + return consistency_evaluator.evaluate() + + def evaluate_structure(self) -> Dict[str, Any]: + structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) + return structure_evaluator.evaluate() diff --git a/graphgen/operators/evaluate_kg/__init__.py b/graphgen/operators/evaluate_kg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/evaluate_kg/evaluate_kg.py b/graphgen/operators/evaluate_kg/evaluate_kg.py new file mode 100644 index 00000000..ab8a8031 --- /dev/null +++ b/graphgen/operators/evaluate_kg/evaluate_kg.py @@ -0,0 +1,208 @@ +import argparse +import json +import os +from pathlib import Path + +from dotenv import load_dotenv + +from graphgen.models import KGQualityEvaluator +from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger + +# Load environment variables +load_dotenv() + + +def main(): + """Main function to run KG quality evaluation.""" + parser = argparse.ArgumentParser( + description="Evaluate knowledge graph quality", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic evaluation + python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache + + # Custom sample size and output + python -m graphgen.operators.evaluate_kg.evaluate_kg \\ + --working_dir cache \\ + --sample_size 200 \\ + --output cache/kg_evaluation.json + + # Specify backends + python -m graphgen.operators.evaluate_kg.evaluate_kg \\ + --working_dir cache \\ + --graph_backend networkx \\ + --kv_backend json_kv + """, + ) + + parser.add_argument( + "--working_dir", + type=str, + default="cache", + help="Working directory containing graph and chunk storage (default: cache)", + ) + parser.add_argument( + "--graph_backend", + type=str, + default="kuzu", + choices=["kuzu", "networkx"], + help="Graph storage backend (default: kuzu)", + ) + parser.add_argument( + "--kv_backend", + type=str, + default="rocksdb", + choices=["rocksdb", "json_kv"], + help="KV storage backend (default: rocksdb)", + ) + parser.add_argument( + "--sample_size", + type=int, + default=100, + help="Sample size for accuracy evaluation (default: 100)", + ) + parser.add_argument( + "--max_concurrent", + type=int, + default=10, + help="Maximum concurrent LLM requests (default: 10)", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output file path for evaluation results (default: working_dir/kg_evaluation.json)", + ) + parser.add_argument( + "--accuracy_only", + action="store_true", + help="Only run accuracy evaluation", + ) + parser.add_argument( + "--consistency_only", + action="store_true", + help="Only run consistency evaluation", + ) + parser.add_argument( + "--structure_only", + action="store_true", + help="Only run structural robustness evaluation", + ) + + args = parser.parse_args() + + # Set up logging + log_dir = Path(args.working_dir) / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + default_logger = set_logger(str(log_dir / "evaluate_kg.log"), name="evaluate_kg") + CURRENT_LOGGER_VAR.set(default_logger) + + # Determine output path + if args.output is None: + output_path = Path(args.working_dir) / "kg_evaluation.json" + else: + output_path = Path(args.output) + + logger.info("Starting KG quality evaluation...") + logger.info(f"Working directory: {args.working_dir}") + logger.info(f"Graph backend: {args.graph_backend}") + logger.info(f"KV backend: {args.kv_backend}") + logger.info(f"Sample size: {args.sample_size}") + + # Initialize evaluator + try: + evaluator = KGQualityEvaluator( + working_dir=args.working_dir, + graph_backend=args.graph_backend, + kv_backend=args.kv_backend, + sample_size=args.sample_size, + max_concurrent=args.max_concurrent, + ) + except Exception as e: + logger.error(f"Failed to initialize evaluator: {e}") + raise + + # Run evaluation + try: + if args.accuracy_only: + logger.info("Running accuracy evaluation only...") + results = {"accuracy": evaluator.evaluate_accuracy()} + elif args.consistency_only: + logger.info("Running consistency evaluation only...") + results = {"consistency": evaluator.evaluate_consistency()} + elif args.structure_only: + logger.info("Running structural robustness evaluation only...") + results = {"structure": evaluator.evaluate_structure()} + else: + logger.info("Running all evaluations...") + results = evaluator.evaluate_all() + + # Save results + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + logger.info(f"Evaluation completed. Results saved to: {output_path}") + + # Print summary + print("\n" + "=" * 60) + print("KG Quality Evaluation Summary") + print("=" * 60) + + if "accuracy" in results: + acc = results["accuracy"] + if "error" not in acc: + print("\n[Accuracy]") + if "entity_accuracy" in acc: + e = acc["entity_accuracy"] + print(f" Entity - Precision: {e.get('precision', 0):.3f}, " + f"Recall: {e.get('recall', 0):.3f}, F1: {e.get('f1', 0):.3f}") + if "relation_accuracy" in acc: + r = acc["relation_accuracy"] + print(f" Relation - Precision: {r.get('precision', 0):.3f}, " + f"Recall: {r.get('recall', 0):.3f}, F1: {r.get('f1', 0):.3f}") + if "triple_accuracy" in acc: + t = acc["triple_accuracy"] + print(f" Triple (RLC) - Precision: {t.get('precision', 0):.3f}, " + f"Recall: {t.get('recall', 0):.3f}, F1: {t.get('f1', 0):.3f}") + else: + print(f"\n[Accuracy] Error: {acc['error']}") + + if "consistency" in results: + cons = results["consistency"] + if "error" not in cons: + print("\n[Consistency]") + print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}") + print(f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / " + f"{cons.get('total_entities', 0)}") + else: + print(f"\n[Consistency] Error: {cons['error']}") + + if "structure" in results: + struct = results["structure"] + if "error" not in struct: + print("\n[Structural Robustness]") + print(f" Total Nodes: {struct.get('total_nodes', 0)}") + print(f" Total Edges: {struct.get('total_edges', 0)}") + print(f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} " + f"({'✓' if struct.get('noise_ratio', 1) < 0.15 else '✗'} < 15%)") + print(f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} " + f"({'✓' if struct.get('largest_cc_ratio', 0) > 0.90 else '✗'} > 90%)") + print(f" Avg Degree: {struct.get('avg_degree', 0):.2f} " + f"({'✓' if 2 <= struct.get('avg_degree', 0) <= 5 else '✗'} 2-5)") + if struct.get('powerlaw_r2') is not None: + print(f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} " + f"({'✓' if struct.get('powerlaw_r2', 0) > 0.75 else '✗'} > 0.75)") + else: + print(f"\n[Structural Robustness] Error: {struct['error']}") + + print("\n" + "=" * 60) + + except Exception as e: + logger.error(f"Evaluation failed: {e}", exc_info=True) + raise + + +if __name__ == "__main__": + main() From a44b1f3e670f8a75cd34eda587aa5d6c8486c859 Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Tue, 23 Dec 2025 14:11:06 +0800 Subject: [PATCH 03/29] refactor: removed repeated calculations and remove hardcoded params --- .../models/evaluator/kg/accuracy_evaluator.py | 170 +++++------------ .../evaluator/kg/consistency_evaluator.py | 2 +- .../evaluator/kg/structure_evaluator.py | 27 +-- .../models/evaluator/kg_quality_evaluator.py | 3 +- graphgen/operators/evaluate_kg/evaluate_kg.py | 171 +++++++++++------- 5 files changed, 167 insertions(+), 206 deletions(-) diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index 4066c92f..357305a9 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple from tqdm.asyncio import tqdm as tqdm_async @@ -9,7 +9,12 @@ class AccuracyEvaluator: - """Evaluates accuracy of entity recognition, relation extraction, and triple validation.""" + """Evaluates accuracy of entity recognition, relation extraction, and triple validation. + + Note: Recall is approximated as equal to precision since we cannot calculate true recall + (TP / (TP + FN)) without complete ground truth. The F1 score is therefore equal to precision. + Only precision should be considered as the primary metric. + """ def __init__( self, @@ -33,22 +38,17 @@ def evaluate(self) -> Dict[str, Any]: if not all_nodes and not all_edges: return {"error": "Empty graph"} - # Sample entities, relations, and triples + # Sample entities and triples (edges) entity_samples = sample_items(all_nodes, self.sample_size) - relation_samples = sample_items(all_edges, self.sample_size) triple_samples = sample_items(all_edges, self.sample_size) # Evaluate each type (async) loop = create_event_loop() entity_results = loop.run_until_complete(self._evaluate_entities(entity_samples)) - relation_results = loop.run_until_complete( - self._evaluate_relations(relation_samples) - ) triple_results = loop.run_until_complete(self._evaluate_triples(triple_samples)) return { "entity_accuracy": entity_results, - "relation_accuracy": relation_results, "triple_accuracy": triple_results, } @@ -75,67 +75,7 @@ async def verify_with_semaphore(entity_sample): result = await coro results.append(result) - # Calculate metrics - tp = sum(results) - fp = len(results) - tp - precision = tp / len(results) if results else 0.0 - recall = precision # Approximation: assume all sampled are relevant - f1 = ( - 2 * precision * recall / (precision + recall) - if (precision + recall) > 0 - else 0.0 - ) - - return { - "precision": precision, - "recall": recall, - "f1": f1, - "true_positives": tp, - "false_positives": fp, - "sample_size": len(results), - } - - async def _evaluate_relations( - self, relation_samples: List[Tuple[str, str, Dict]] - ) -> Dict[str, float]: - """Evaluate relation extraction accuracy.""" - source_text = get_relevant_text(self.chunk_storage) - - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def verify_with_semaphore(relation_sample): - async with semaphore: - src_id, dst_id, edge_data = relation_sample - return await self._verify_relation_with_llm( - src_id, dst_id, edge_data, source_text - ) - - results = [] - tasks = [verify_with_semaphore(sample) for sample in relation_samples] - for coro in tqdm_async( - asyncio.as_completed(tasks), total=len(tasks), desc="Verifying relations" - ): - result = await coro - results.append(result) - - tp = sum(results) - fp = len(results) - tp - precision = tp / len(results) if results else 0.0 - recall = precision - f1 = ( - 2 * precision * recall / (precision + recall) - if (precision + recall) > 0 - else 0.0 - ) - - return { - "precision": precision, - "recall": recall, - "f1": f1, - "true_positives": tp, - "false_positives": fp, - "sample_size": len(results), - } + return self._calculate_metrics(results) async def _evaluate_triples( self, triple_samples: List[Tuple[str, str, Dict]] @@ -160,24 +100,7 @@ async def verify_with_semaphore(triple_sample): result = await coro results.append(result) - tp = sum(results) - fp = len(results) - tp - precision = tp / len(results) if results else 0.0 - recall = precision - f1 = ( - 2 * precision * recall / (precision + recall) - if (precision + recall) > 0 - else 0.0 - ) - - return { - "precision": precision, - "recall": recall, - "f1": f1, - "true_positives": tp, - "false_positives": fp, - "sample_size": len(results), - } + return self._calculate_metrics(results) async def _verify_entity_with_llm( self, entity_id: str, entity_data: Dict, source_text: str @@ -216,48 +139,41 @@ async def _verify_entity_with_llm( logger.error(f"LLM verification failed for entity {entity_id}: {e}") return False - async def _verify_relation_with_llm( - self, src_id: str, dst_id: str, edge_data: Dict, source_text: str - ) -> bool: - """Verify relation correctness using LLM.""" - src_node = self.graph_storage.get_node(src_id) or {} - dst_node = self.graph_storage.get_node(dst_id) or {} - source_entity = src_node.get("entity_name", src_id) - target_entity = dst_node.get("entity_name", dst_id) - relationship_summary = edge_data.get( - "relationship_summary", edge_data.get("description", "") + def _calculate_metrics(self, results: List[bool]) -> Dict[str, float]: + """ + Calculate precision, recall, and F1 score from boolean verification results. + + Note: Recall is approximated as equal to precision since we cannot calculate + true recall (TP / (TP + FN)) without complete ground truth. The F1 score + is therefore equal to precision. Only precision should be considered as the + primary metric. + + Args: + results: List of boolean values indicating verification results (True = correct) + + Returns: + Dictionary containing precision, recall, f1, true_positives, false_positives, + and sample_size + """ + tp = sum(results) + fp = len(results) - tp + precision = tp / len(results) if results else 0.0 + # Approximation: assume all sampled are relevant (cannot calculate true recall) + recall = precision + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0.0 ) - # Try to get relevant text from source_id - source_id = edge_data.get("source_id") - if source_id: - relevant_text = get_relevant_text(self.chunk_storage, source_id) - if relevant_text: - source_text = relevant_text - - prompt = f"""给定以下文本和关系信息,请判断该关系是否在文本中正确抽取。 - -文本:{source_text[:2000]} - -源实体:{source_entity} -目标实体:{target_entity} -关系描述:{relationship_summary} - -请回答:该关系是否在文本中正确抽取?回答"是"或"否",并简要说明理由。""" - - try: - response = await self.llm_client.generate_answer(prompt) - response_lower = response.lower() - return ( - "是" in response_lower - or "yes" in response_lower - or "正确" in response_lower - ) - except Exception as e: - logger.error( - f"LLM verification failed for relation {src_id}->{dst_id}: {e}" - ) - return False + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": tp, + "false_positives": fp, + "sample_size": len(results), + } async def _verify_triple_with_llm( self, src_id: str, dst_id: str, edge_data: Dict, source_text: str diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py index 9beff1f4..05d27d96 100644 --- a/graphgen/models/evaluator/kg/consistency_evaluator.py +++ b/graphgen/models/evaluator/kg/consistency_evaluator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict from graphgen.bases import BaseGraphStorage diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py index 44a8cd25..b8c5028d 100644 --- a/graphgen/models/evaluator/kg/structure_evaluator.py +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -20,6 +20,13 @@ class StructureEvaluator: """Evaluates structural robustness of the graph.""" + # Thresholds for structural metrics + NOISE_RATIO_THRESHOLD = 0.15 + LARGEST_CC_RATIO_THRESHOLD = 0.90 + AVG_DEGREE_MIN = 2 + AVG_DEGREE_MAX = 5 + POWERLAW_R2_THRESHOLD = 0.75 + def __init__(self, graph_storage: BaseGraphStorage): self.graph_storage = graph_storage @@ -70,23 +77,23 @@ def evaluate(self) -> Dict[str, Any]: thresholds = { "noise_ratio": { "value": noise_ratio, - "threshold": 0.15, - "pass": noise_ratio < 0.15, + "threshold": self.NOISE_RATIO_THRESHOLD, + "pass": noise_ratio < self.NOISE_RATIO_THRESHOLD, }, "largest_cc_ratio": { "value": largest_cc_ratio, - "threshold": 0.90, - "pass": largest_cc_ratio > 0.90, + "threshold": self.LARGEST_CC_RATIO_THRESHOLD, + "pass": largest_cc_ratio > self.LARGEST_CC_RATIO_THRESHOLD, }, "avg_degree": { "value": avg_degree, - "threshold": (2, 5), - "pass": 2 <= avg_degree <= 5, + "threshold": (self.AVG_DEGREE_MIN, self.AVG_DEGREE_MAX), + "pass": self.AVG_DEGREE_MIN <= avg_degree <= self.AVG_DEGREE_MAX, }, "powerlaw_r2": { "value": powerlaw_r2, - "threshold": 0.75, - "pass": powerlaw_r2 > 0.75 if powerlaw_r2 is not None else False, + "threshold": self.POWERLAW_R2_THRESHOLD, + "pass": powerlaw_r2 > self.POWERLAW_R2_THRESHOLD if powerlaw_r2 is not None else False, }, } @@ -130,9 +137,7 @@ def _calculate_powerlaw_r2(self, G: "nx.Graph") -> Optional[float]: log_x = np.log(x) # Linear regression on log-log scale - slope, intercept, r_value, p_value, std_err = stats.linregress( - log_x, sorted_log_degrees - ) + r_value, *_ = stats.linregress(log_x, sorted_log_degrees) r2 = r_value ** 2 return float(r2) diff --git a/graphgen/models/evaluator/kg_quality_evaluator.py b/graphgen/models/evaluator/kg_quality_evaluator.py index 0bb945be..23ae6187 100644 --- a/graphgen/models/evaluator/kg_quality_evaluator.py +++ b/graphgen/models/evaluator/kg_quality_evaluator.py @@ -7,7 +7,6 @@ 3. robustness assessment (noise ratio, connectivity, degree distribution). """ -import os from dataclasses import dataclass from typing import Any, Dict, Optional @@ -17,7 +16,7 @@ ConsistencyEvaluator, StructureEvaluator, ) -from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger +from graphgen.utils import CURRENT_LOGGER_VAR, logger @dataclass diff --git a/graphgen/operators/evaluate_kg/evaluate_kg.py b/graphgen/operators/evaluate_kg/evaluate_kg.py index ab8a8031..2b7e74fd 100644 --- a/graphgen/operators/evaluate_kg/evaluate_kg.py +++ b/graphgen/operators/evaluate_kg/evaluate_kg.py @@ -1,6 +1,5 @@ import argparse import json -import os from pathlib import Path from dotenv import load_dotenv @@ -12,6 +11,110 @@ load_dotenv() +def _run_evaluation(evaluator, args): + """Run the evaluation based on arguments.""" + if args.accuracy_only: + logger.info("Running accuracy evaluation only...") + return {"accuracy": evaluator.evaluate_accuracy()} + if args.consistency_only: + logger.info("Running consistency evaluation only...") + return {"consistency": evaluator.evaluate_consistency()} + if args.structure_only: + logger.info("Running structural robustness evaluation only...") + return {"structure": evaluator.evaluate_structure()} + logger.info("Running all evaluations...") + return evaluator.evaluate_all() + + +def _print_accuracy_summary(acc): + """Print accuracy evaluation summary.""" + if "error" not in acc: + print("\n[Accuracy]") + if "entity_accuracy" in acc: + e = acc["entity_accuracy"] + print(f" Entity - Precision: {e.get('precision', 0):.3f}, " + f"Recall: {e.get('recall', 0):.3f}, F1: {e.get('f1', 0):.3f}") + if "triple_accuracy" in acc: + t = acc["triple_accuracy"] + print(f" Triple (RLC) - Precision: {t.get('precision', 0):.3f}, " + f"Recall: {t.get('recall', 0):.3f}, F1: {t.get('f1', 0):.3f}") + else: + print(f"\n[Accuracy] Error: {acc['error']}") + + +def _print_consistency_summary(cons): + """Print consistency evaluation summary.""" + if "error" not in cons: + print("\n[Consistency]") + print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}") + print(f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / " + f"{cons.get('total_entities', 0)}") + else: + print(f"\n[Consistency] Error: {cons['error']}") + + +def _print_structure_summary(struct): + """Print structural robustness evaluation summary.""" + if "error" not in struct: + print("\n[Structural Robustness]") + print(f" Total Nodes: {struct.get('total_nodes', 0)}") + print(f" Total Edges: {struct.get('total_edges', 0)}") + + thresholds = struct.get("thresholds", {}) + + # Noise Ratio + noise_check = thresholds.get("noise_ratio", {}) + noise_threshold = noise_check.get("threshold", "N/A") + noise_pass = noise_check.get("pass", False) + print(f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} " + f"({'✓' if noise_pass else '✗'} < {noise_threshold})") + + # Largest CC Ratio + lcc_check = thresholds.get("largest_cc_ratio", {}) + lcc_threshold = lcc_check.get("threshold", "N/A") + lcc_pass = lcc_check.get("pass", False) + print(f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} " + f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})") + + # Avg Degree + avg_degree_check = thresholds.get("avg_degree", {}) + avg_degree_threshold = avg_degree_check.get("threshold", "N/A") + avg_degree_pass = avg_degree_check.get("pass", False) + # Format threshold for display (handle tuple case) + if isinstance(avg_degree_threshold, tuple): + threshold_str = f"{avg_degree_threshold[0]}-{avg_degree_threshold[1]}" + else: + threshold_str = str(avg_degree_threshold) + print(f" Avg Degree: {struct.get('avg_degree', 0):.2f} " + f"({'✓' if avg_degree_pass else '✗'} {threshold_str})") + + # Power Law R² + if struct.get('powerlaw_r2') is not None: + powerlaw_check = thresholds.get("powerlaw_r2", {}) + powerlaw_threshold = powerlaw_check.get("threshold", "N/A") + powerlaw_pass = powerlaw_check.get("pass", False) + print(f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} " + f"({'✓' if powerlaw_pass else '✗'} > {powerlaw_threshold})") + else: + print(f"\n[Structural Robustness] Error: {struct['error']}") + + +def _print_summary(results): + """Print evaluation summary.""" + print("\n" + "=" * 60) + print("KG Quality Evaluation Summary") + print("=" * 60) + + if "accuracy" in results: + _print_accuracy_summary(results["accuracy"]) + if "consistency" in results: + _print_consistency_summary(results["consistency"]) + if "structure" in results: + _print_structure_summary(results["structure"]) + + print("\n" + "=" * 60) + + def main(): """Main function to run KG quality evaluation.""" parser = argparse.ArgumentParser( @@ -125,18 +228,7 @@ def main(): # Run evaluation try: - if args.accuracy_only: - logger.info("Running accuracy evaluation only...") - results = {"accuracy": evaluator.evaluate_accuracy()} - elif args.consistency_only: - logger.info("Running consistency evaluation only...") - results = {"consistency": evaluator.evaluate_consistency()} - elif args.structure_only: - logger.info("Running structural robustness evaluation only...") - results = {"structure": evaluator.evaluate_structure()} - else: - logger.info("Running all evaluations...") - results = evaluator.evaluate_all() + results = _run_evaluation(evaluator, args) # Save results output_path.parent.mkdir(parents=True, exist_ok=True) @@ -146,58 +238,7 @@ def main(): logger.info(f"Evaluation completed. Results saved to: {output_path}") # Print summary - print("\n" + "=" * 60) - print("KG Quality Evaluation Summary") - print("=" * 60) - - if "accuracy" in results: - acc = results["accuracy"] - if "error" not in acc: - print("\n[Accuracy]") - if "entity_accuracy" in acc: - e = acc["entity_accuracy"] - print(f" Entity - Precision: {e.get('precision', 0):.3f}, " - f"Recall: {e.get('recall', 0):.3f}, F1: {e.get('f1', 0):.3f}") - if "relation_accuracy" in acc: - r = acc["relation_accuracy"] - print(f" Relation - Precision: {r.get('precision', 0):.3f}, " - f"Recall: {r.get('recall', 0):.3f}, F1: {r.get('f1', 0):.3f}") - if "triple_accuracy" in acc: - t = acc["triple_accuracy"] - print(f" Triple (RLC) - Precision: {t.get('precision', 0):.3f}, " - f"Recall: {t.get('recall', 0):.3f}, F1: {t.get('f1', 0):.3f}") - else: - print(f"\n[Accuracy] Error: {acc['error']}") - - if "consistency" in results: - cons = results["consistency"] - if "error" not in cons: - print("\n[Consistency]") - print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}") - print(f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / " - f"{cons.get('total_entities', 0)}") - else: - print(f"\n[Consistency] Error: {cons['error']}") - - if "structure" in results: - struct = results["structure"] - if "error" not in struct: - print("\n[Structural Robustness]") - print(f" Total Nodes: {struct.get('total_nodes', 0)}") - print(f" Total Edges: {struct.get('total_edges', 0)}") - print(f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} " - f"({'✓' if struct.get('noise_ratio', 1) < 0.15 else '✗'} < 15%)") - print(f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} " - f"({'✓' if struct.get('largest_cc_ratio', 0) > 0.90 else '✗'} > 90%)") - print(f" Avg Degree: {struct.get('avg_degree', 0):.2f} " - f"({'✓' if 2 <= struct.get('avg_degree', 0) <= 5 else '✗'} 2-5)") - if struct.get('powerlaw_r2') is not None: - print(f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} " - f"({'✓' if struct.get('powerlaw_r2', 0) > 0.75 else '✗'} > 0.75)") - else: - print(f"\n[Structural Robustness] Error: {struct['error']}") - - print("\n" + "=" * 60) + _print_summary(results) except Exception as e: logger.error(f"Evaluation failed: {e}", exc_info=True) From 6c777340b9f453248401da21cca7a95fae7265b5 Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Tue, 23 Dec 2025 15:29:42 +0800 Subject: [PATCH 04/29] add: add kg_evaluate config file for params --- examples/evaluate_kg/evaluate_kg_config.yaml | 12 ++++++ .../evaluator/kg/structure_evaluator.py | 39 +++++++++++-------- 2 files changed, 34 insertions(+), 17 deletions(-) create mode 100644 examples/evaluate_kg/evaluate_kg_config.yaml diff --git a/examples/evaluate_kg/evaluate_kg_config.yaml b/examples/evaluate_kg/evaluate_kg_config.yaml new file mode 100644 index 00000000..d9a54f94 --- /dev/null +++ b/examples/evaluate_kg/evaluate_kg_config.yaml @@ -0,0 +1,12 @@ +source_text_paths: + - data/protein_source.txt + - data/dna_source.txt + - data/rna_source.txt + +structure_thresholds: + noise_ratio: 0.15 + largest_cc_ratio: 0.90 + avg_degree_min: 2.0 + avg_degree_max: 5.0 + powerlaw_r2: 0.75 + diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py index b8c5028d..3aee3475 100644 --- a/graphgen/models/evaluator/kg/structure_evaluator.py +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -20,15 +20,21 @@ class StructureEvaluator: """Evaluates structural robustness of the graph.""" - # Thresholds for structural metrics - NOISE_RATIO_THRESHOLD = 0.15 - LARGEST_CC_RATIO_THRESHOLD = 0.90 - AVG_DEGREE_MIN = 2 - AVG_DEGREE_MAX = 5 - POWERLAW_R2_THRESHOLD = 0.75 - - def __init__(self, graph_storage: BaseGraphStorage): + def __init__( + self, + graph_storage: BaseGraphStorage, + noise_ratio_threshold: float = 0.15, + largest_cc_ratio_threshold: float = 0.90, + avg_degree_min: float = 2.0, + avg_degree_max: float = 5.0, + powerlaw_r2_threshold: float = 0.75, + ): self.graph_storage = graph_storage + self.noise_ratio_threshold = noise_ratio_threshold + self.largest_cc_ratio_threshold = largest_cc_ratio_threshold + self.avg_degree_min = avg_degree_min + self.avg_degree_max = avg_degree_max + self.powerlaw_r2_threshold = powerlaw_r2_threshold def evaluate(self) -> Dict[str, Any]: if nx is None: @@ -73,27 +79,26 @@ def evaluate(self) -> Dict[str, Any]: # Power law distribution R² powerlaw_r2 = self._calculate_powerlaw_r2(G) - # Check thresholds thresholds = { "noise_ratio": { "value": noise_ratio, - "threshold": self.NOISE_RATIO_THRESHOLD, - "pass": noise_ratio < self.NOISE_RATIO_THRESHOLD, + "threshold": self.noise_ratio_threshold, + "pass": noise_ratio < self.noise_ratio_threshold, }, "largest_cc_ratio": { "value": largest_cc_ratio, - "threshold": self.LARGEST_CC_RATIO_THRESHOLD, - "pass": largest_cc_ratio > self.LARGEST_CC_RATIO_THRESHOLD, + "threshold": self.largest_cc_ratio_threshold, + "pass": largest_cc_ratio > self.largest_cc_ratio_threshold, }, "avg_degree": { "value": avg_degree, - "threshold": (self.AVG_DEGREE_MIN, self.AVG_DEGREE_MAX), - "pass": self.AVG_DEGREE_MIN <= avg_degree <= self.AVG_DEGREE_MAX, + "threshold": (self.avg_degree_min, self.avg_degree_max), + "pass": self.avg_degree_min <= avg_degree <= self.avg_degree_max, }, "powerlaw_r2": { "value": powerlaw_r2, - "threshold": self.POWERLAW_R2_THRESHOLD, - "pass": powerlaw_r2 > self.POWERLAW_R2_THRESHOLD if powerlaw_r2 is not None else False, + "threshold": self.powerlaw_r2_threshold, + "pass": powerlaw_r2 > self.powerlaw_r2_threshold if powerlaw_r2 is not None else False, }, } From 93abd005cbbad45dd392a3568d4cb38ee2538d78 Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Tue, 23 Dec 2025 16:28:26 +0800 Subject: [PATCH 05/29] fix: correct relation acc evaluation logic --- .../models/evaluator/kg/accuracy_evaluator.py | 369 ++++++++++-------- 1 file changed, 216 insertions(+), 153 deletions(-) diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index 357305a9..517d5a35 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -1,170 +1,202 @@ import asyncio -from typing import Any, Dict, List, Tuple - -from tqdm.asyncio import tqdm as tqdm_async +from pathlib import Path +from typing import Any, Dict, List, Set, Tuple from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper -from graphgen.models.evaluator.kg.utils import get_relevant_text, sample_items +from graphgen.models.evaluator.kg.utils import load_text_content from graphgen.utils import create_event_loop, logger class AccuracyEvaluator: """Evaluates accuracy of entity recognition, relation extraction, and triple validation. - Note: Recall is approximated as equal to precision since we cannot calculate true recall - (TP / (TP + FN)) without complete ground truth. The F1 score is therefore equal to precision. - Only precision should be considered as the primary metric. + Uses LLM to extract ground truth (entities, relations, and triples) from source texts, + then compares with entities, relations, and triples in the knowledge graph to calculate + true precision, recall, and F1 scores. + + Three evaluation dimensions: + 1. Entity recognition accuracy: compares extracted entities with KG entities + 2. Relation extraction accuracy: compares extracted relation descriptions with KG relation descriptions + 3. Triple validation (RLC): validates complete triples (head entity + relation + tail entity) """ - + def __init__( self, graph_storage: BaseGraphStorage, chunk_storage: BaseKVStorage, llm_client: BaseLLMWrapper, - sample_size: int = 100, + source_text_paths: List[str], max_concurrent: int = 10, ): + if not source_text_paths: + raise ValueError("source_text_paths is required and cannot be empty") self.graph_storage = graph_storage self.chunk_storage = chunk_storage self.llm_client = llm_client - self.sample_size = sample_size self.max_concurrent = max_concurrent + self.source_text_paths = source_text_paths def evaluate(self) -> Dict[str, Any]: - # Get all nodes and edges + """Evaluate entity, relation, and triple accuracy. + + Returns: + Dictionary containing entity_accuracy, relation_accuracy, and triple_accuracy metrics. + """ all_nodes = self.graph_storage.get_all_nodes() or [] all_edges = self.graph_storage.get_all_edges() or [] if not all_nodes and not all_edges: return {"error": "Empty graph"} - # Sample entities and triples (edges) - entity_samples = sample_items(all_nodes, self.sample_size) - triple_samples = sample_items(all_edges, self.sample_size) - - # Evaluate each type (async) loop = create_event_loop() - entity_results = loop.run_until_complete(self._evaluate_entities(entity_samples)) - triple_results = loop.run_until_complete(self._evaluate_triples(triple_samples)) + source_text = self._load_source_texts() + entity_ground_truth, relation_ground_truth, triple_ground_truth = loop.run_until_complete( + self._extract_ground_truth(source_text) + ) + entity_results = loop.run_until_complete( + self._evaluate_entities_with_ground_truth( + all_nodes, entity_ground_truth + ) + ) + relation_results = loop.run_until_complete( + self._evaluate_relations_with_ground_truth( + all_edges, relation_ground_truth + ) + ) + triple_results = loop.run_until_complete( + self._evaluate_triples_with_ground_truth( + all_edges, triple_ground_truth + ) + ) return { "entity_accuracy": entity_results, + "relation_accuracy": relation_results, "triple_accuracy": triple_results, } - async def _evaluate_entities( - self, entity_samples: List[Tuple[str, Dict]] - ) -> Dict[str, float]: - """Evaluate entity recognition accuracy.""" - source_text = get_relevant_text(self.chunk_storage) - + def _load_source_texts(self) -> str: + """Load and concatenate source text files. + + Supports .txt, .json, and .jsonl formats using the utility function + from graphgen.models.evaluator.kg.utils. + """ + texts = [] + for path in self.source_text_paths: + file_path = Path(path) + if not file_path.exists(): + logger.warning(f"Source text file not found: {path}") + continue + try: + content = load_text_content(file_path) + if content: + texts.append(content) + except Exception as e: + logger.error(f"Failed to read {path}: {e}") + return "\n\n".join(texts) + + async def _extract_ground_truth( + self, source_text: str + ) -> Tuple[Set[str], Set[str], Set[Tuple[str, str, str]]]: + """Extract entities, relations, and triples from source text using LLM as ground truth.""" semaphore = asyncio.Semaphore(self.max_concurrent) + chunk_size = 2000 + chunks = [ + source_text[i : i + chunk_size] + for i in range(0, len(source_text), chunk_size) + ] - async def verify_with_semaphore(entity_sample): + async def extract_from_chunk(chunk): async with semaphore: - entity_id, entity_data = entity_sample - return await self._verify_entity_with_llm( - entity_id, entity_data, source_text - ) - - results = [] - tasks = [verify_with_semaphore(sample) for sample in entity_samples] - for coro in tqdm_async( - asyncio.as_completed(tasks), total=len(tasks), desc="Verifying entities" - ): - result = await coro - results.append(result) - - return self._calculate_metrics(results) - - async def _evaluate_triples( - self, triple_samples: List[Tuple[str, str, Dict]] - ) -> Dict[str, float]: - """Evaluate triple validation accuracy (RLC).""" - source_text = get_relevant_text(self.chunk_storage) + return await self._extract_entities_relations_and_triples(chunk) - semaphore = asyncio.Semaphore(self.max_concurrent) + tasks = [extract_from_chunk(chunk) for chunk in chunks] + results = await asyncio.gather(*tasks) - async def verify_with_semaphore(triple_sample): - async with semaphore: - src_id, dst_id, edge_data = triple_sample - return await self._verify_triple_with_llm( - src_id, dst_id, edge_data, source_text - ) - - results = [] - tasks = [verify_with_semaphore(sample) for sample in triple_samples] - for coro in tqdm_async( - asyncio.as_completed(tasks), total=len(tasks), desc="Verifying triples" - ): - result = await coro - results.append(result) - - return self._calculate_metrics(results) - - async def _verify_entity_with_llm( - self, entity_id: str, entity_data: Dict, source_text: str - ) -> bool: - """Verify entity correctness using LLM.""" - entity_name = entity_data.get("entity_name", entity_id) - entity_type = entity_data.get("entity_type", "unknown") - entity_summary = entity_data.get("entity_summary", entity_data.get("description", "")) - - # Try to get relevant text from source_id - source_id = entity_data.get("source_id") - if source_id: - relevant_text = get_relevant_text(self.chunk_storage, source_id) - if relevant_text: - source_text = relevant_text - - prompt = f"""给定以下文本和实体信息,请判断该实体是否在文本中正确识别。 - -文本:{source_text[:2000]} - -实体名称:{entity_name} -实体类型:{entity_type} -实体描述:{entity_summary} - -请回答:该实体是否在文本中正确识别?回答"是"或"否",并简要说明理由。""" + all_entities = set() + all_relations = set() + all_triples = set() + for entities, relations, triples in results: + all_entities.update(entities) + all_relations.update(relations) + all_triples.update(triples) + + return all_entities, all_relations, all_triples + + async def _extract_entities_relations_and_triples( + self, text: str + ) -> Tuple[Set[str], Set[str], Set[Tuple[str, str, str]]]: + """Extract entities, relations, and triples from a text chunk using LLM.""" + entity_prompt = f"""从以下文本中提取所有实体名称,每行一个实体名称。 + +文本: +{text[:2000]} + +请只返回实体名称列表,每行一个,不要其他内容:""" + + relation_prompt = f"""从以下文本中提取所有关系描述,每行一个关系描述。 +关系描述是指描述两个实体之间关系的词语或短语,例如"设计"、"位于"、"属于"等。 + +文本: +{text[:2000]} + +请只返回关系描述列表,每行一个,不要其他内容:""" + + triple_prompt = f"""从以下文本中提取所有三元组,格式为:头实体|关系|尾实体,每行一个三元组。 + +文本: +{text[:2000]} + +请只返回三元组列表,每行一个,格式为:头实体|关系|尾实体,不要其他内容:""" try: - response = await self.llm_client.generate_answer(prompt) - response_lower = response.lower() - return ( - "是" in response_lower - or "yes" in response_lower - or "正确" in response_lower - ) + entity_response = await self.llm_client.generate_answer(entity_prompt) + relation_response = await self.llm_client.generate_answer(relation_prompt) + triple_response = await self.llm_client.generate_answer(triple_prompt) + + entities = { + line.strip() + for line in entity_response.split("\n") + if line.strip() and not line.strip().startswith("#") + } + + relations = { + line.strip() + for line in relation_response.split("\n") + if line.strip() and not line.strip().startswith("#") + } + + triples = set() + for line in triple_response.split("\n"): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split("|") + if len(parts) == 3: + triples.add((parts[0].strip(), parts[1].strip(), parts[2].strip())) + + return entities, relations, triples except Exception as e: - logger.error(f"LLM verification failed for entity {entity_id}: {e}") - return False + logger.error(f"Failed to extract ground truth: {e}") + return set(), set(), set() - def _calculate_metrics(self, results: List[bool]) -> Dict[str, float]: - """ - Calculate precision, recall, and F1 score from boolean verification results. - - Note: Recall is approximated as equal to precision since we cannot calculate - true recall (TP / (TP + FN)) without complete ground truth. The F1 score - is therefore equal to precision. Only precision should be considered as the - primary metric. - - Args: - results: List of boolean values indicating verification results (True = correct) - - Returns: - Dictionary containing precision, recall, f1, true_positives, false_positives, - and sample_size - """ - tp = sum(results) - fp = len(results) - tp - precision = tp / len(results) if results else 0.0 - # Approximation: assume all sampled are relevant (cannot calculate true recall) - recall = precision - f1 = ( - 2 * precision * recall / (precision + recall) - if (precision + recall) > 0 - else 0.0 - ) + async def _evaluate_entities_with_ground_truth( + self, all_nodes: List[Tuple[str, Dict]], ground_truth: Set[str] + ) -> Dict[str, float]: + """Evaluate entity accuracy by comparing KG entities with ground truth.""" + kg_entities = { + node_data.get("entity_name", node_id).lower() + for node_id, node_data in all_nodes + if isinstance(node_data, dict) + } + + tp = len(kg_entities & {e.lower() for e in ground_truth}) + fp = len(kg_entities - {e.lower() for e in ground_truth}) + fn = len({e.lower() for e in ground_truth} - kg_entities) + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return { "precision": precision, @@ -172,42 +204,73 @@ def _calculate_metrics(self, results: List[bool]) -> Dict[str, float]: "f1": f1, "true_positives": tp, "false_positives": fp, - "sample_size": len(results), + "false_negatives": fn, + "total_ground_truth": len(ground_truth), + "total_kg_entities": len(kg_entities), } - async def _verify_triple_with_llm( - self, src_id: str, dst_id: str, edge_data: Dict, source_text: str - ) -> bool: - """Verify triple correctness using LLM.""" - src_node = self.graph_storage.get_node(src_id) or {} - dst_node = self.graph_storage.get_node(dst_id) or {} - head = src_node.get("entity_name", src_id) - tail = dst_node.get("entity_name", dst_id) - relation = edge_data.get("relationship_summary", edge_data.get("description", "")) + async def _evaluate_relations_with_ground_truth( + self, all_edges: List[Tuple[str, str, Dict]], ground_truth: Set[str] + ) -> Dict[str, float]: + """Evaluate relation extraction accuracy by comparing KG relation descriptions with ground truth.""" + kg_relations = set() + for src_id, dst_id, edge_data in all_edges: + relation = edge_data.get("relationship_summary", edge_data.get("description", "")) + if relation: + kg_relations.add(relation.lower().strip()) - # Try to get relevant text from source_id - source_id = edge_data.get("source_id") - if source_id: - relevant_text = get_relevant_text(self.chunk_storage, source_id) - if relevant_text: - source_text = relevant_text + gt_normalized = {r.lower().strip() for r in ground_truth if r.strip()} - prompt = f"""给定以下文本和三元组,请判断该三元组是否正确。 + tp = len(kg_relations & gt_normalized) + fp = len(kg_relations - gt_normalized) + fn = len(gt_normalized - kg_relations) -文本:{source_text[:2000]} + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 -三元组:(头实体: {head}, 关系: {relation}, 尾实体: {tail}) + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": tp, + "false_positives": fp, + "false_negatives": fn, + "total_ground_truth": len(ground_truth), + "total_kg_relations": len(kg_relations), + } -请回答:该三元组是否正确?回答"是"或"否",并简要说明理由。""" + async def _evaluate_triples_with_ground_truth( + self, all_edges: List[Tuple[str, str, Dict]], ground_truth: Set[Tuple[str, str, str]] + ) -> Dict[str, float]: + """Evaluate triple accuracy by comparing KG triples with ground truth.""" + kg_triples = set() + for src_id, dst_id, edge_data in all_edges: + src_node = self.graph_storage.get_node(src_id) or {} + dst_node = self.graph_storage.get_node(dst_id) or {} + head = src_node.get("entity_name", src_id).lower() + tail = dst_node.get("entity_name", dst_id).lower() + relation = edge_data.get("relationship_summary", edge_data.get("description", "")).lower() + kg_triples.add((head, relation, tail)) + + gt_normalized = {(h.lower(), r.lower(), t.lower()) for h, r, t in ground_truth} + + tp = len(kg_triples & gt_normalized) + fp = len(kg_triples - gt_normalized) + fn = len(gt_normalized - kg_triples) + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + + return { + "precision": precision, + "recall": recall, + "f1": f1, + "true_positives": tp, + "false_positives": fp, + "false_negatives": fn, + "total_ground_truth": len(ground_truth), + "total_kg_triples": len(kg_triples), + } - try: - response = await self.llm_client.generate_answer(prompt) - response_lower = response.lower() - return ( - "是" in response_lower - or "yes" in response_lower - or "正确" in response_lower - ) - except Exception as e: - logger.error(f"LLM verification failed for triple {src_id}->{dst_id}: {e}") - return False From 777cb251e5195f9dc1527bd26d420fcfa51f188d Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Wed, 24 Dec 2025 02:54:46 +0800 Subject: [PATCH 06/29] refactor: enhance KG evaluator to use llm-as judge; remove evaluate_kg_config --- examples/evaluate_kg/evaluate_kg.sh | 1 - examples/evaluate_kg/evaluate_kg_config.yaml | 12 - graphgen/models/evaluator/kg/README.md | 161 ++++- graphgen/models/evaluator/kg/__init__.py | 5 - .../models/evaluator/kg/accuracy_evaluator.py | 586 +++++++++++------- .../evaluator/kg/consistency_evaluator.py | 438 ++++++++++++- .../evaluator/kg/structure_evaluator.py | 37 +- graphgen/models/evaluator/kg/utils.py | 96 --- .../models/evaluator/kg_quality_evaluator.py | 27 +- graphgen/operators/evaluate_kg/evaluate_kg.py | 63 +- 10 files changed, 998 insertions(+), 428 deletions(-) delete mode 100644 examples/evaluate_kg/evaluate_kg_config.yaml delete mode 100644 graphgen/models/evaluator/kg/utils.py diff --git a/examples/evaluate_kg/evaluate_kg.sh b/examples/evaluate_kg/evaluate_kg.sh index a846ee65..cda034bc 100644 --- a/examples/evaluate_kg/evaluate_kg.sh +++ b/examples/evaluate_kg/evaluate_kg.sh @@ -2,5 +2,4 @@ python3 -m graphgen.operators.evaluate_kg.evaluate_kg \ --working_dir cache \ --graph_backend kuzu \ --kv_backend rocksdb \ - --sample_size 100 \ --max_concurrent 10 diff --git a/examples/evaluate_kg/evaluate_kg_config.yaml b/examples/evaluate_kg/evaluate_kg_config.yaml deleted file mode 100644 index d9a54f94..00000000 --- a/examples/evaluate_kg/evaluate_kg_config.yaml +++ /dev/null @@ -1,12 +0,0 @@ -source_text_paths: - - data/protein_source.txt - - data/dna_source.txt - - data/rna_source.txt - -structure_thresholds: - noise_ratio: 0.15 - largest_cc_ratio: 0.90 - avg_degree_min: 2.0 - avg_degree_max: 5.0 - powerlaw_r2: 0.75 - diff --git a/graphgen/models/evaluator/kg/README.md b/graphgen/models/evaluator/kg/README.md index 71554c40..833e9ad6 100644 --- a/graphgen/models/evaluator/kg/README.md +++ b/graphgen/models/evaluator/kg/README.md @@ -6,24 +6,33 @@ This module provides comprehensive quality evaluation for knowledge graphs built The evaluation functionality has been split into modular components: -- **`accuracy_evaluator.py`**: Entity/relation/triple accuracy evaluation using LLM-as-judge +- **`accuracy_evaluator.py`**: Entity/relation extraction quality evaluation using LLM-as-a-Judge - **`consistency_evaluator.py`**: Attribute value conflict detection - **`structure_evaluator.py`**: Graph structural robustness metrics -- **`utils.py`**: Utility functions (NetworkX conversion, text retrieval, sampling) - **`kg_quality_evaluator.py`**: Main evaluator class that integrates all modules ## Features ### 1. Accuracy Assessment -- **Entity Recognition Accuracy**: Samples entities and validates them using LLM -- **Relation Extraction Accuracy**: Samples relations and validates them using LLM -- **Triple Validation (RLC)**: Samples triples and validates them using LLM -- Calculates Precision, Recall, and F1 scores for each metric +- **Entity Extraction Quality**: Uses LLM-as-a-Judge to evaluate the quality of entity extraction from chunks + - Evaluates accuracy (correctness of extracted entities) + - Evaluates completeness (whether important entities are missed) + - Evaluates precision (naming accuracy and specificity) +- **Relation Extraction Quality**: Uses LLM-as-a-Judge to evaluate the quality of relation extraction from chunks + - Evaluates accuracy (correctness of extracted relations) + - Evaluates completeness (whether important relations are missed) + - Evaluates precision (relation description accuracy) +- Provides multi-dimensional quality scores (0-1 scale) with detailed reasoning for each chunk ### 2. Consistency Assessment -- Detects attribute value conflicts (same entity, same attribute, different values) +- **Semantic Conflict Detection**: Uses LLM-as-a-Judge to detect semantic conflicts in entity attributes + - **Entity Type Conflicts**: Detects when the same entity is extracted with different types across chunks + - **Entity Description Conflicts**: Detects when entity descriptions from different chunks are semantically inconsistent + - **Relation Conflicts**: Detects when the same entity pair has conflicting relation descriptions +- Only evaluates entities with multiple source chunks (entities appearing in multiple chunks) +- Uses LLM to extract entity attributes from each chunk and compare them semantically - Calculates conflict rate: `conflict_entities_count / total_entities` -- Returns detailed conflict information +- Returns detailed conflict information including conflict severity and reasoning ### 3. Structural Robustness Assessment - **Noise Ratio**: Isolated nodes / total nodes (threshold: < 15%) @@ -42,10 +51,9 @@ python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache # Run specific evaluation python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache --accuracy_only -# Custom configuration +# Specify backends python -m graphgen.operators.evaluate_kg.evaluate_kg \ --working_dir cache \ - --sample_size 200 \ --graph_backend networkx \ --kv_backend json_kv ``` @@ -59,10 +67,22 @@ bash examples/evaluate_kg/evaluate_kg.sh # With custom options bash examples/evaluate_kg/evaluate_kg.sh \ --working_dir cache \ - --sample_size 200 \ --accuracy_only ``` +## Configuration + +All evaluation thresholds use default values defined in the evaluator classes: + +- **Structure thresholds**: Defined in `StructureEvaluator` with defaults: + - `noise_ratio_threshold`: 0.15 + - `largest_cc_ratio_threshold`: 0.90 + - `avg_degree_min`: 2.0 + - `avg_degree_max`: 5.0 + - `powerlaw_r2_threshold`: 0.75 + +**Note**: Accuracy evaluation automatically loads chunks from the chunk storage and evaluates the quality of entity/relation extraction using LLM-as-a-Judge. No configuration file is needed. + ## Requirements - **NetworkX**: Required for structural evaluation @@ -78,21 +98,117 @@ The evaluation returns a dictionary with the following structure: { "accuracy": { "entity_accuracy": { - "precision": float, - "recall": float, - "f1": float, - "true_positives": int, - "false_positives": int, - "sample_size": int + "overall_score": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "accuracy": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "completeness": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "precision": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "total_chunks": int, + "detailed_results": [ + { + "chunk_id": str, + "chunk_content": str, + "extracted_entities_count": int, + "accuracy": float, + "completeness": float, + "precision": float, + "overall_score": float, + "accuracy_reasoning": str, + "completeness_reasoning": str, + "precision_reasoning": str, + "issues": [str] + }, + ... + ] }, - "relation_accuracy": { ... }, - "triple_accuracy": { ... } + "relation_accuracy": { + "overall_score": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "accuracy": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "completeness": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "precision": { + "mean": float, + "median": float, + "min": float, + "max": float, + "std": float + }, + "total_chunks": int, + "detailed_results": [ + { + "chunk_id": str, + "chunk_content": str, + "extracted_relations_count": int, + "accuracy": float, + "completeness": float, + "precision": float, + "overall_score": float, + "accuracy_reasoning": str, + "completeness_reasoning": str, + "precision_reasoning": str, + "issues": [str] + }, + ... + ] + } }, "consistency": { "conflict_rate": float, "conflict_entities_count": int, "total_entities": int, - "conflicts": [ ... ] + "entities_checked": int, + "conflicts": [ + { + "entity_id": str, + "conflict_type": str, # "entity_type" or "description" + "conflict_severity": float, # 0-1, severity of the conflict + "conflict_reasoning": str, + "conflicting_values": [str], + "recommended_value": str, # for entity_type conflicts + "conflict_details": str # for description conflicts + }, + ... + ] }, "structure": { "total_nodes": int, @@ -111,7 +227,10 @@ The evaluation returns a dictionary with the following structure: ## Notes -- Accuracy evaluation requires LLM API access and may be slow for large sample sizes +- Accuracy evaluation uses LLM-as-a-Judge to evaluate extraction quality from chunks +- Accuracy evaluation automatically loads chunks from chunk storage (no need for source_text_paths) +- The evaluator associates extracted entities/relations with their source chunks using the `source_id` field - Structural evaluation automatically converts Kuzu storage to NetworkX for analysis - All evaluations include error handling and will return error messages if something fails - The evaluator automatically loads graph and chunk storage from the working directory +- LLM evaluation may take time for large numbers of chunks (controlled by `max_concurrent` parameter) diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py index 007f0c9d..4a7f794b 100644 --- a/graphgen/models/evaluator/kg/__init__.py +++ b/graphgen/models/evaluator/kg/__init__.py @@ -1,14 +1,9 @@ from .accuracy_evaluator import AccuracyEvaluator from .consistency_evaluator import ConsistencyEvaluator from .structure_evaluator import StructureEvaluator -from .utils import convert_to_networkx, get_relevant_text, get_source_text, sample_items __all__ = [ "AccuracyEvaluator", "ConsistencyEvaluator", "StructureEvaluator", - "convert_to_networkx", - "get_relevant_text", - "get_source_text", - "sample_items", ] diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index 517d5a35..ea40788a 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -1,23 +1,93 @@ import asyncio -from pathlib import Path -from typing import Any, Dict, List, Set, Tuple +import json +import re +from typing import Any, Dict, List from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper -from graphgen.models.evaluator.kg.utils import load_text_content +from graphgen.bases.datatypes import Chunk from graphgen.utils import create_event_loop, logger +# LLM-as-a-Judge evaluation prompts +ENTITY_EVALUATION_PROMPT = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的实体列表,评估实体提取的质量。 + +评估维度: +1. ACCURACY (准确性, 权重: 40%): 提取的实体是否正确,是否有误提取或错误识别 +2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要实体 +3. PRECISION (精确性, 权重: 20%): 提取的实体是否精确,命名是否准确 + +评分标准(每个维度 0-1 分): +- EXCELLENT (0.8-1.0): 高质量提取 +- GOOD (0.6-0.79): 良好质量,有少量问题 +- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 +- POOR (0.0-0.39): 质量差,需要改进 + +综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +请评估以下内容: + +原始文本块: +{chunk_content} + +提取的实体列表: +{extracted_entities} + +请以 JSON 格式返回评估结果: +{{ + "accuracy": <0-1之间的浮点数>, + "completeness": <0-1之间的浮点数>, + "precision": <0-1之间的浮点数>, + "overall_score": <综合评分>, + "accuracy_reasoning": "<准确性评估理由>", + "completeness_reasoning": "<完整性评估理由,包括遗漏的重要实体>", + "precision_reasoning": "<精确性评估理由>", + "issues": ["<发现的问题列表>"] +}} +""" + +RELATION_EVALUATION_PROMPT = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的关系列表,评估关系抽取的质量。 + +评估维度: +1. ACCURACY (准确性, 权重: 40%): 提取的关系是否正确,关系描述是否准确 +2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要关系 +3. PRECISION (精确性, 权重: 20%): 关系描述是否精确,是否过于宽泛 + +评分标准(每个维度 0-1 分): +- EXCELLENT (0.8-1.0): 高质量提取 +- GOOD (0.6-0.79): 良好质量,有少量问题 +- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 +- POOR (0.0-0.39): 质量差,需要改进 + +综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +请评估以下内容: + +原始文本块: +{chunk_content} + +提取的关系列表: +{extracted_relations} + +请以 JSON 格式返回评估结果: +{{ + "accuracy": <0-1之间的浮点数>, + "completeness": <0-1之间的浮点数>, + "precision": <0-1之间的浮点数>, + "overall_score": <综合评分>, + "accuracy_reasoning": "<准确性评估理由>", + "completeness_reasoning": "<完整性评估理由,包括遗漏的重要关系>", + "precision_reasoning": "<精确性评估理由>", + "issues": ["<发现的问题列表>"] +}} +""" + + class AccuracyEvaluator: - """Evaluates accuracy of entity recognition, relation extraction, and triple validation. - - Uses LLM to extract ground truth (entities, relations, and triples) from source texts, - then compares with entities, relations, and triples in the knowledge graph to calculate - true precision, recall, and F1 scores. + """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge. - Three evaluation dimensions: - 1. Entity recognition accuracy: compares extracted entities with KG entities - 2. Relation extraction accuracy: compares extracted relation descriptions with KG relation descriptions - 3. Triple validation (RLC): validates complete triples (head entity + relation + tail entity) + For each chunk, uses LLM to evaluate the quality of extracted entities and relations + by comparing them with the original chunk content. Provides multi-dimensional quality + scores (accuracy, completeness, precision). """ def __init__( @@ -25,252 +95,304 @@ def __init__( graph_storage: BaseGraphStorage, chunk_storage: BaseKVStorage, llm_client: BaseLLMWrapper, - source_text_paths: List[str], max_concurrent: int = 10, ): - if not source_text_paths: - raise ValueError("source_text_paths is required and cannot be empty") self.graph_storage = graph_storage self.chunk_storage = chunk_storage self.llm_client = llm_client self.max_concurrent = max_concurrent - self.source_text_paths = source_text_paths def evaluate(self) -> Dict[str, Any]: - """Evaluate entity, relation, and triple accuracy. + """Evaluate entity and relation extraction quality using LLM-as-a-Judge. Returns: - Dictionary containing entity_accuracy, relation_accuracy, and triple_accuracy metrics. + Dictionary containing entity_accuracy and relation_accuracy metrics. """ - all_nodes = self.graph_storage.get_all_nodes() or [] - all_edges = self.graph_storage.get_all_edges() or [] - - if not all_nodes and not all_edges: - return {"error": "Empty graph"} - + # 1. Load all chunks from storage + chunks = self._load_chunks_from_storage() + + if not chunks: + logger.warning("No chunks found in storage") + return {"error": "No chunks found in storage"} + + logger.info(f"Found {len(chunks)} chunks to evaluate") + + # 2. Evaluate each chunk loop = create_event_loop() - source_text = self._load_source_texts() - entity_ground_truth, relation_ground_truth, triple_ground_truth = loop.run_until_complete( - self._extract_ground_truth(source_text) - ) - entity_results = loop.run_until_complete( - self._evaluate_entities_with_ground_truth( - all_nodes, entity_ground_truth - ) + entity_evaluations, relation_evaluations = loop.run_until_complete( + self._evaluate_all_chunks(chunks) ) - relation_results = loop.run_until_complete( - self._evaluate_relations_with_ground_truth( - all_edges, relation_ground_truth - ) - ) - triple_results = loop.run_until_complete( - self._evaluate_triples_with_ground_truth( - all_edges, triple_ground_truth - ) - ) - - return { - "entity_accuracy": entity_results, - "relation_accuracy": relation_results, - "triple_accuracy": triple_results, - } + + # 3. Aggregate results + return self._aggregate_evaluation_results(entity_evaluations, relation_evaluations) - def _load_source_texts(self) -> str: - """Load and concatenate source text files. + def _load_chunks_from_storage(self) -> List[Chunk]: + """Load all chunks from chunk storage.""" + chunks = [] + all_chunk_data = self.chunk_storage.get_all() - Supports .txt, .json, and .jsonl formats using the utility function - from graphgen.models.evaluator.kg.utils. - """ - texts = [] - for path in self.source_text_paths: - file_path = Path(path) - if not file_path.exists(): - logger.warning(f"Source text file not found: {path}") - continue + for chunk_id, chunk_data in all_chunk_data.items(): try: - content = load_text_content(file_path) - if content: - texts.append(content) + chunk = Chunk.from_dict(chunk_id, chunk_data) + chunks.append(chunk) except Exception as e: - logger.error(f"Failed to read {path}: {e}") - return "\n\n".join(texts) - - async def _extract_ground_truth( - self, source_text: str - ) -> Tuple[Set[str], Set[str], Set[Tuple[str, str, str]]]: - """Extract entities, relations, and triples from source text using LLM as ground truth.""" - semaphore = asyncio.Semaphore(self.max_concurrent) - chunk_size = 2000 - chunks = [ - source_text[i : i + chunk_size] - for i in range(0, len(source_text), chunk_size) - ] - - async def extract_from_chunk(chunk): - async with semaphore: - return await self._extract_entities_relations_and_triples(chunk) - - tasks = [extract_from_chunk(chunk) for chunk in chunks] - results = await asyncio.gather(*tasks) - - all_entities = set() - all_relations = set() - all_triples = set() - for entities, relations, triples in results: - all_entities.update(entities) - all_relations.update(relations) - all_triples.update(triples) - - return all_entities, all_relations, all_triples - - async def _extract_entities_relations_and_triples( - self, text: str - ) -> Tuple[Set[str], Set[str], Set[Tuple[str, str, str]]]: - """Extract entities, relations, and triples from a text chunk using LLM.""" - entity_prompt = f"""从以下文本中提取所有实体名称,每行一个实体名称。 - -文本: -{text[:2000]} - -请只返回实体名称列表,每行一个,不要其他内容:""" - - relation_prompt = f"""从以下文本中提取所有关系描述,每行一个关系描述。 -关系描述是指描述两个实体之间关系的词语或短语,例如"设计"、"位于"、"属于"等。 - -文本: -{text[:2000]} - -请只返回关系描述列表,每行一个,不要其他内容:""" + logger.warning(f"Failed to load chunk {chunk_id}: {e}") + continue + + return chunks - triple_prompt = f"""从以下文本中提取所有三元组,格式为:头实体|关系|尾实体,每行一个三元组。 + def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]: + """Get all entities extracted from the specified chunk.""" + entities = [] + all_nodes = self.graph_storage.get_all_nodes() or [] + + for node_id, node_data in all_nodes: + if not isinstance(node_data, dict): + continue + source_ids = node_data.get("source_id", "").split("") + # Check if this chunk_id is in the source_ids + if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: + entities.append({ + "entity_name": node_data.get("entity_name", node_id), + "entity_type": node_data.get("entity_type", ""), + "description": node_data.get("description", "") + }) + + return entities -文本: -{text[:2000]} + def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]: + """Get all relations extracted from the specified chunk.""" + relations = [] + all_edges = self.graph_storage.get_all_edges() or [] + + for src_id, dst_id, edge_data in all_edges: + if not isinstance(edge_data, dict): + continue + source_ids = edge_data.get("source_id", "").split("") + # Check if this chunk_id is in the source_ids + if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: + src_node = self.graph_storage.get_node(src_id) or {} + dst_node = self.graph_storage.get_node(dst_id) or {} + relations.append({ + "source_entity": src_node.get("entity_name", src_id), + "target_entity": dst_node.get("entity_name", dst_id), + "relationship_summary": edge_data.get("description", "") + }) + + return relations -请只返回三元组列表,每行一个,格式为:头实体|关系|尾实体,不要其他内容:""" + async def _evaluate_all_chunks( + self, chunks: List[Chunk] + ) -> tuple[List[Dict], List[Dict]]: + """Evaluate all chunks concurrently.""" + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def evaluate_chunk(chunk: Chunk): + async with semaphore: + entities = self._get_extracted_entities_for_chunk(chunk.id) + relations = self._get_extracted_relations_for_chunk(chunk.id) + + entity_eval = await self._evaluate_entity_extraction(chunk, entities) + relation_eval = await self._evaluate_relation_extraction(chunk, relations) + + return entity_eval, relation_eval + + tasks = [evaluate_chunk(chunk) for chunk in chunks] + results = await asyncio.gather(*tasks, return_exceptions=True) + + entity_evaluations = [] + relation_evaluations = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Failed to evaluate chunk {chunks[i].id}: {result}") + continue + entity_eval, relation_eval = result + entity_evaluations.append(entity_eval) + relation_evaluations.append(relation_eval) + + return entity_evaluations, relation_evaluations + async def _evaluate_entity_extraction( + self, chunk: Chunk, extracted_entities: List[Dict] + ) -> Dict[str, Any]: + """Use LLM to evaluate entity extraction quality.""" try: - entity_response = await self.llm_client.generate_answer(entity_prompt) - relation_response = await self.llm_client.generate_answer(relation_prompt) - triple_response = await self.llm_client.generate_answer(triple_prompt) - - entities = { - line.strip() - for line in entity_response.split("\n") - if line.strip() and not line.strip().startswith("#") + prompt = ENTITY_EVALUATION_PROMPT.format( + chunk_content=chunk.content, + extracted_entities=json.dumps(extracted_entities, ensure_ascii=False, indent=2) + ) + + response = await self.llm_client.generate_answer(prompt) + + # Try to parse JSON response + try: + evaluation_result = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks or other formats + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + evaluation_result = json.loads(json_match.group(0)) + else: + logger.warning(f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}") + # Return default evaluation + evaluation_result = { + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": "Failed to parse LLM response", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": ["LLM response parsing failed"] + } + + # Validate and calculate overall_score if not provided + if "overall_score" not in evaluation_result: + accuracy = float(evaluation_result.get("accuracy", 0.0)) + completeness = float(evaluation_result.get("completeness", 0.0)) + precision = float(evaluation_result.get("precision", 0.0)) + evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] if chunk.content else "", # First 200 chars for debugging + "extracted_entities_count": len(extracted_entities), + **evaluation_result } - - relations = { - line.strip() - for line in relation_response.split("\n") - if line.strip() and not line.strip().startswith("#") + except Exception as e: + logger.error(f"Error evaluating entity extraction for chunk {chunk.id}: {e}") + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] if chunk.content else "", + "extracted_entities_count": len(extracted_entities), + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": f"Evaluation failed: {str(e)}", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [f"Evaluation error: {str(e)}"] } - triples = set() - for line in triple_response.split("\n"): - line = line.strip() - if not line or line.startswith("#"): - continue - parts = line.split("|") - if len(parts) == 3: - triples.add((parts[0].strip(), parts[1].strip(), parts[2].strip())) - - return entities, relations, triples + async def _evaluate_relation_extraction( + self, chunk: Chunk, extracted_relations: List[Dict] + ) -> Dict[str, Any]: + """Use LLM to evaluate relation extraction quality.""" + try: + prompt = RELATION_EVALUATION_PROMPT.format( + chunk_content=chunk.content, + extracted_relations=json.dumps(extracted_relations, ensure_ascii=False, indent=2) + ) + + response = await self.llm_client.generate_answer(prompt) + + # Try to parse JSON response + try: + evaluation_result = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks or other formats + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + evaluation_result = json.loads(json_match.group(0)) + else: + logger.warning(f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}") + # Return default evaluation + evaluation_result = { + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": "Failed to parse LLM response", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": ["LLM response parsing failed"] + } + + # Validate and calculate overall_score if not provided + if "overall_score" not in evaluation_result: + accuracy = float(evaluation_result.get("accuracy", 0.0)) + completeness = float(evaluation_result.get("completeness", 0.0)) + precision = float(evaluation_result.get("precision", 0.0)) + evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] if chunk.content else "", + "extracted_relations_count": len(extracted_relations), + **evaluation_result + } except Exception as e: - logger.error(f"Failed to extract ground truth: {e}") - return set(), set(), set() - - async def _evaluate_entities_with_ground_truth( - self, all_nodes: List[Tuple[str, Dict]], ground_truth: Set[str] - ) -> Dict[str, float]: - """Evaluate entity accuracy by comparing KG entities with ground truth.""" - kg_entities = { - node_data.get("entity_name", node_id).lower() - for node_id, node_data in all_nodes - if isinstance(node_data, dict) - } - - tp = len(kg_entities & {e.lower() for e in ground_truth}) - fp = len(kg_entities - {e.lower() for e in ground_truth}) - fn = len({e.lower() for e in ground_truth} - kg_entities) - - precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 - f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 - - return { - "precision": precision, - "recall": recall, - "f1": f1, - "true_positives": tp, - "false_positives": fp, - "false_negatives": fn, - "total_ground_truth": len(ground_truth), - "total_kg_entities": len(kg_entities), - } - - async def _evaluate_relations_with_ground_truth( - self, all_edges: List[Tuple[str, str, Dict]], ground_truth: Set[str] - ) -> Dict[str, float]: - """Evaluate relation extraction accuracy by comparing KG relation descriptions with ground truth.""" - kg_relations = set() - for src_id, dst_id, edge_data in all_edges: - relation = edge_data.get("relationship_summary", edge_data.get("description", "")) - if relation: - kg_relations.add(relation.lower().strip()) - - gt_normalized = {r.lower().strip() for r in ground_truth if r.strip()} - - tp = len(kg_relations & gt_normalized) - fp = len(kg_relations - gt_normalized) - fn = len(gt_normalized - kg_relations) - - precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 - f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 - - return { - "precision": precision, - "recall": recall, - "f1": f1, - "true_positives": tp, - "false_positives": fp, - "false_negatives": fn, - "total_ground_truth": len(ground_truth), - "total_kg_relations": len(kg_relations), - } - - async def _evaluate_triples_with_ground_truth( - self, all_edges: List[Tuple[str, str, Dict]], ground_truth: Set[Tuple[str, str, str]] - ) -> Dict[str, float]: - """Evaluate triple accuracy by comparing KG triples with ground truth.""" - kg_triples = set() - for src_id, dst_id, edge_data in all_edges: - src_node = self.graph_storage.get_node(src_id) or {} - dst_node = self.graph_storage.get_node(dst_id) or {} - head = src_node.get("entity_name", src_id).lower() - tail = dst_node.get("entity_name", dst_id).lower() - relation = edge_data.get("relationship_summary", edge_data.get("description", "")).lower() - kg_triples.add((head, relation, tail)) - - gt_normalized = {(h.lower(), r.lower(), t.lower()) for h, r, t in ground_truth} - - tp = len(kg_triples & gt_normalized) - fp = len(kg_triples - gt_normalized) - fn = len(gt_normalized - kg_triples) - - precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 - f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + logger.error(f"Error evaluating relation extraction for chunk {chunk.id}: {e}") + return { + "chunk_id": chunk.id, + "chunk_content": chunk.content[:200] if chunk.content else "", + "extracted_relations_count": len(extracted_relations), + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": f"Evaluation failed: {str(e)}", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [f"Evaluation error: {str(e)}"] + } + def _aggregate_evaluation_results( + self, entity_evaluations: List[Dict], relation_evaluations: List[Dict] + ) -> Dict[str, Any]: + """Aggregate evaluation results from all chunks.""" + def calculate_stats(scores: List[float]) -> Dict[str, float]: + if not scores: + return { + "mean": 0.0, + "median": 0.0, + "min": 0.0, + "max": 0.0, + "std": 0.0 + } + sorted_scores = sorted(scores) + n = len(scores) + mean = sum(scores) / n + median = sorted_scores[n // 2] if n % 2 == 1 else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2 + variance = sum((x - mean) ** 2 for x in scores) / n + std = variance ** 0.5 + + return { + "mean": mean, + "median": median, + "min": min(scores), + "max": max(scores), + "std": std + } + + # Extract scores + entity_overall_scores = [e.get("overall_score", 0.0) for e in entity_evaluations] + entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations] + entity_completeness_scores = [e.get("completeness", 0.0) for e in entity_evaluations] + entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations] + + relation_overall_scores = [r.get("overall_score", 0.0) for r in relation_evaluations] + relation_accuracy_scores = [r.get("accuracy", 0.0) for r in relation_evaluations] + relation_completeness_scores = [r.get("completeness", 0.0) for r in relation_evaluations] + relation_precision_scores = [r.get("precision", 0.0) for r in relation_evaluations] + return { - "precision": precision, - "recall": recall, - "f1": f1, - "true_positives": tp, - "false_positives": fp, - "false_negatives": fn, - "total_ground_truth": len(ground_truth), - "total_kg_triples": len(kg_triples), + "entity_accuracy": { + "overall_score": calculate_stats(entity_overall_scores), + "accuracy": calculate_stats(entity_accuracy_scores), + "completeness": calculate_stats(entity_completeness_scores), + "precision": calculate_stats(entity_precision_scores), + "total_chunks": len(entity_evaluations), + "detailed_results": entity_evaluations + }, + "relation_accuracy": { + "overall_score": calculate_stats(relation_overall_scores), + "accuracy": calculate_stats(relation_accuracy_scores), + "completeness": calculate_stats(relation_completeness_scores), + "precision": calculate_stats(relation_precision_scores), + "total_chunks": len(relation_evaluations), + "detailed_results": relation_evaluations + } } - diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py index 05d27d96..f616b98e 100644 --- a/graphgen/models/evaluator/kg/consistency_evaluator.py +++ b/graphgen/models/evaluator/kg/consistency_evaluator.py @@ -1,44 +1,187 @@ -from typing import Any, Dict +import asyncio +import json +import re +from typing import Any, Dict, List -from graphgen.bases import BaseGraphStorage +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.bases.datatypes import Chunk +from graphgen.utils import create_event_loop, logger + + +# LLM prompts for conflict detection +ENTITY_TYPE_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中被提取为不同的类型,是否存在语义冲突。 + +实体名称:{entity_name} + +在不同文本块中的类型提取结果: +{type_extractions} + +预设的实体类型列表(供参考): +concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene + +请判断这些类型是否存在语义冲突(即它们是否描述的是同一类事物,还是存在矛盾)。 +注意:如果类型只是同一概念的不同表述(如 concept 和 keyword),可能不算严重冲突。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数,0表示无冲突,1表示严重冲突>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_types": ["<存在冲突的类型对>"], + "recommended_type": "<如果存在冲突,推荐的正确类型(必须是预设类型之一)>" +}} +""" + +ENTITY_DESCRIPTION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中的描述是否存在语义冲突。 + +实体名称:{entity_name} + +在不同文本块中的描述: +{descriptions} + +请判断这些描述是否存在语义冲突(即它们是否描述的是同一个实体,还是存在矛盾的信息)。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_descriptions": ["<存在冲突的描述对>"], + "conflict_details": "<具体的冲突内容>" +}} +""" + +RELATION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一对实体在不同文本块中的关系描述是否存在语义冲突。 + +实体对:{source_entity} -> {target_entity} + +在不同文本块中的关系描述: +{relation_descriptions} + +请判断这些关系描述是否存在语义冲突。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_relations": ["<存在冲突的关系描述对>"] +}} +""" + +ENTITY_EXTRACTION_PROMPT = """从以下文本块中提取指定实体的类型和描述。 + +**重要**:你只需要提取指定的实体,不要提取其他实体。 + +实体名称:{entity_name} + +文本块: +{chunk_content} + +请从文本块中找到并提取**仅此实体**(实体名称:{entity_name})的以下信息: + +1. entity_type: 实体类型,必须是以下预设类型之一(小写): + - concept: 概念 + - date: 日期 + - location: 地点 + - keyword: 关键词 + - organization: 组织 + - person: 人物 + - event: 事件 + - work: 作品/工作 + - nature: 自然 + - artificial: 人工 + - science: 科学 + - technology: 技术 + - mission: 任务 + - gene: 基因 + + 如果无法确定类型,请使用 "concept" 作为默认值。 + +2. description: 实体描述(简要描述该实体在文本中的作用和特征) + +请以 JSON 格式返回: +{{ + "entity_type": "<实体类型(必须是上述预设类型之一)>", + "description": "<实体描述>" +}} +""" class ConsistencyEvaluator: - """Evaluates consistency by detecting attribute value conflicts.""" + """Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge. + + For entities with multiple source chunks, compares entity_type and description + extracted from different chunks to detect semantic conflicts. + """ - def __init__(self, graph_storage: BaseGraphStorage): + def __init__( + self, + graph_storage: BaseGraphStorage, + chunk_storage: BaseKVStorage, + llm_client: BaseLLMWrapper, + max_concurrent: int = 10, + ): self.graph_storage = graph_storage + self.chunk_storage = chunk_storage + self.llm_client = llm_client + self.max_concurrent = max_concurrent def evaluate(self) -> Dict[str, Any]: + """Evaluate consistency by detecting semantic conflicts.""" all_nodes = self.graph_storage.get_all_nodes() or [] if not all_nodes: return {"error": "Empty graph"} - conflicts = [] - conflict_entities = set() + loop = create_event_loop() + return loop.run_until_complete(self._evaluate_consistency(all_nodes)) + async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: + """Async evaluation of consistency.""" + # Filter entities with multiple source chunks + entities_with_multiple_sources = [] for node_id, node_data in all_nodes: if not isinstance(node_data, dict): continue + source_ids = node_data.get("source_id", "").split("") + source_ids = [sid.strip() for sid in source_ids if sid.strip()] + if len(source_ids) > 1: # Only check entities from multiple chunks + entities_with_multiple_sources.append((node_id, node_data, source_ids)) - # Check each attribute for multiple values - for attr_key, attr_value in node_data.items(): - # Skip special keys - if attr_key.startswith("_") or attr_key in ["id", "loss"]: - continue + if not entities_with_multiple_sources: + logger.info("No entities with multiple sources found, skipping consistency check") + return { + "conflict_rate": 0.0, + "conflict_entities_count": 0, + "total_entities": len(all_nodes), + "conflicts": [], + } + + logger.info(f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources") + + # Evaluate entities concurrently + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def evaluate_entity(entity_info): + async with semaphore: + return await self._evaluate_entity_consistency(entity_info) + + tasks = [evaluate_entity(entity_info) for entity_info in entities_with_multiple_sources] + results = await asyncio.gather(*tasks, return_exceptions=True) - # If attribute has multiple values (list), check for conflicts - if isinstance(attr_value, list): - unique_values = set(str(v) for v in attr_value if v) - if len(unique_values) > 1: - conflicts.append( - { - "entity_id": node_id, - "attribute": attr_key, - "values": list(unique_values), - } - ) - conflict_entities.add(node_id) + # Aggregate results + conflicts = [] + conflict_entities = set() + + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"Failed to evaluate entity {entities_with_multiple_sources[i][0]}: {result}") + continue + + entity_id, entity_conflicts = result + if entity_conflicts: + conflicts.extend(entity_conflicts) + conflict_entities.add(entity_id) total_entities = len(all_nodes) conflict_rate = ( @@ -49,5 +192,254 @@ def evaluate(self) -> Dict[str, Any]: "conflict_rate": conflict_rate, "conflict_entities_count": len(conflict_entities), "total_entities": total_entities, + "entities_checked": len(entities_with_multiple_sources), "conflicts": conflicts[:100], # Limit to first 100 conflicts } + + def _clean_entity_id(self, entity_id: str) -> str: + """Clean entity ID by removing surrounding quotes.""" + clean_id = entity_id.strip() + if (clean_id.startswith('"') and clean_id.endswith('"')) or \ + (clean_id.startswith("'") and clean_id.endswith("'")): + clean_id = clean_id[1:-1].strip() + return clean_id + + async def _evaluate_entity_consistency( + self, entity_info: tuple + ) -> tuple[str, List[Dict]]: + """Evaluate consistency for a single entity.""" + entity_id, node_data, source_ids = entity_info + # Clean entity_id for display + clean_entity_id = self._clean_entity_id(entity_id) + conflicts = [] + + # Get chunks for this entity + chunks = self._get_entity_chunks(source_ids) + if len(chunks) < 2: + return entity_id, [] + + # Extract entity attributes from each chunk + entity_extractions = {} + for chunk in chunks: + extraction = await self._extract_entity_from_chunk(entity_id, chunk) + if extraction: + entity_extractions[chunk.id] = extraction + + if len(entity_extractions) < 2: + return entity_id, [] + + # Check entity type consistency + type_extractions = { + chunk_id: ext.get("entity_type", "") + for chunk_id, ext in entity_extractions.items() + } + type_conflict = await self._check_entity_type_consistency( + entity_id, type_extractions + ) + if type_conflict and type_conflict.get("has_conflict", False): + conflicts.append({ + "entity_id": clean_entity_id, + "conflict_type": "entity_type", + "conflict_severity": type_conflict.get("conflict_severity", 0.0), + "conflict_reasoning": type_conflict.get("conflict_reasoning", ""), + "conflicting_values": type_conflict.get("conflicting_types", []), + "recommended_value": type_conflict.get("recommended_type", ""), + }) + + # Check entity description consistency + descriptions = { + chunk_id: ext.get("description", "") + for chunk_id, ext in entity_extractions.items() + } + desc_conflict = await self._check_entity_description_consistency( + entity_id, descriptions + ) + if desc_conflict and desc_conflict.get("has_conflict", False): + conflicts.append({ + "entity_id": clean_entity_id, + "conflict_type": "description", + "conflict_severity": desc_conflict.get("conflict_severity", 0.0), + "conflict_reasoning": desc_conflict.get("conflict_reasoning", ""), + "conflicting_values": desc_conflict.get("conflicting_descriptions", []), + "conflict_details": desc_conflict.get("conflict_details", ""), + }) + + return entity_id, conflicts + + def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]: + """Get all chunks related to an entity.""" + chunks = [] + for chunk_id in source_ids: + chunk_data = self.chunk_storage.get_by_id(chunk_id) + if chunk_data: + try: + chunk = Chunk.from_dict(chunk_id, chunk_data) + chunks.append(chunk) + except Exception as e: + logger.warning(f"Failed to load chunk {chunk_id}: {e}") + continue + return chunks + + async def _extract_entity_from_chunk( + self, entity_id: str, chunk: Chunk + ) -> Dict[str, str]: + """Extract entity attributes from a chunk using LLM.""" + try: + # Clean entity_id: remove surrounding quotes if present + clean_entity_id = self._clean_entity_id(entity_id) + + prompt = ENTITY_EXTRACTION_PROMPT.format( + entity_name=clean_entity_id, + chunk_content=chunk.content[:2000] if chunk.content else "" # Limit content length + ) + + response = await self.llm_client.generate_answer(prompt) + + # Try to parse JSON response + try: + extraction = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + extraction = json.loads(json_match.group(0)) + else: + logger.warning(f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}") + return {} + + # Normalize entity_type to lowercase and validate + entity_type = extraction.get("entity_type", "").lower().strip() + # Valid preset types + valid_types = { + "concept", "date", "location", "keyword", "organization", + "person", "event", "work", "nature", "artificial", + "science", "technology", "mission", "gene" + } + # If entity_type is not in valid types, default to "concept" + if entity_type not in valid_types: + if entity_type: # If LLM provided a type but it's invalid + logger.warning( + f"Invalid entity_type '{entity_type}' for entity {clean_entity_id} in chunk {chunk.id}, " + f"defaulting to 'concept'" + ) + entity_type = "concept" + + return { + "entity_type": entity_type, + "description": extraction.get("description", ""), + } + except Exception as e: + logger.error(f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}") + return {} + + async def _check_entity_type_consistency( + self, entity_id: str, type_extractions: Dict[str, str] + ) -> Dict[str, Any]: + """Check entity type consistency using LLM.""" + if len(set(type_extractions.values())) <= 1: + # All types are the same, no conflict + return {"has_conflict": False} + + try: + type_list = [f"Chunk {chunk_id}: {entity_type}" + for chunk_id, entity_type in type_extractions.items() + if entity_type] + + prompt = ENTITY_TYPE_CONFLICT_PROMPT.format( + entity_name=entity_id, + type_extractions="\n".join(type_list) + ) + + response = await self.llm_client.generate_answer(prompt) + + # Parse JSON response + try: + result = json.loads(response) + except json.JSONDecodeError: + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + result = json.loads(json_match.group(0)) + else: + logger.warning(f"Failed to parse conflict detection response for {entity_id}") + return {"has_conflict": False} + + return result + except Exception as e: + logger.error(f"Error checking type consistency for {entity_id}: {e}") + return {"has_conflict": False} + + async def _check_entity_description_consistency( + self, entity_id: str, descriptions: Dict[str, str] + ) -> Dict[str, Any]: + """Check entity description consistency using LLM.""" + # Filter out empty descriptions + valid_descriptions = {k: v for k, v in descriptions.items() if v} + if len(valid_descriptions) < 2: + return {"has_conflict": False} + + if len(set(valid_descriptions.values())) <= 1: + # All descriptions are the same, no conflict + return {"has_conflict": False} + + try: + desc_list = [f"Chunk {chunk_id}: {description}" + for chunk_id, description in valid_descriptions.items()] + + prompt = ENTITY_DESCRIPTION_CONFLICT_PROMPT.format( + entity_name=entity_id, + descriptions="\n".join(desc_list) + ) + + response = await self.llm_client.generate_answer(prompt) + + # Parse JSON response + try: + result = json.loads(response) + except json.JSONDecodeError: + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + result = json.loads(json_match.group(0)) + else: + logger.warning(f"Failed to parse conflict detection response for {entity_id}") + return {"has_conflict": False} + + return result + except Exception as e: + logger.error(f"Error checking description consistency for {entity_id}: {e}") + return {"has_conflict": False} + + async def _check_relation_consistency( + self, src_id: str, dst_id: str, relation_extractions: Dict[str, str] + ) -> Dict[str, Any]: + """Check relation consistency using LLM.""" + if len(set(relation_extractions.values())) <= 1: + return {"has_conflict": False} + + try: + rel_list = [f"Chunk {chunk_id}: {relation}" + for chunk_id, relation in relation_extractions.items() + if relation] + + prompt = RELATION_CONFLICT_PROMPT.format( + source_entity=src_id, + target_entity=dst_id, + relation_descriptions="\n".join(rel_list) + ) + + response = await self.llm_client.generate_answer(prompt) + + # Parse JSON response + try: + result = json.loads(response) + except json.JSONDecodeError: + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + result = json.loads(json_match.group(0)) + else: + logger.warning(f"Failed to parse relation conflict response for {src_id}->{dst_id}") + return {"has_conflict": False} + + return result + except Exception as e: + logger.error(f"Error checking relation consistency for {src_id}->{dst_id}: {e}") + return {"has_conflict": False} diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py index 3aee3475..24207c53 100644 --- a/graphgen/models/evaluator/kg/structure_evaluator.py +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -1,22 +1,40 @@ from typing import Any, Dict, Optional -try: - import networkx as nx -except ImportError: - nx = None +import networkx as nx +import numpy as np try: from scipy import stats except ImportError: stats = None -import numpy as np - from graphgen.bases import BaseGraphStorage -from graphgen.models.evaluator.kg.utils import convert_to_networkx from graphgen.utils import logger +def _convert_to_networkx(graph_storage: BaseGraphStorage) -> nx.DiGraph: + """Convert graph storage to NetworkX graph.""" + G = nx.DiGraph() + + # Add nodes + nodes = graph_storage.get_all_nodes() or [] + for node_id, node_data in nodes: + if isinstance(node_data, dict): + G.add_node(node_id, **node_data) + else: + G.add_node(node_id) + + # Add edges + edges = graph_storage.get_all_edges() or [] + for src, dst, edge_data in edges: + if isinstance(edge_data, dict): + G.add_edge(src, dst, **edge_data) + else: + G.add_edge(src, dst) + + return G + + class StructureEvaluator: """Evaluates structural robustness of the graph.""" @@ -37,11 +55,8 @@ def __init__( self.powerlaw_r2_threshold = powerlaw_r2_threshold def evaluate(self) -> Dict[str, Any]: - if nx is None: - return {"error": "NetworkX not installed"} - # Convert graph to NetworkX - G = convert_to_networkx(self.graph_storage) + G = _convert_to_networkx(self.graph_storage) if G.number_of_nodes() == 0: return {"error": "Empty graph"} diff --git a/graphgen/models/evaluator/kg/utils.py b/graphgen/models/evaluator/kg/utils.py deleted file mode 100644 index 64963d2c..00000000 --- a/graphgen/models/evaluator/kg/utils.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Any, List, Optional - -try: - import networkx as nx -except ImportError: - nx = None - -from graphgen.bases import BaseGraphStorage, BaseKVStorage - - -def convert_to_networkx(graph_storage: BaseGraphStorage) -> "nx.Graph": - """Convert graph storage to NetworkX graph.""" - if nx is None: - raise ImportError("NetworkX is required for structural evaluation") - - G = nx.DiGraph() - - # Add nodes - nodes = graph_storage.get_all_nodes() or [] - for node_id, node_data in nodes: - if isinstance(node_data, dict): - G.add_node(node_id, **node_data) - else: - G.add_node(node_id) - - # Add edges - edges = graph_storage.get_all_edges() or [] - for src, dst, edge_data in edges: - if isinstance(edge_data, dict): - G.add_edge(src, dst, **edge_data) - else: - G.add_edge(src, dst) - - return G - - -def get_source_text(chunk_storage: BaseKVStorage, chunk_id: Optional[str] = None) -> str: - """ - Get source text from chunk storage. - - Args: - chunk_storage: KV storage containing chunks - chunk_id: Optional chunk ID. If None, returns concatenated text from all chunks. - - Returns: - Source text content - """ - if chunk_id: - chunk = chunk_storage.get_by_id(chunk_id) - if chunk and isinstance(chunk, dict): - return chunk.get("content", "") - return "" - - # Get all chunks and concatenate - all_chunks = chunk_storage.get_all() - texts = [] - for chunk_data in all_chunks.values(): - if isinstance(chunk_data, dict): - content = chunk_data.get("content", "") - if content: - texts.append(content) - return "\n\n".join(texts) - - -def get_relevant_text( - chunk_storage: BaseKVStorage, source_id: Optional[str] = None -) -> str: - """Get relevant source text from chunk storage using source_id.""" - if source_id: - # Try to get specific chunk - chunk = chunk_storage.get_by_id(source_id) - if chunk and isinstance(chunk, dict): - return chunk.get("content", "") - # If source_id contains , try multiple chunks - if "" in str(source_id): - chunk_ids = [sid.strip() for sid in str(source_id).split("") if sid.strip()] - texts = [] - for cid in chunk_ids: - chunk = chunk_storage.get_by_id(cid) - if chunk and isinstance(chunk, dict): - content = chunk.get("content", "") - if content: - texts.append(content) - return "\n\n".join(texts) if texts else "" - - # Fallback to all chunks - return get_source_text(chunk_storage) - - -def sample_items(items: List[Any], sample_size: int) -> List[Any]: - """Sample items from a list.""" - import random - - if len(items) <= sample_size: - return items - return random.sample(items, sample_size) diff --git a/graphgen/models/evaluator/kg_quality_evaluator.py b/graphgen/models/evaluator/kg_quality_evaluator.py index 23ae6187..d8df9095 100644 --- a/graphgen/models/evaluator/kg_quality_evaluator.py +++ b/graphgen/models/evaluator/kg_quality_evaluator.py @@ -7,8 +7,8 @@ 3. robustness assessment (noise ratio, connectivity, degree distribution). """ -from dataclasses import dataclass -from typing import Any, Dict, Optional +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.models.evaluator.kg import ( @@ -27,7 +27,6 @@ class KGQualityEvaluator: graph_backend: str = "kuzu" kv_backend: str = "rocksdb" llm_client: Optional[BaseLLMWrapper] = None - sample_size: int = 100 max_concurrent: int = 10 def __post_init__(self): @@ -65,17 +64,21 @@ def evaluate_all(self) -> Dict[str, Any]: try: logger.info("Starting consistency evaluation...") consistency_evaluator = ConsistencyEvaluator( - graph_storage=self.graph_storage + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + max_concurrent=self.max_concurrent, ) results["consistency"] = consistency_evaluator.evaluate() except Exception as e: logger.error(f"Consistency evaluation failed: {e}") results["consistency"] = {"error": str(e)} - # Structural robustness evaluation try: logger.info("Starting structural robustness evaluation...") - structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) + structure_evaluator = StructureEvaluator( + graph_storage=self.graph_storage + ) results["structure"] = structure_evaluator.evaluate() except Exception as e: logger.error(f"Structural evaluation failed: {e}") @@ -88,15 +91,21 @@ def evaluate_accuracy(self) -> Dict[str, Any]: graph_storage=self.graph_storage, chunk_storage=self.chunk_storage, llm_client=self.llm_client, - sample_size=self.sample_size, max_concurrent=self.max_concurrent, ) return accuracy_evaluator.evaluate() def evaluate_consistency(self) -> Dict[str, Any]: - consistency_evaluator = ConsistencyEvaluator(graph_storage=self.graph_storage) + consistency_evaluator = ConsistencyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + max_concurrent=self.max_concurrent, + ) return consistency_evaluator.evaluate() def evaluate_structure(self) -> Dict[str, Any]: - structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) + structure_evaluator = StructureEvaluator( + graph_storage=self.graph_storage + ) return structure_evaluator.evaluate() diff --git a/graphgen/operators/evaluate_kg/evaluate_kg.py b/graphgen/operators/evaluate_kg/evaluate_kg.py index 2b7e74fd..4be8e91b 100644 --- a/graphgen/operators/evaluate_kg/evaluate_kg.py +++ b/graphgen/operators/evaluate_kg/evaluate_kg.py @@ -1,7 +1,6 @@ import argparse import json from pathlib import Path - from dotenv import load_dotenv from graphgen.models import KGQualityEvaluator @@ -32,12 +31,39 @@ def _print_accuracy_summary(acc): print("\n[Accuracy]") if "entity_accuracy" in acc: e = acc["entity_accuracy"] - print(f" Entity - Precision: {e.get('precision', 0):.3f}, " - f"Recall: {e.get('recall', 0):.3f}, F1: {e.get('f1', 0):.3f}") - if "triple_accuracy" in acc: - t = acc["triple_accuracy"] - print(f" Triple (RLC) - Precision: {t.get('precision', 0):.3f}, " - f"Recall: {t.get('recall', 0):.3f}, F1: {t.get('f1', 0):.3f}") + overall = e.get("overall_score", {}) + accuracy = e.get("accuracy", {}) + completeness = e.get("completeness", {}) + precision = e.get("precision", {}) + + print(f" Entity Extraction Quality:") + print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), " + f"{overall.get('median', 0):.3f} (median)") + print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " + f"{accuracy.get('median', 0):.3f} (median)") + print(f" Completeness: {completeness.get('mean', 0):.3f} (mean), " + f"{completeness.get('median', 0):.3f} (median)") + print(f" Precision: {precision.get('mean', 0):.3f} (mean), " + f"{precision.get('median', 0):.3f} (median)") + print(f" Total Chunks Evaluated: {e.get('total_chunks', 0)}") + + if "relation_accuracy" in acc: + r = acc["relation_accuracy"] + overall = r.get("overall_score", {}) + accuracy = r.get("accuracy", {}) + completeness = r.get("completeness", {}) + precision = r.get("precision", {}) + + print(f" Relation Extraction Quality:") + print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), " + f"{overall.get('median', 0):.3f} (median)") + print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " + f"{accuracy.get('median', 0):.3f} (median)") + print(f" Completeness: {completeness.get('mean', 0):.3f} (mean), " + f"{completeness.get('median', 0):.3f} (median)") + print(f" Precision: {precision.get('mean', 0):.3f} (mean), " + f"{precision.get('median', 0):.3f} (median)") + print(f" Total Chunks Evaluated: {r.get('total_chunks', 0)}") else: print(f"\n[Accuracy] Error: {acc['error']}") @@ -49,6 +75,17 @@ def _print_consistency_summary(cons): print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}") print(f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / " f"{cons.get('total_entities', 0)}") + entities_checked = cons.get('entities_checked', 0) + if entities_checked > 0: + print(f" Entities Checked: {entities_checked} (entities with multiple sources)") + conflicts = cons.get('conflicts', []) + if conflicts: + print(f" Total Conflicts Found: {len(conflicts)}") + # Show sample conflicts + sample_conflicts = conflicts[:3] + for conflict in sample_conflicts: + print(f" - {conflict.get('entity_id', 'N/A')}: {conflict.get('conflict_type', 'N/A')} " + f"(severity: {conflict.get('conflict_severity', 0):.2f})") else: print(f"\n[Consistency] Error: {cons['error']}") @@ -125,10 +162,9 @@ def main(): # Basic evaluation python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache - # Custom sample size and output + # Custom output python -m graphgen.operators.evaluate_kg.evaluate_kg \\ --working_dir cache \\ - --sample_size 200 \\ --output cache/kg_evaluation.json # Specify backends @@ -159,12 +195,6 @@ def main(): choices=["rocksdb", "json_kv"], help="KV storage backend (default: rocksdb)", ) - parser.add_argument( - "--sample_size", - type=int, - default=100, - help="Sample size for accuracy evaluation (default: 100)", - ) parser.add_argument( "--max_concurrent", type=int, @@ -211,15 +241,12 @@ def main(): logger.info(f"Working directory: {args.working_dir}") logger.info(f"Graph backend: {args.graph_backend}") logger.info(f"KV backend: {args.kv_backend}") - logger.info(f"Sample size: {args.sample_size}") - # Initialize evaluator try: evaluator = KGQualityEvaluator( working_dir=args.working_dir, graph_backend=args.graph_backend, kv_backend=args.kv_backend, - sample_size=args.sample_size, max_concurrent=args.max_concurrent, ) except Exception as e: From 5bfdc0a8130372a798bad682185c5a4724513d6c Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Wed, 24 Dec 2025 03:27:04 +0800 Subject: [PATCH 07/29] fix: fix format and clean up imports --- .../models/evaluator/kg/accuracy_evaluator.py | 65 ++++++++++--------- .../evaluator/kg/consistency_evaluator.py | 64 +++++++++--------- .../models/evaluator/kg_quality_evaluator.py | 4 +- graphgen/operators/evaluate_kg/evaluate_kg.py | 20 +++--- webui/utils/count_tokens.py | 4 +- 5 files changed, 80 insertions(+), 77 deletions(-) diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index ea40788a..afdfc175 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -84,12 +84,12 @@ class AccuracyEvaluator: """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge. - + For each chunk, uses LLM to evaluate the quality of extracted entities and relations by comparing them with the original chunk content. Provides multi-dimensional quality scores (accuracy, completeness, precision). """ - + def __init__( self, graph_storage: BaseGraphStorage, @@ -104,25 +104,25 @@ def __init__( def evaluate(self) -> Dict[str, Any]: """Evaluate entity and relation extraction quality using LLM-as-a-Judge. - + Returns: Dictionary containing entity_accuracy and relation_accuracy metrics. """ # 1. Load all chunks from storage chunks = self._load_chunks_from_storage() - + if not chunks: logger.warning("No chunks found in storage") return {"error": "No chunks found in storage"} - + logger.info(f"Found {len(chunks)} chunks to evaluate") - + # 2. Evaluate each chunk loop = create_event_loop() entity_evaluations, relation_evaluations = loop.run_until_complete( self._evaluate_all_chunks(chunks) ) - + # 3. Aggregate results return self._aggregate_evaluation_results(entity_evaluations, relation_evaluations) @@ -130,7 +130,7 @@ def _load_chunks_from_storage(self) -> List[Chunk]: """Load all chunks from chunk storage.""" chunks = [] all_chunk_data = self.chunk_storage.get_all() - + for chunk_id, chunk_data in all_chunk_data.items(): try: chunk = Chunk.from_dict(chunk_id, chunk_data) @@ -138,14 +138,14 @@ def _load_chunks_from_storage(self) -> List[Chunk]: except Exception as e: logger.warning(f"Failed to load chunk {chunk_id}: {e}") continue - + return chunks def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]: """Get all entities extracted from the specified chunk.""" entities = [] all_nodes = self.graph_storage.get_all_nodes() or [] - + for node_id, node_data in all_nodes: if not isinstance(node_data, dict): continue @@ -157,14 +157,14 @@ def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]: "entity_type": node_data.get("entity_type", ""), "description": node_data.get("description", "") }) - + return entities def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]: """Get all relations extracted from the specified chunk.""" relations = [] all_edges = self.graph_storage.get_all_edges() or [] - + for src_id, dst_id, edge_data in all_edges: if not isinstance(edge_data, dict): continue @@ -178,7 +178,7 @@ def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]: "target_entity": dst_node.get("entity_name", dst_id), "relationship_summary": edge_data.get("description", "") }) - + return relations async def _evaluate_all_chunks( @@ -186,31 +186,32 @@ async def _evaluate_all_chunks( ) -> tuple[List[Dict], List[Dict]]: """Evaluate all chunks concurrently.""" semaphore = asyncio.Semaphore(self.max_concurrent) - + async def evaluate_chunk(chunk: Chunk): async with semaphore: entities = self._get_extracted_entities_for_chunk(chunk.id) relations = self._get_extracted_relations_for_chunk(chunk.id) - + entity_eval = await self._evaluate_entity_extraction(chunk, entities) relation_eval = await self._evaluate_relation_extraction(chunk, relations) - + return entity_eval, relation_eval - + tasks = [evaluate_chunk(chunk) for chunk in chunks] results = await asyncio.gather(*tasks, return_exceptions=True) - + entity_evaluations = [] relation_evaluations = [] - + for i, result in enumerate(results): if isinstance(result, Exception): logger.error(f"Failed to evaluate chunk {chunks[i].id}: {result}") continue + entity_eval, relation_eval = result entity_evaluations.append(entity_eval) relation_evaluations.append(relation_eval) - + return entity_evaluations, relation_evaluations async def _evaluate_entity_extraction( @@ -222,9 +223,9 @@ async def _evaluate_entity_extraction( chunk_content=chunk.content, extracted_entities=json.dumps(extracted_entities, ensure_ascii=False, indent=2) ) - + response = await self.llm_client.generate_answer(prompt) - + # Try to parse JSON response try: evaluation_result = json.loads(response) @@ -246,14 +247,14 @@ async def _evaluate_entity_extraction( "precision_reasoning": "", "issues": ["LLM response parsing failed"] } - + # Validate and calculate overall_score if not provided if "overall_score" not in evaluation_result: accuracy = float(evaluation_result.get("accuracy", 0.0)) completeness = float(evaluation_result.get("completeness", 0.0)) precision = float(evaluation_result.get("precision", 0.0)) evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision - + return { "chunk_id": chunk.id, "chunk_content": chunk.content[:200] if chunk.content else "", # First 200 chars for debugging @@ -285,9 +286,9 @@ async def _evaluate_relation_extraction( chunk_content=chunk.content, extracted_relations=json.dumps(extracted_relations, ensure_ascii=False, indent=2) ) - + response = await self.llm_client.generate_answer(prompt) - + # Try to parse JSON response try: evaluation_result = json.loads(response) @@ -309,14 +310,14 @@ async def _evaluate_relation_extraction( "precision_reasoning": "", "issues": ["LLM response parsing failed"] } - + # Validate and calculate overall_score if not provided if "overall_score" not in evaluation_result: accuracy = float(evaluation_result.get("accuracy", 0.0)) completeness = float(evaluation_result.get("completeness", 0.0)) precision = float(evaluation_result.get("precision", 0.0)) evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision - + return { "chunk_id": chunk.id, "chunk_content": chunk.content[:200] if chunk.content else "", @@ -358,7 +359,7 @@ def calculate_stats(scores: List[float]) -> Dict[str, float]: median = sorted_scores[n // 2] if n % 2 == 1 else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2 variance = sum((x - mean) ** 2 for x in scores) / n std = variance ** 0.5 - + return { "mean": mean, "median": median, @@ -366,18 +367,18 @@ def calculate_stats(scores: List[float]) -> Dict[str, float]: "max": max(scores), "std": std } - + # Extract scores entity_overall_scores = [e.get("overall_score", 0.0) for e in entity_evaluations] entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations] entity_completeness_scores = [e.get("completeness", 0.0) for e in entity_evaluations] entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations] - + relation_overall_scores = [r.get("overall_score", 0.0) for r in relation_evaluations] relation_accuracy_scores = [r.get("accuracy", 0.0) for r in relation_evaluations] relation_completeness_scores = [r.get("completeness", 0.0) for r in relation_evaluations] relation_precision_scores = [r.get("precision", 0.0) for r in relation_evaluations] - + return { "entity_accuracy": { "overall_score": calculate_stats(entity_overall_scores), diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py index f616b98e..c67b6d32 100644 --- a/graphgen/models/evaluator/kg/consistency_evaluator.py +++ b/graphgen/models/evaluator/kg/consistency_evaluator.py @@ -95,7 +95,7 @@ - technology: 技术 - mission: 任务 - gene: 基因 - + 如果无法确定类型,请使用 "concept" 作为默认值。 2. description: 实体描述(简要描述该实体在文本中的作用和特征) @@ -110,7 +110,7 @@ class ConsistencyEvaluator: """Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge. - + For entities with multiple source chunks, compares entity_type and description extracted from different chunks to detect semantic conflicts. """ @@ -161,7 +161,7 @@ async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: # Evaluate entities concurrently semaphore = asyncio.Semaphore(self.max_concurrent) - + async def evaluate_entity(entity_info): async with semaphore: return await self._evaluate_entity_consistency(entity_info) @@ -172,12 +172,12 @@ async def evaluate_entity(entity_info): # Aggregate results conflicts = [] conflict_entities = set() - + for i, result in enumerate(results): if isinstance(result, Exception): logger.error(f"Failed to evaluate entity {entities_with_multiple_sources[i][0]}: {result}") continue - + entity_id, entity_conflicts = result if entity_conflicts: conflicts.extend(entity_conflicts) @@ -208,7 +208,7 @@ async def _evaluate_entity_consistency( self, entity_info: tuple ) -> tuple[str, List[Dict]]: """Evaluate consistency for a single entity.""" - entity_id, node_data, source_ids = entity_info + entity_id, _node_data, source_ids = entity_info # Clean entity_id for display clean_entity_id = self._clean_entity_id(entity_id) conflicts = [] @@ -287,14 +287,14 @@ async def _extract_entity_from_chunk( try: # Clean entity_id: remove surrounding quotes if present clean_entity_id = self._clean_entity_id(entity_id) - + prompt = ENTITY_EXTRACTION_PROMPT.format( entity_name=clean_entity_id, chunk_content=chunk.content[:2000] if chunk.content else "" # Limit content length ) - + response = await self.llm_client.generate_answer(prompt) - + # Try to parse JSON response try: extraction = json.loads(response) @@ -306,7 +306,7 @@ async def _extract_entity_from_chunk( else: logger.warning(f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}") return {} - + # Normalize entity_type to lowercase and validate entity_type = extraction.get("entity_type", "").lower().strip() # Valid preset types @@ -323,7 +323,7 @@ async def _extract_entity_from_chunk( f"defaulting to 'concept'" ) entity_type = "concept" - + return { "entity_type": entity_type, "description": extraction.get("description", ""), @@ -339,19 +339,19 @@ async def _check_entity_type_consistency( if len(set(type_extractions.values())) <= 1: # All types are the same, no conflict return {"has_conflict": False} - + try: - type_list = [f"Chunk {chunk_id}: {entity_type}" - for chunk_id, entity_type in type_extractions.items() + type_list = [f"Chunk {chunk_id}: {entity_type}" + for chunk_id, entity_type in type_extractions.items() if entity_type] - + prompt = ENTITY_TYPE_CONFLICT_PROMPT.format( entity_name=entity_id, type_extractions="\n".join(type_list) ) - + response = await self.llm_client.generate_answer(prompt) - + # Parse JSON response try: result = json.loads(response) @@ -362,7 +362,7 @@ async def _check_entity_type_consistency( else: logger.warning(f"Failed to parse conflict detection response for {entity_id}") return {"has_conflict": False} - + return result except Exception as e: logger.error(f"Error checking type consistency for {entity_id}: {e}") @@ -376,22 +376,22 @@ async def _check_entity_description_consistency( valid_descriptions = {k: v for k, v in descriptions.items() if v} if len(valid_descriptions) < 2: return {"has_conflict": False} - + if len(set(valid_descriptions.values())) <= 1: # All descriptions are the same, no conflict return {"has_conflict": False} - + try: - desc_list = [f"Chunk {chunk_id}: {description}" + desc_list = [f"Chunk {chunk_id}: {description}" for chunk_id, description in valid_descriptions.items()] - + prompt = ENTITY_DESCRIPTION_CONFLICT_PROMPT.format( entity_name=entity_id, descriptions="\n".join(desc_list) ) - + response = await self.llm_client.generate_answer(prompt) - + # Parse JSON response try: result = json.loads(response) @@ -402,7 +402,7 @@ async def _check_entity_description_consistency( else: logger.warning(f"Failed to parse conflict detection response for {entity_id}") return {"has_conflict": False} - + return result except Exception as e: logger.error(f"Error checking description consistency for {entity_id}: {e}") @@ -414,20 +414,20 @@ async def _check_relation_consistency( """Check relation consistency using LLM.""" if len(set(relation_extractions.values())) <= 1: return {"has_conflict": False} - + try: - rel_list = [f"Chunk {chunk_id}: {relation}" - for chunk_id, relation in relation_extractions.items() + rel_list = [f"Chunk {chunk_id}: {relation}" + for chunk_id, relation in relation_extractions.items() if relation] - + prompt = RELATION_CONFLICT_PROMPT.format( source_entity=src_id, target_entity=dst_id, relation_descriptions="\n".join(rel_list) ) - + response = await self.llm_client.generate_answer(prompt) - + # Parse JSON response try: result = json.loads(response) @@ -438,7 +438,7 @@ async def _check_relation_consistency( else: logger.warning(f"Failed to parse relation conflict response for {src_id}->{dst_id}") return {"has_conflict": False} - + return result except Exception as e: logger.error(f"Error checking relation consistency for {src_id}->{dst_id}: {e}") diff --git a/graphgen/models/evaluator/kg_quality_evaluator.py b/graphgen/models/evaluator/kg_quality_evaluator.py index d8df9095..019fbf68 100644 --- a/graphgen/models/evaluator/kg_quality_evaluator.py +++ b/graphgen/models/evaluator/kg_quality_evaluator.py @@ -7,8 +7,8 @@ 3. robustness assessment (noise ratio, connectivity, degree distribution). """ -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import Any, Dict, Optional from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.models.evaluator.kg import ( diff --git a/graphgen/operators/evaluate_kg/evaluate_kg.py b/graphgen/operators/evaluate_kg/evaluate_kg.py index 4be8e91b..b0621d4c 100644 --- a/graphgen/operators/evaluate_kg/evaluate_kg.py +++ b/graphgen/operators/evaluate_kg/evaluate_kg.py @@ -35,8 +35,8 @@ def _print_accuracy_summary(acc): accuracy = e.get("accuracy", {}) completeness = e.get("completeness", {}) precision = e.get("precision", {}) - - print(f" Entity Extraction Quality:") + + print(" Entity Extraction Quality:") print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), " f"{overall.get('median', 0):.3f} (median)") print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " @@ -46,15 +46,15 @@ def _print_accuracy_summary(acc): print(f" Precision: {precision.get('mean', 0):.3f} (mean), " f"{precision.get('median', 0):.3f} (median)") print(f" Total Chunks Evaluated: {e.get('total_chunks', 0)}") - + if "relation_accuracy" in acc: r = acc["relation_accuracy"] overall = r.get("overall_score", {}) accuracy = r.get("accuracy", {}) completeness = r.get("completeness", {}) precision = r.get("precision", {}) - - print(f" Relation Extraction Quality:") + + print(" Relation Extraction Quality:") print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), " f"{overall.get('median', 0):.3f} (median)") print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " @@ -96,23 +96,23 @@ def _print_structure_summary(struct): print("\n[Structural Robustness]") print(f" Total Nodes: {struct.get('total_nodes', 0)}") print(f" Total Edges: {struct.get('total_edges', 0)}") - + thresholds = struct.get("thresholds", {}) - + # Noise Ratio noise_check = thresholds.get("noise_ratio", {}) noise_threshold = noise_check.get("threshold", "N/A") noise_pass = noise_check.get("pass", False) print(f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} " f"({'✓' if noise_pass else '✗'} < {noise_threshold})") - + # Largest CC Ratio lcc_check = thresholds.get("largest_cc_ratio", {}) lcc_threshold = lcc_check.get("threshold", "N/A") lcc_pass = lcc_check.get("pass", False) print(f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} " f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})") - + # Avg Degree avg_degree_check = thresholds.get("avg_degree", {}) avg_degree_threshold = avg_degree_check.get("threshold", "N/A") @@ -124,7 +124,7 @@ def _print_structure_summary(struct): threshold_str = str(avg_degree_threshold) print(f" Avg Degree: {struct.get('avg_degree', 0):.2f} " f"({'✓' if avg_degree_pass else '✗'} {threshold_str})") - + # Power Law R² if struct.get('powerlaw_r2') is not None: powerlaw_check = thresholds.get("powerlaw_r2", {}) diff --git a/webui/utils/count_tokens.py b/webui/utils/count_tokens.py index 82b5522c..3016ac5c 100644 --- a/webui/utils/count_tokens.py +++ b/webui/utils/count_tokens.py @@ -7,10 +7,12 @@ # pylint: disable=wrong-import-position root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(root_dir) -from graphgen.models import Tokenizer def count_tokens(file, tokenizer_name, data_frame): + # Lazy import to avoid circular dependency + from graphgen.models import Tokenizer # pylint: disable=import-outside-toplevel + if not file or not os.path.exists(file): return data_frame From 42693dfa467f981a0d467879dd55b505447bf6ba Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 24 Dec 2025 15:57:20 +0800 Subject: [PATCH 08/29] wip: refactor evaluator structure --- examples/evaluate/evaluate.sh | 3 - examples/evaluate/evaluate_kg/evaluate_kg.sh | 2 + .../evaluate_kg/kg_evaluation_config.yaml | 41 ++++++++ examples/evaluate/evaluate_qa/evaluate.sh | 2 + .../evaluate_qa/qa_evaluation_config.yaml | 90 ++++++++++++++++++ examples/evaluate_kg/evaluate_kg.sh | 5 - .../{evaluate_kg => evaluate}/evaluate_kg.py | 95 ++++++++++++------- .../evaluate/{evaluate.py => evaluate_qa.py} | 0 .../operators/evaluate/evaluate_service.py | 22 +++++ graphgen/operators/evaluate_kg/__init__.py | 0 10 files changed, 220 insertions(+), 40 deletions(-) delete mode 100644 examples/evaluate/evaluate.sh create mode 100644 examples/evaluate/evaluate_kg/evaluate_kg.sh create mode 100644 examples/evaluate/evaluate_kg/kg_evaluation_config.yaml create mode 100644 examples/evaluate/evaluate_qa/evaluate.sh create mode 100644 examples/evaluate/evaluate_qa/qa_evaluation_config.yaml delete mode 100644 examples/evaluate_kg/evaluate_kg.sh rename graphgen/operators/{evaluate_kg => evaluate}/evaluate_kg.py (75%) rename graphgen/operators/evaluate/{evaluate.py => evaluate_qa.py} (100%) create mode 100644 graphgen/operators/evaluate/evaluate_service.py delete mode 100644 graphgen/operators/evaluate_kg/__init__.py diff --git a/examples/evaluate/evaluate.sh b/examples/evaluate/evaluate.sh deleted file mode 100644 index 2b352669..00000000 --- a/examples/evaluate/evaluate.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.evaluate --folder cache/data \ - --reward "OpenAssistant/reward-model-deberta-v3-large-v2,BAAI/IndustryCorpus2_DataRater" \ - --uni MingZhong/unieval-sum \ diff --git a/examples/evaluate/evaluate_kg/evaluate_kg.sh b/examples/evaluate/evaluate_kg/evaluate_kg.sh new file mode 100644 index 00000000..ac40b0f6 --- /dev/null +++ b/examples/evaluate/evaluate_kg/evaluate_kg.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/evaluate/evaluate_kg/evaluate_kg_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml new file mode 100644 index 00000000..334c2a85 --- /dev/null +++ b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml @@ -0,0 +1,41 @@ +global_params: + working_dir: cache + graph_backend: kuzu # graph database backend, support: kuzu, networkx + kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/extract_demo.txt + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 20480 # larger chunk size for better context + chunk_overlap: 2000 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: evaluate + op_name: evaluate + type: aggregate + dependencies: + - build_kg + params: + metrics: diff --git a/examples/evaluate/evaluate_qa/evaluate.sh b/examples/evaluate/evaluate_qa/evaluate.sh new file mode 100644 index 00000000..8c637d1f --- /dev/null +++ b/examples/evaluate/evaluate_qa/evaluate.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/evaluate/evaluate_qa/evaluate_qa_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml new file mode 100644 index 00000000..b62e9ee5 --- /dev/null +++ b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml @@ -0,0 +1,90 @@ +global_params: + working_dir: cache + graph_backend: kuzu # graph database backend, support: kuzu, networkx + kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: aggregate + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 # number of quiz samples to generate + concurrency_limit: 200 + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: aggregated # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML + + - id: evaluate + op_name: evaluate + type: map_batch + dependencies: + - generate + execution_params: + replicas: 1 + batch_size: 128 + params: + metrics: diff --git a/examples/evaluate_kg/evaluate_kg.sh b/examples/evaluate_kg/evaluate_kg.sh deleted file mode 100644 index cda034bc..00000000 --- a/examples/evaluate_kg/evaluate_kg.sh +++ /dev/null @@ -1,5 +0,0 @@ -python3 -m graphgen.operators.evaluate_kg.evaluate_kg \ - --working_dir cache \ - --graph_backend kuzu \ - --kv_backend rocksdb \ - --max_concurrent 10 diff --git a/graphgen/operators/evaluate_kg/evaluate_kg.py b/graphgen/operators/evaluate/evaluate_kg.py similarity index 75% rename from graphgen/operators/evaluate_kg/evaluate_kg.py rename to graphgen/operators/evaluate/evaluate_kg.py index b0621d4c..4d3a62c8 100644 --- a/graphgen/operators/evaluate_kg/evaluate_kg.py +++ b/graphgen/operators/evaluate/evaluate_kg.py @@ -1,6 +1,7 @@ import argparse import json from pathlib import Path + from dotenv import load_dotenv from graphgen.models import KGQualityEvaluator @@ -37,14 +38,22 @@ def _print_accuracy_summary(acc): precision = e.get("precision", {}) print(" Entity Extraction Quality:") - print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), " - f"{overall.get('median', 0):.3f} (median)") - print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " - f"{accuracy.get('median', 0):.3f} (median)") - print(f" Completeness: {completeness.get('mean', 0):.3f} (mean), " - f"{completeness.get('median', 0):.3f} (median)") - print(f" Precision: {precision.get('mean', 0):.3f} (mean), " - f"{precision.get('median', 0):.3f} (median)") + print( + f" Overall Score: {overall.get('mean', 0):.3f} (mean), " + f"{overall.get('median', 0):.3f} (median)" + ) + print( + f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " + f"{accuracy.get('median', 0):.3f} (median)" + ) + print( + f" Completeness: {completeness.get('mean', 0):.3f} (mean), " + f"{completeness.get('median', 0):.3f} (median)" + ) + print( + f" Precision: {precision.get('mean', 0):.3f} (mean), " + f"{precision.get('median', 0):.3f} (median)" + ) print(f" Total Chunks Evaluated: {e.get('total_chunks', 0)}") if "relation_accuracy" in acc: @@ -55,14 +64,22 @@ def _print_accuracy_summary(acc): precision = r.get("precision", {}) print(" Relation Extraction Quality:") - print(f" Overall Score: {overall.get('mean', 0):.3f} (mean), " - f"{overall.get('median', 0):.3f} (median)") - print(f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " - f"{accuracy.get('median', 0):.3f} (median)") - print(f" Completeness: {completeness.get('mean', 0):.3f} (mean), " - f"{completeness.get('median', 0):.3f} (median)") - print(f" Precision: {precision.get('mean', 0):.3f} (mean), " - f"{precision.get('median', 0):.3f} (median)") + print( + f" Overall Score: {overall.get('mean', 0):.3f} (mean), " + f"{overall.get('median', 0):.3f} (median)" + ) + print( + f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " + f"{accuracy.get('median', 0):.3f} (median)" + ) + print( + f" Completeness: {completeness.get('mean', 0):.3f} (mean), " + f"{completeness.get('median', 0):.3f} (median)" + ) + print( + f" Precision: {precision.get('mean', 0):.3f} (mean), " + f"{precision.get('median', 0):.3f} (median)" + ) print(f" Total Chunks Evaluated: {r.get('total_chunks', 0)}") else: print(f"\n[Accuracy] Error: {acc['error']}") @@ -73,19 +90,25 @@ def _print_consistency_summary(cons): if "error" not in cons: print("\n[Consistency]") print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}") - print(f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / " - f"{cons.get('total_entities', 0)}") - entities_checked = cons.get('entities_checked', 0) + print( + f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / " + f"{cons.get('total_entities', 0)}" + ) + entities_checked = cons.get("entities_checked", 0) if entities_checked > 0: - print(f" Entities Checked: {entities_checked} (entities with multiple sources)") - conflicts = cons.get('conflicts', []) + print( + f" Entities Checked: {entities_checked} (entities with multiple sources)" + ) + conflicts = cons.get("conflicts", []) if conflicts: print(f" Total Conflicts Found: {len(conflicts)}") # Show sample conflicts sample_conflicts = conflicts[:3] for conflict in sample_conflicts: - print(f" - {conflict.get('entity_id', 'N/A')}: {conflict.get('conflict_type', 'N/A')} " - f"(severity: {conflict.get('conflict_severity', 0):.2f})") + print( + f" - {conflict.get('entity_id', 'N/A')}: {conflict.get('conflict_type', 'N/A')} " + f"(severity: {conflict.get('conflict_severity', 0):.2f})" + ) else: print(f"\n[Consistency] Error: {cons['error']}") @@ -103,15 +126,19 @@ def _print_structure_summary(struct): noise_check = thresholds.get("noise_ratio", {}) noise_threshold = noise_check.get("threshold", "N/A") noise_pass = noise_check.get("pass", False) - print(f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} " - f"({'✓' if noise_pass else '✗'} < {noise_threshold})") + print( + f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} " + f"({'✓' if noise_pass else '✗'} < {noise_threshold})" + ) # Largest CC Ratio lcc_check = thresholds.get("largest_cc_ratio", {}) lcc_threshold = lcc_check.get("threshold", "N/A") lcc_pass = lcc_check.get("pass", False) - print(f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} " - f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})") + print( + f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} " + f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})" + ) # Avg Degree avg_degree_check = thresholds.get("avg_degree", {}) @@ -122,16 +149,20 @@ def _print_structure_summary(struct): threshold_str = f"{avg_degree_threshold[0]}-{avg_degree_threshold[1]}" else: threshold_str = str(avg_degree_threshold) - print(f" Avg Degree: {struct.get('avg_degree', 0):.2f} " - f"({'✓' if avg_degree_pass else '✗'} {threshold_str})") + print( + f" Avg Degree: {struct.get('avg_degree', 0):.2f} " + f"({'✓' if avg_degree_pass else '✗'} {threshold_str})" + ) # Power Law R² - if struct.get('powerlaw_r2') is not None: + if struct.get("powerlaw_r2") is not None: powerlaw_check = thresholds.get("powerlaw_r2", {}) powerlaw_threshold = powerlaw_check.get("threshold", "N/A") powerlaw_pass = powerlaw_check.get("pass", False) - print(f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} " - f"({'✓' if powerlaw_pass else '✗'} > {powerlaw_threshold})") + print( + f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} " + f"({'✓' if powerlaw_pass else '✗'} > {powerlaw_threshold})" + ) else: print(f"\n[Structural Robustness] Error: {struct['error']}") diff --git a/graphgen/operators/evaluate/evaluate.py b/graphgen/operators/evaluate/evaluate_qa.py similarity index 100% rename from graphgen/operators/evaluate/evaluate.py rename to graphgen/operators/evaluate/evaluate_qa.py diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py new file mode 100644 index 00000000..385a7184 --- /dev/null +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -0,0 +1,22 @@ +import pandas as pd + +from graphgen.bases import BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm + + +class EvaluateService(BaseOperator): + """ + 1. KG Quality Evaluation + 2. QA Quality Evaluation + """ + + def __init__(self, working_dir: str = "cache"): + super().__init__(working_dir=working_dir, op_name="evaluate_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + return pd.DataFrame(self.evaluate(items)) + + def evaluate(self, items: list[dict]) -> list[dict]: + pass diff --git a/graphgen/operators/evaluate_kg/__init__.py b/graphgen/operators/evaluate_kg/__init__.py deleted file mode 100644 index e69de29b..00000000 From a2572468fb56314d65d2b5b38e7185244b83b8fd Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Wed, 24 Dec 2025 23:44:00 +0800 Subject: [PATCH 09/29] wip: add annotations --- .../evaluate/evaluate_kg/kg_evaluation_config.yaml | 1 + graphgen/operators/evaluate/evaluate_service.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml index 334c2a85..4ff65818 100644 --- a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml +++ b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml @@ -35,6 +35,7 @@ nodes: - id: evaluate op_name: evaluate type: aggregate + save_output: true dependencies: - build_kg params: diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index 385a7184..49d45703 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -10,13 +10,24 @@ class EvaluateService(BaseOperator): 2. QA Quality Evaluation """ - def __init__(self, working_dir: str = "cache"): + def __init__(self, working_dir: str = "cache", metrics: list[str] = None): + # optional 传入 graph super().__init__(working_dir=working_dir, op_name="evaluate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.metrics = metrics or [] + + self.evaluators = { + "xxx": "xxxEvaluator" + } + + self.graph_storage = init_storage( + xx, xx, xx + ) def process(self, batch: pd.DataFrame) -> pd.DataFrame: items = batch.to_dict(orient="records") return pd.DataFrame(self.evaluate(items)) def evaluate(self, items: list[dict]) -> list[dict]: + # 用evaluators 评估 items pass From 41015a2fd3ac4552524abc961f7deecc39b811f6 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 25 Dec 2025 13:42:06 +0800 Subject: [PATCH 10/29] refactor: refactor proj structure & configs --- examples/evaluate/evaluate_kg/kg_evaluation_config.yaml | 3 +++ examples/evaluate/evaluate_qa/qa_evaluation_config.yaml | 6 ++++++ graphgen/models/evaluator/__init__.py | 7 ++----- .../models/evaluator/{ => kg}/kg_quality_evaluator.py | 8 ++------ graphgen/models/evaluator/qa/__init__.py | 0 graphgen/models/evaluator/{ => qa}/length_evaluator.py | 0 graphgen/models/evaluator/{ => qa}/mtld_evaluator.py | 0 graphgen/models/evaluator/{ => qa}/reward_evaluator.py | 0 graphgen/models/evaluator/{ => qa}/uni_evaluator.py | 0 9 files changed, 13 insertions(+), 11 deletions(-) rename graphgen/models/evaluator/{ => kg}/kg_quality_evaluator.py (94%) create mode 100644 graphgen/models/evaluator/qa/__init__.py rename graphgen/models/evaluator/{ => qa}/length_evaluator.py (100%) rename graphgen/models/evaluator/{ => qa}/mtld_evaluator.py (100%) rename graphgen/models/evaluator/{ => qa}/reward_evaluator.py (100%) rename graphgen/models/evaluator/{ => qa}/uni_evaluator.py (100%) diff --git a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml index 4ff65818..57c6f307 100644 --- a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml +++ b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml @@ -40,3 +40,6 @@ nodes: - build_kg params: metrics: + - kg_accuracy + - kg_consistency + - kg_structure diff --git a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml index b62e9ee5..45e9d3a7 100644 --- a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml +++ b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml @@ -74,6 +74,7 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true params: method: aggregated # atomic, aggregated, multi_hop, cot, vqa data_format: ChatML # Alpaca, Sharegpt, ChatML @@ -86,5 +87,10 @@ nodes: execution_params: replicas: 1 batch_size: 128 + save_output: true params: metrics: + - qa_length + - qa_mtld + - qa_reward_score + - qa_uni_score diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index 5f2716fc..83a48aaa 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -1,5 +1,2 @@ -from .kg_quality_evaluator import KGQualityEvaluator -from .length_evaluator import LengthEvaluator -from .mtld_evaluator import MTLDEvaluator -from .reward_evaluator import RewardEvaluator -from .uni_evaluator import UniEvaluator +from graphgen.models.evaluator.kg.kg_quality_evaluator import KGQualityEvaluator +from graphgen.models.evaluator.qa.uni_evaluator import UniEvaluator diff --git a/graphgen/models/evaluator/kg_quality_evaluator.py b/graphgen/models/evaluator/kg/kg_quality_evaluator.py similarity index 94% rename from graphgen/models/evaluator/kg_quality_evaluator.py rename to graphgen/models/evaluator/kg/kg_quality_evaluator.py index 019fbf68..0a1e4e9a 100644 --- a/graphgen/models/evaluator/kg_quality_evaluator.py +++ b/graphgen/models/evaluator/kg/kg_quality_evaluator.py @@ -76,9 +76,7 @@ def evaluate_all(self) -> Dict[str, Any]: try: logger.info("Starting structural robustness evaluation...") - structure_evaluator = StructureEvaluator( - graph_storage=self.graph_storage - ) + structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) results["structure"] = structure_evaluator.evaluate() except Exception as e: logger.error(f"Structural evaluation failed: {e}") @@ -105,7 +103,5 @@ def evaluate_consistency(self) -> Dict[str, Any]: return consistency_evaluator.evaluate() def evaluate_structure(self) -> Dict[str, Any]: - structure_evaluator = StructureEvaluator( - graph_storage=self.graph_storage - ) + structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) return structure_evaluator.evaluate() diff --git a/graphgen/models/evaluator/qa/__init__.py b/graphgen/models/evaluator/qa/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/evaluator/length_evaluator.py b/graphgen/models/evaluator/qa/length_evaluator.py similarity index 100% rename from graphgen/models/evaluator/length_evaluator.py rename to graphgen/models/evaluator/qa/length_evaluator.py diff --git a/graphgen/models/evaluator/mtld_evaluator.py b/graphgen/models/evaluator/qa/mtld_evaluator.py similarity index 100% rename from graphgen/models/evaluator/mtld_evaluator.py rename to graphgen/models/evaluator/qa/mtld_evaluator.py diff --git a/graphgen/models/evaluator/reward_evaluator.py b/graphgen/models/evaluator/qa/reward_evaluator.py similarity index 100% rename from graphgen/models/evaluator/reward_evaluator.py rename to graphgen/models/evaluator/qa/reward_evaluator.py diff --git a/graphgen/models/evaluator/uni_evaluator.py b/graphgen/models/evaluator/qa/uni_evaluator.py similarity index 100% rename from graphgen/models/evaluator/uni_evaluator.py rename to graphgen/models/evaluator/qa/uni_evaluator.py From 978b76c80f97f651c24d9878893b057e1590db24 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 25 Dec 2025 16:47:34 +0800 Subject: [PATCH 11/29] wip: split prompts --- graphgen/models/evaluator/kg/__init__.py | 9 + .../models/evaluator/kg/accuracy_evaluator.py | 212 +++++++-------- .../evaluator/kg/consistency_evaluator.py | 243 +++++++----------- .../evaluator/kg/kg_quality_evaluator.py | 107 -------- .../operators/evaluate/evaluate_service.py | 7 +- graphgen/templates/__init__.py | 1 + graphgen/templates/evaluation/__init__.py | 1 + graphgen/templates/evaluation/kg/__init__.py | 1 + .../evaluation/kg/accuracy_evaluation.py | 156 +++++++++++ .../evaluation/kg/consistency_evaluation.py | 97 +++++++ 10 files changed, 451 insertions(+), 383 deletions(-) delete mode 100644 graphgen/models/evaluator/kg/kg_quality_evaluator.py create mode 100644 graphgen/templates/evaluation/__init__.py create mode 100644 graphgen/templates/evaluation/kg/__init__.py create mode 100644 graphgen/templates/evaluation/kg/accuracy_evaluation.py create mode 100644 graphgen/templates/evaluation/kg/consistency_evaluation.py diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py index 4a7f794b..375cbc50 100644 --- a/graphgen/models/evaluator/kg/__init__.py +++ b/graphgen/models/evaluator/kg/__init__.py @@ -1,3 +1,12 @@ +""" +Knowledge Graph Quality Evaluator + +This module provides comprehensive quality evaluation for knowledge graphs, +1. accuracy assessment (entity/relation/triple validation), +2. consistency assessment (attribute conflict detection), and structural +3. robustness assessment (noise ratio, connectivity, degree distribution). +""" + from .accuracy_evaluator import AccuracyEvaluator from .consistency_evaluator import ConsistencyEvaluator from .structure_evaluator import StructureEvaluator diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index afdfc175..f9d2e405 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -5,81 +5,8 @@ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.bases.datatypes import Chunk -from graphgen.utils import create_event_loop, logger - - -# LLM-as-a-Judge evaluation prompts -ENTITY_EVALUATION_PROMPT = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的实体列表,评估实体提取的质量。 - -评估维度: -1. ACCURACY (准确性, 权重: 40%): 提取的实体是否正确,是否有误提取或错误识别 -2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要实体 -3. PRECISION (精确性, 权重: 20%): 提取的实体是否精确,命名是否准确 - -评分标准(每个维度 0-1 分): -- EXCELLENT (0.8-1.0): 高质量提取 -- GOOD (0.6-0.79): 良好质量,有少量问题 -- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 -- POOR (0.0-0.39): 质量差,需要改进 - -综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision - -请评估以下内容: - -原始文本块: -{chunk_content} - -提取的实体列表: -{extracted_entities} - -请以 JSON 格式返回评估结果: -{{ - "accuracy": <0-1之间的浮点数>, - "completeness": <0-1之间的浮点数>, - "precision": <0-1之间的浮点数>, - "overall_score": <综合评分>, - "accuracy_reasoning": "<准确性评估理由>", - "completeness_reasoning": "<完整性评估理由,包括遗漏的重要实体>", - "precision_reasoning": "<精确性评估理由>", - "issues": ["<发现的问题列表>"] -}} -""" - -RELATION_EVALUATION_PROMPT = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的关系列表,评估关系抽取的质量。 - -评估维度: -1. ACCURACY (准确性, 权重: 40%): 提取的关系是否正确,关系描述是否准确 -2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要关系 -3. PRECISION (精确性, 权重: 20%): 关系描述是否精确,是否过于宽泛 - -评分标准(每个维度 0-1 分): -- EXCELLENT (0.8-1.0): 高质量提取 -- GOOD (0.6-0.79): 良好质量,有少量问题 -- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 -- POOR (0.0-0.39): 质量差,需要改进 - -综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision - -请评估以下内容: - -原始文本块: -{chunk_content} - -提取的关系列表: -{extracted_relations} - -请以 JSON 格式返回评估结果: -{{ - "accuracy": <0-1之间的浮点数>, - "completeness": <0-1之间的浮点数>, - "precision": <0-1之间的浮点数>, - "overall_score": <综合评分>, - "accuracy_reasoning": "<准确性评估理由>", - "completeness_reasoning": "<完整性评估理由,包括遗漏的重要关系>", - "precision_reasoning": "<精确性评估理由>", - "issues": ["<发现的问题列表>"] -}} -""" +from graphgen.templates import ACCURACY_EVALUATION_PROMPT +from graphgen.utils import create_event_loop, detect_main_language, logger class AccuracyEvaluator: @@ -95,12 +22,10 @@ def __init__( graph_storage: BaseGraphStorage, chunk_storage: BaseKVStorage, llm_client: BaseLLMWrapper, - max_concurrent: int = 10, ): self.graph_storage = graph_storage self.chunk_storage = chunk_storage self.llm_client = llm_client - self.max_concurrent = max_concurrent def evaluate(self) -> Dict[str, Any]: """Evaluate entity and relation extraction quality using LLM-as-a-Judge. @@ -124,7 +49,9 @@ def evaluate(self) -> Dict[str, Any]: ) # 3. Aggregate results - return self._aggregate_evaluation_results(entity_evaluations, relation_evaluations) + return self._aggregate_evaluation_results( + entity_evaluations, relation_evaluations + ) def _load_chunks_from_storage(self) -> List[Chunk]: """Load all chunks from chunk storage.""" @@ -152,11 +79,13 @@ def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]: source_ids = node_data.get("source_id", "").split("") # Check if this chunk_id is in the source_ids if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: - entities.append({ - "entity_name": node_data.get("entity_name", node_id), - "entity_type": node_data.get("entity_type", ""), - "description": node_data.get("description", "") - }) + entities.append( + { + "entity_name": node_data.get("entity_name", node_id), + "entity_type": node_data.get("entity_type", ""), + "description": node_data.get("description", ""), + } + ) return entities @@ -173,11 +102,13 @@ def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]: if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: src_node = self.graph_storage.get_node(src_id) or {} dst_node = self.graph_storage.get_node(dst_id) or {} - relations.append({ - "source_entity": src_node.get("entity_name", src_id), - "target_entity": dst_node.get("entity_name", dst_id), - "relationship_summary": edge_data.get("description", "") - }) + relations.append( + { + "source_entity": src_node.get("entity_name", src_id), + "target_entity": dst_node.get("entity_name", dst_id), + "relationship_summary": edge_data.get("description", ""), + } + ) return relations @@ -193,7 +124,9 @@ async def evaluate_chunk(chunk: Chunk): relations = self._get_extracted_relations_for_chunk(chunk.id) entity_eval = await self._evaluate_entity_extraction(chunk, entities) - relation_eval = await self._evaluate_relation_extraction(chunk, relations) + relation_eval = await self._evaluate_relation_extraction( + chunk, relations + ) return entity_eval, relation_eval @@ -221,7 +154,9 @@ async def _evaluate_entity_extraction( try: prompt = ENTITY_EVALUATION_PROMPT.format( chunk_content=chunk.content, - extracted_entities=json.dumps(extracted_entities, ensure_ascii=False, indent=2) + extracted_entities=json.dumps( + extracted_entities, ensure_ascii=False, indent=2 + ), ) response = await self.llm_client.generate_answer(prompt) @@ -231,11 +166,13 @@ async def _evaluate_entity_extraction( evaluation_result = json.loads(response) except json.JSONDecodeError: # Try to extract JSON from markdown code blocks or other formats - json_match = re.search(r'\{.*\}', response, re.DOTALL) + json_match = re.search(r"\{.*\}", response, re.DOTALL) if json_match: evaluation_result = json.loads(json_match.group(0)) else: - logger.warning(f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}") + logger.warning( + f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}" + ) # Return default evaluation evaluation_result = { "accuracy": 0.0, @@ -245,7 +182,7 @@ async def _evaluate_entity_extraction( "accuracy_reasoning": "Failed to parse LLM response", "completeness_reasoning": "", "precision_reasoning": "", - "issues": ["LLM response parsing failed"] + "issues": ["LLM response parsing failed"], } # Validate and calculate overall_score if not provided @@ -253,16 +190,22 @@ async def _evaluate_entity_extraction( accuracy = float(evaluation_result.get("accuracy", 0.0)) completeness = float(evaluation_result.get("completeness", 0.0)) precision = float(evaluation_result.get("precision", 0.0)) - evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + evaluation_result["overall_score"] = ( + 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + ) return { "chunk_id": chunk.id, - "chunk_content": chunk.content[:200] if chunk.content else "", # First 200 chars for debugging + "chunk_content": chunk.content[:200] + if chunk.content + else "", # First 200 chars for debugging "extracted_entities_count": len(extracted_entities), - **evaluation_result + **evaluation_result, } except Exception as e: - logger.error(f"Error evaluating entity extraction for chunk {chunk.id}: {e}") + logger.error( + f"Error evaluating entity extraction for chunk {chunk.id}: {e}" + ) return { "chunk_id": chunk.id, "chunk_content": chunk.content[:200] if chunk.content else "", @@ -274,7 +217,7 @@ async def _evaluate_entity_extraction( "accuracy_reasoning": f"Evaluation failed: {str(e)}", "completeness_reasoning": "", "precision_reasoning": "", - "issues": [f"Evaluation error: {str(e)}"] + "issues": [f"Evaluation error: {str(e)}"], } async def _evaluate_relation_extraction( @@ -284,7 +227,9 @@ async def _evaluate_relation_extraction( try: prompt = RELATION_EVALUATION_PROMPT.format( chunk_content=chunk.content, - extracted_relations=json.dumps(extracted_relations, ensure_ascii=False, indent=2) + extracted_relations=json.dumps( + extracted_relations, ensure_ascii=False, indent=2 + ), ) response = await self.llm_client.generate_answer(prompt) @@ -294,11 +239,13 @@ async def _evaluate_relation_extraction( evaluation_result = json.loads(response) except json.JSONDecodeError: # Try to extract JSON from markdown code blocks or other formats - json_match = re.search(r'\{.*\}', response, re.DOTALL) + json_match = re.search(r"\{.*\}", response, re.DOTALL) if json_match: evaluation_result = json.loads(json_match.group(0)) else: - logger.warning(f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}") + logger.warning( + f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}" + ) # Return default evaluation evaluation_result = { "accuracy": 0.0, @@ -308,7 +255,7 @@ async def _evaluate_relation_extraction( "accuracy_reasoning": "Failed to parse LLM response", "completeness_reasoning": "", "precision_reasoning": "", - "issues": ["LLM response parsing failed"] + "issues": ["LLM response parsing failed"], } # Validate and calculate overall_score if not provided @@ -316,16 +263,20 @@ async def _evaluate_relation_extraction( accuracy = float(evaluation_result.get("accuracy", 0.0)) completeness = float(evaluation_result.get("completeness", 0.0)) precision = float(evaluation_result.get("precision", 0.0)) - evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + evaluation_result["overall_score"] = ( + 0.4 * accuracy + 0.4 * completeness + 0.2 * precision + ) return { "chunk_id": chunk.id, "chunk_content": chunk.content[:200] if chunk.content else "", "extracted_relations_count": len(extracted_relations), - **evaluation_result + **evaluation_result, } except Exception as e: - logger.error(f"Error evaluating relation extraction for chunk {chunk.id}: {e}") + logger.error( + f"Error evaluating relation extraction for chunk {chunk.id}: {e}" + ) return { "chunk_id": chunk.id, "chunk_content": chunk.content[:200] if chunk.content else "", @@ -337,47 +288,58 @@ async def _evaluate_relation_extraction( "accuracy_reasoning": f"Evaluation failed: {str(e)}", "completeness_reasoning": "", "precision_reasoning": "", - "issues": [f"Evaluation error: {str(e)}"] + "issues": [f"Evaluation error: {str(e)}"], } def _aggregate_evaluation_results( self, entity_evaluations: List[Dict], relation_evaluations: List[Dict] ) -> Dict[str, Any]: """Aggregate evaluation results from all chunks.""" + def calculate_stats(scores: List[float]) -> Dict[str, float]: if not scores: - return { - "mean": 0.0, - "median": 0.0, - "min": 0.0, - "max": 0.0, - "std": 0.0 - } + return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0} sorted_scores = sorted(scores) n = len(scores) mean = sum(scores) / n - median = sorted_scores[n // 2] if n % 2 == 1 else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2 + median = ( + sorted_scores[n // 2] + if n % 2 == 1 + else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2 + ) variance = sum((x - mean) ** 2 for x in scores) / n - std = variance ** 0.5 + std = variance**0.5 return { "mean": mean, "median": median, "min": min(scores), "max": max(scores), - "std": std + "std": std, } # Extract scores - entity_overall_scores = [e.get("overall_score", 0.0) for e in entity_evaluations] + entity_overall_scores = [ + e.get("overall_score", 0.0) for e in entity_evaluations + ] entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations] - entity_completeness_scores = [e.get("completeness", 0.0) for e in entity_evaluations] + entity_completeness_scores = [ + e.get("completeness", 0.0) for e in entity_evaluations + ] entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations] - relation_overall_scores = [r.get("overall_score", 0.0) for r in relation_evaluations] - relation_accuracy_scores = [r.get("accuracy", 0.0) for r in relation_evaluations] - relation_completeness_scores = [r.get("completeness", 0.0) for r in relation_evaluations] - relation_precision_scores = [r.get("precision", 0.0) for r in relation_evaluations] + relation_overall_scores = [ + r.get("overall_score", 0.0) for r in relation_evaluations + ] + relation_accuracy_scores = [ + r.get("accuracy", 0.0) for r in relation_evaluations + ] + relation_completeness_scores = [ + r.get("completeness", 0.0) for r in relation_evaluations + ] + relation_precision_scores = [ + r.get("precision", 0.0) for r in relation_evaluations + ] return { "entity_accuracy": { @@ -386,7 +348,7 @@ def calculate_stats(scores: List[float]) -> Dict[str, float]: "completeness": calculate_stats(entity_completeness_scores), "precision": calculate_stats(entity_precision_scores), "total_chunks": len(entity_evaluations), - "detailed_results": entity_evaluations + "detailed_results": entity_evaluations, }, "relation_accuracy": { "overall_score": calculate_stats(relation_overall_scores), @@ -394,6 +356,6 @@ def calculate_stats(scores: List[float]) -> Dict[str, float]: "completeness": calculate_stats(relation_completeness_scores), "precision": calculate_stats(relation_precision_scores), "total_chunks": len(relation_evaluations), - "detailed_results": relation_evaluations - } + "detailed_results": relation_evaluations, + }, } diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py index c67b6d32..a840abc6 100644 --- a/graphgen/models/evaluator/kg/consistency_evaluator.py +++ b/graphgen/models/evaluator/kg/consistency_evaluator.py @@ -8,106 +8,6 @@ from graphgen.utils import create_event_loop, logger -# LLM prompts for conflict detection -ENTITY_TYPE_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中被提取为不同的类型,是否存在语义冲突。 - -实体名称:{entity_name} - -在不同文本块中的类型提取结果: -{type_extractions} - -预设的实体类型列表(供参考): -concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene - -请判断这些类型是否存在语义冲突(即它们是否描述的是同一类事物,还是存在矛盾)。 -注意:如果类型只是同一概念的不同表述(如 concept 和 keyword),可能不算严重冲突。 - -请以 JSON 格式返回: -{{ - "has_conflict": , - "conflict_severity": <0-1之间的浮点数,0表示无冲突,1表示严重冲突>, - "conflict_reasoning": "<冲突判断的理由>", - "conflicting_types": ["<存在冲突的类型对>"], - "recommended_type": "<如果存在冲突,推荐的正确类型(必须是预设类型之一)>" -}} -""" - -ENTITY_DESCRIPTION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中的描述是否存在语义冲突。 - -实体名称:{entity_name} - -在不同文本块中的描述: -{descriptions} - -请判断这些描述是否存在语义冲突(即它们是否描述的是同一个实体,还是存在矛盾的信息)。 - -请以 JSON 格式返回: -{{ - "has_conflict": , - "conflict_severity": <0-1之间的浮点数>, - "conflict_reasoning": "<冲突判断的理由>", - "conflicting_descriptions": ["<存在冲突的描述对>"], - "conflict_details": "<具体的冲突内容>" -}} -""" - -RELATION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一对实体在不同文本块中的关系描述是否存在语义冲突。 - -实体对:{source_entity} -> {target_entity} - -在不同文本块中的关系描述: -{relation_descriptions} - -请判断这些关系描述是否存在语义冲突。 - -请以 JSON 格式返回: -{{ - "has_conflict": , - "conflict_severity": <0-1之间的浮点数>, - "conflict_reasoning": "<冲突判断的理由>", - "conflicting_relations": ["<存在冲突的关系描述对>"] -}} -""" - -ENTITY_EXTRACTION_PROMPT = """从以下文本块中提取指定实体的类型和描述。 - -**重要**:你只需要提取指定的实体,不要提取其他实体。 - -实体名称:{entity_name} - -文本块: -{chunk_content} - -请从文本块中找到并提取**仅此实体**(实体名称:{entity_name})的以下信息: - -1. entity_type: 实体类型,必须是以下预设类型之一(小写): - - concept: 概念 - - date: 日期 - - location: 地点 - - keyword: 关键词 - - organization: 组织 - - person: 人物 - - event: 事件 - - work: 作品/工作 - - nature: 自然 - - artificial: 人工 - - science: 科学 - - technology: 技术 - - mission: 任务 - - gene: 基因 - - 如果无法确定类型,请使用 "concept" 作为默认值。 - -2. description: 实体描述(简要描述该实体在文本中的作用和特征) - -请以 JSON 格式返回: -{{ - "entity_type": "<实体类型(必须是上述预设类型之一)>", - "description": "<实体描述>" -}} -""" - - class ConsistencyEvaluator: """Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge. @@ -149,7 +49,9 @@ async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: entities_with_multiple_sources.append((node_id, node_data, source_ids)) if not entities_with_multiple_sources: - logger.info("No entities with multiple sources found, skipping consistency check") + logger.info( + "No entities with multiple sources found, skipping consistency check" + ) return { "conflict_rate": 0.0, "conflict_entities_count": 0, @@ -157,7 +59,9 @@ async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: "conflicts": [], } - logger.info(f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources") + logger.info( + f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources" + ) # Evaluate entities concurrently semaphore = asyncio.Semaphore(self.max_concurrent) @@ -166,7 +70,10 @@ async def evaluate_entity(entity_info): async with semaphore: return await self._evaluate_entity_consistency(entity_info) - tasks = [evaluate_entity(entity_info) for entity_info in entities_with_multiple_sources] + tasks = [ + evaluate_entity(entity_info) + for entity_info in entities_with_multiple_sources + ] results = await asyncio.gather(*tasks, return_exceptions=True) # Aggregate results @@ -175,7 +82,9 @@ async def evaluate_entity(entity_info): for i, result in enumerate(results): if isinstance(result, Exception): - logger.error(f"Failed to evaluate entity {entities_with_multiple_sources[i][0]}: {result}") + logger.error( + f"Failed to evaluate entity {entities_with_multiple_sources[i][0]}: {result}" + ) continue entity_id, entity_conflicts = result @@ -199,8 +108,9 @@ async def evaluate_entity(entity_info): def _clean_entity_id(self, entity_id: str) -> str: """Clean entity ID by removing surrounding quotes.""" clean_id = entity_id.strip() - if (clean_id.startswith('"') and clean_id.endswith('"')) or \ - (clean_id.startswith("'") and clean_id.endswith("'")): + if (clean_id.startswith('"') and clean_id.endswith('"')) or ( + clean_id.startswith("'") and clean_id.endswith("'") + ): clean_id = clean_id[1:-1].strip() return clean_id @@ -237,14 +147,16 @@ async def _evaluate_entity_consistency( entity_id, type_extractions ) if type_conflict and type_conflict.get("has_conflict", False): - conflicts.append({ - "entity_id": clean_entity_id, - "conflict_type": "entity_type", - "conflict_severity": type_conflict.get("conflict_severity", 0.0), - "conflict_reasoning": type_conflict.get("conflict_reasoning", ""), - "conflicting_values": type_conflict.get("conflicting_types", []), - "recommended_value": type_conflict.get("recommended_type", ""), - }) + conflicts.append( + { + "entity_id": clean_entity_id, + "conflict_type": "entity_type", + "conflict_severity": type_conflict.get("conflict_severity", 0.0), + "conflict_reasoning": type_conflict.get("conflict_reasoning", ""), + "conflicting_values": type_conflict.get("conflicting_types", []), + "recommended_value": type_conflict.get("recommended_type", ""), + } + ) # Check entity description consistency descriptions = { @@ -255,14 +167,18 @@ async def _evaluate_entity_consistency( entity_id, descriptions ) if desc_conflict and desc_conflict.get("has_conflict", False): - conflicts.append({ - "entity_id": clean_entity_id, - "conflict_type": "description", - "conflict_severity": desc_conflict.get("conflict_severity", 0.0), - "conflict_reasoning": desc_conflict.get("conflict_reasoning", ""), - "conflicting_values": desc_conflict.get("conflicting_descriptions", []), - "conflict_details": desc_conflict.get("conflict_details", ""), - }) + conflicts.append( + { + "entity_id": clean_entity_id, + "conflict_type": "description", + "conflict_severity": desc_conflict.get("conflict_severity", 0.0), + "conflict_reasoning": desc_conflict.get("conflict_reasoning", ""), + "conflicting_values": desc_conflict.get( + "conflicting_descriptions", [] + ), + "conflict_details": desc_conflict.get("conflict_details", ""), + } + ) return entity_id, conflicts @@ -290,7 +206,9 @@ async def _extract_entity_from_chunk( prompt = ENTITY_EXTRACTION_PROMPT.format( entity_name=clean_entity_id, - chunk_content=chunk.content[:2000] if chunk.content else "" # Limit content length + chunk_content=chunk.content[:2000] + if chunk.content + else "", # Limit content length ) response = await self.llm_client.generate_answer(prompt) @@ -300,20 +218,33 @@ async def _extract_entity_from_chunk( extraction = json.loads(response) except json.JSONDecodeError: # Try to extract JSON from markdown code blocks - json_match = re.search(r'\{.*\}', response, re.DOTALL) + json_match = re.search(r"\{.*\}", response, re.DOTALL) if json_match: extraction = json.loads(json_match.group(0)) else: - logger.warning(f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}") + logger.warning( + f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}" + ) return {} # Normalize entity_type to lowercase and validate entity_type = extraction.get("entity_type", "").lower().strip() # Valid preset types valid_types = { - "concept", "date", "location", "keyword", "organization", - "person", "event", "work", "nature", "artificial", - "science", "technology", "mission", "gene" + "concept", + "date", + "location", + "keyword", + "organization", + "person", + "event", + "work", + "nature", + "artificial", + "science", + "technology", + "mission", + "gene", } # If entity_type is not in valid types, default to "concept" if entity_type not in valid_types: @@ -329,7 +260,9 @@ async def _extract_entity_from_chunk( "description": extraction.get("description", ""), } except Exception as e: - logger.error(f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}") + logger.error( + f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}" + ) return {} async def _check_entity_type_consistency( @@ -341,13 +274,14 @@ async def _check_entity_type_consistency( return {"has_conflict": False} try: - type_list = [f"Chunk {chunk_id}: {entity_type}" - for chunk_id, entity_type in type_extractions.items() - if entity_type] + type_list = [ + f"Chunk {chunk_id}: {entity_type}" + for chunk_id, entity_type in type_extractions.items() + if entity_type + ] prompt = ENTITY_TYPE_CONFLICT_PROMPT.format( - entity_name=entity_id, - type_extractions="\n".join(type_list) + entity_name=entity_id, type_extractions="\n".join(type_list) ) response = await self.llm_client.generate_answer(prompt) @@ -356,11 +290,13 @@ async def _check_entity_type_consistency( try: result = json.loads(response) except json.JSONDecodeError: - json_match = re.search(r'\{.*\}', response, re.DOTALL) + json_match = re.search(r"\{.*\}", response, re.DOTALL) if json_match: result = json.loads(json_match.group(0)) else: - logger.warning(f"Failed to parse conflict detection response for {entity_id}") + logger.warning( + f"Failed to parse conflict detection response for {entity_id}" + ) return {"has_conflict": False} return result @@ -382,12 +318,13 @@ async def _check_entity_description_consistency( return {"has_conflict": False} try: - desc_list = [f"Chunk {chunk_id}: {description}" - for chunk_id, description in valid_descriptions.items()] + desc_list = [ + f"Chunk {chunk_id}: {description}" + for chunk_id, description in valid_descriptions.items() + ] prompt = ENTITY_DESCRIPTION_CONFLICT_PROMPT.format( - entity_name=entity_id, - descriptions="\n".join(desc_list) + entity_name=entity_id, descriptions="\n".join(desc_list) ) response = await self.llm_client.generate_answer(prompt) @@ -396,11 +333,13 @@ async def _check_entity_description_consistency( try: result = json.loads(response) except json.JSONDecodeError: - json_match = re.search(r'\{.*\}', response, re.DOTALL) + json_match = re.search(r"\{.*\}", response, re.DOTALL) if json_match: result = json.loads(json_match.group(0)) else: - logger.warning(f"Failed to parse conflict detection response for {entity_id}") + logger.warning( + f"Failed to parse conflict detection response for {entity_id}" + ) return {"has_conflict": False} return result @@ -416,14 +355,16 @@ async def _check_relation_consistency( return {"has_conflict": False} try: - rel_list = [f"Chunk {chunk_id}: {relation}" - for chunk_id, relation in relation_extractions.items() - if relation] + rel_list = [ + f"Chunk {chunk_id}: {relation}" + for chunk_id, relation in relation_extractions.items() + if relation + ] prompt = RELATION_CONFLICT_PROMPT.format( source_entity=src_id, target_entity=dst_id, - relation_descriptions="\n".join(rel_list) + relation_descriptions="\n".join(rel_list), ) response = await self.llm_client.generate_answer(prompt) @@ -432,14 +373,18 @@ async def _check_relation_consistency( try: result = json.loads(response) except json.JSONDecodeError: - json_match = re.search(r'\{.*\}', response, re.DOTALL) + json_match = re.search(r"\{.*\}", response, re.DOTALL) if json_match: result = json.loads(json_match.group(0)) else: - logger.warning(f"Failed to parse relation conflict response for {src_id}->{dst_id}") + logger.warning( + f"Failed to parse relation conflict response for {src_id}->{dst_id}" + ) return {"has_conflict": False} return result except Exception as e: - logger.error(f"Error checking relation consistency for {src_id}->{dst_id}: {e}") + logger.error( + f"Error checking relation consistency for {src_id}->{dst_id}: {e}" + ) return {"has_conflict": False} diff --git a/graphgen/models/evaluator/kg/kg_quality_evaluator.py b/graphgen/models/evaluator/kg/kg_quality_evaluator.py deleted file mode 100644 index 0a1e4e9a..00000000 --- a/graphgen/models/evaluator/kg/kg_quality_evaluator.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -Knowledge Graph Quality Evaluator - -This module provides comprehensive quality evaluation for knowledge graphs, -1. accuracy assessment (entity/relation/triple validation), -2. consistency assessment (attribute conflict detection), and structural -3. robustness assessment (noise ratio, connectivity, degree distribution). -""" - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper -from graphgen.models.evaluator.kg import ( - AccuracyEvaluator, - ConsistencyEvaluator, - StructureEvaluator, -) -from graphgen.utils import CURRENT_LOGGER_VAR, logger - - -@dataclass -class KGQualityEvaluator: - """Knowledge Graph Quality Evaluator.""" - - working_dir: str = "cache" - graph_backend: str = "kuzu" - kv_backend: str = "rocksdb" - llm_client: Optional[BaseLLMWrapper] = None - max_concurrent: int = 10 - - def __post_init__(self): - """Initialize storage and LLM client.""" - # Lazy import to avoid circular dependency - from graphgen.common import init_llm, init_storage - - self.graph_storage: BaseGraphStorage = init_storage( - backend=self.graph_backend, - working_dir=self.working_dir, - namespace="graph", - ) - self.chunk_storage: BaseKVStorage = init_storage( - backend=self.kv_backend, - working_dir=self.working_dir, - namespace="chunk", - ) - - if self.llm_client is None: - self.llm_client = init_llm("trainee") - - def evaluate_all(self) -> Dict[str, Any]: - """Run all evaluation metrics and return comprehensive report.""" - CURRENT_LOGGER_VAR.get() - results = {} - - try: - logger.info("Starting accuracy evaluation...") - results["accuracy"] = self.evaluate_accuracy() - except Exception as e: - logger.error(f"Accuracy evaluation failed: {e}") - results["accuracy"] = {"error": str(e)} - - # Consistency evaluation - try: - logger.info("Starting consistency evaluation...") - consistency_evaluator = ConsistencyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - max_concurrent=self.max_concurrent, - ) - results["consistency"] = consistency_evaluator.evaluate() - except Exception as e: - logger.error(f"Consistency evaluation failed: {e}") - results["consistency"] = {"error": str(e)} - - try: - logger.info("Starting structural robustness evaluation...") - structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) - results["structure"] = structure_evaluator.evaluate() - except Exception as e: - logger.error(f"Structural evaluation failed: {e}") - results["structure"] = {"error": str(e)} - - return results - - def evaluate_accuracy(self) -> Dict[str, Any]: - accuracy_evaluator = AccuracyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - max_concurrent=self.max_concurrent, - ) - return accuracy_evaluator.evaluate() - - def evaluate_consistency(self) -> Dict[str, Any]: - consistency_evaluator = ConsistencyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - max_concurrent=self.max_concurrent, - ) - return consistency_evaluator.evaluate() - - def evaluate_structure(self) -> Dict[str, Any]: - structure_evaluator = StructureEvaluator(graph_storage=self.graph_storage) - return structure_evaluator.evaluate() diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index 49d45703..6d2fe89a 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -11,10 +11,9 @@ class EvaluateService(BaseOperator): """ def __init__(self, working_dir: str = "cache", metrics: list[str] = None): - # optional 传入 graph super().__init__(working_dir=working_dir, op_name="evaluate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") - self.metrics = metrics or [] + self.metrics = metrics self.evaluators = { "xxx": "xxxEvaluator" @@ -24,6 +23,10 @@ def __init__(self, working_dir: str = "cache", metrics: list[str] = None): xx, xx, xx ) + def _init_evaluators(self): + for metric in self.metrics: + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: items = batch.to_dict(orient="records") return pd.DataFrame(self.evaluate(items)) diff --git a/graphgen/templates/__init__.py b/graphgen/templates/__init__.py index 0940e910..cbfa4e17 100644 --- a/graphgen/templates/__init__.py +++ b/graphgen/templates/__init__.py @@ -1,5 +1,6 @@ from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT +from .evaluation import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT from .generation import ( AGGREGATED_GENERATION_PROMPT, diff --git a/graphgen/templates/evaluation/__init__.py b/graphgen/templates/evaluation/__init__.py new file mode 100644 index 00000000..93761e85 --- /dev/null +++ b/graphgen/templates/evaluation/__init__.py @@ -0,0 +1 @@ +from .kg import ACCURACY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/__init__.py b/graphgen/templates/evaluation/kg/__init__.py new file mode 100644 index 00000000..9c500d1f --- /dev/null +++ b/graphgen/templates/evaluation/kg/__init__.py @@ -0,0 +1 @@ +from .accuracy_evaluation import ACCURACY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/accuracy_evaluation.py b/graphgen/templates/evaluation/kg/accuracy_evaluation.py new file mode 100644 index 00000000..f98b8b0f --- /dev/null +++ b/graphgen/templates/evaluation/kg/accuracy_evaluation.py @@ -0,0 +1,156 @@ +ENTITY_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的实体列表,评估实体提取的质量。 + +评估维度: +1. ACCURACY (准确性, 权重: 40%): 提取的实体是否正确,是否有误提取或错误识别 +2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要实体 +3. PRECISION (精确性, 权重: 20%): 提取的实体是否精确,命名是否准确 + +评分标准(每个维度 0-1 分): +- EXCELLENT (0.8-1.0): 高质量提取 +- GOOD (0.6-0.79): 良好质量,有少量问题 +- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 +- POOR (0.0-0.39): 质量差,需要改进 + +综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +请评估以下内容: + +原始文本块: +{chunk_content} + +提取的实体列表: +{extracted_entities} + +请以 JSON 格式返回评估结果: +{{ + "accuracy": <0-1之间的浮点数>, + "completeness": <0-1之间的浮点数>, + "precision": <0-1之间的浮点数>, + "overall_score": <综合评分>, + "accuracy_reasoning": "<准确性评估理由>", + "completeness_reasoning": "<完整性评估理由,包括遗漏的重要实体>", + "precision_reasoning": "<精确性评估理由>", + "issues": ["<发现的问题列表>"] +}} +""" + +ENTITY_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \ +Your task is to evaluate the quality of entity extraction from a given text block and extracted entity list. + +Evaluation Dimensions: +1. ACCURACY (Weight: 40%): Whether the extracted entities are correct, and if there are any false extractions or misidentifications +2. COMPLETENESS (Weight: 40%): Whether important entities from the text are missing +3. PRECISION (Weight: 20%): Whether the extracted entities are precise and accurately named + +Scoring Criteria (0-1 scale for each dimension): +- EXCELLENT (0.8-1.0): High-quality extraction +- GOOD (0.6-0.79): Good quality with minor issues +- ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues +- POOR (0.0-0.39): Poor quality, needs improvement + +Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +Please evaluate the following: + +Original Text Block: +{chunk_content} + +Extracted Entity List: +{extracted_entities} + +Please return the evaluation result in JSON format: +{{ + "accuracy": , + "completeness": , + "precision": , + "overall_score": , + "accuracy_reasoning": "", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [""] +}} +""" + +RELATION_EVALUATION_PROMPT_ZH = """你是一个知识图谱质量评估专家。你的任务是从给定的文本块和提取的关系列表,评估关系抽取的质量。 + +评估维度: +1. ACCURACY (准确性, 权重: 40%): 提取的关系是否正确,关系描述是否准确 +2. COMPLETENESS (完整性, 权重: 40%): 是否遗漏了文本中的重要关系 +3. PRECISION (精确性, 权重: 20%): 关系描述是否精确,是否过于宽泛 + +评分标准(每个维度 0-1 分): +- EXCELLENT (0.8-1.0): 高质量提取 +- GOOD (0.6-0.79): 良好质量,有少量问题 +- ACCEPTABLE (0.4-0.59): 可接受,有明显问题 +- POOR (0.0-0.39): 质量差,需要改进 + +综合评分 = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +请评估以下内容: + +原始文本块: +{chunk_content} + +提取的关系列表: +{extracted_relations} + +请以 JSON 格式返回评估结果: +{{ + "accuracy": <0-1之间的浮点数>, + "completeness": <0-1之间的浮点数>, + "precision": <0-1之间的浮点数>, + "overall_score": <综合评分>, + "accuracy_reasoning": "<准确性评估理由>", + "completeness_reasoning": "<完整性评估理由,包括遗漏的重要关系>", + "precision_reasoning": "<精确性评估理由>", + "issues": ["<发现的问题列表>"] +}} +""" + +RELATION_EVALUATION_PROMPT_EN = """You are a Knowledge Graph Quality Assessment Expert. \ +Your task is to evaluate the quality of relation extraction from a given text block and extracted relation list. + +Evaluation Dimensions: +1. ACCURACY (Weight: 40%): Whether the extracted relations are correct and the relation descriptions are accurate +2. COMPLETENESS (Weight: 40%): Whether important relations from the text are missing +3. PRECISION (Weight: 20%): Whether the relation descriptions are precise and not overly broad + +Scoring Criteria (0-1 scale for each dimension): +- EXCELLENT (0.8-1.0): High-quality extraction +- GOOD (0.6-0.79): Good quality with minor issues +- ACCEPTABLE (0.4-0.59): Acceptable with noticeable issues +- POOR (0.0-0.39): Poor quality, needs improvement + +Overall Score = 0.4 × Accuracy + 0.4 × Completeness + 0.2 × Precision + +Please evaluate the following: + +Original Text Block: +{chunk_content} + +Extracted Relation List: +{extracted_relations} + +Please return the evaluation result in JSON format: +{{ + "accuracy": , + "completeness": , + "precision": , + "overall_score": , + "accuracy_reasoning": "", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": [""] +}} +""" + +ACCURACY_EVALUATION_PROMPT = { + "zh": { + "ENTITY": ENTITY_EVALUATION_PROMPT_ZH, + "RELATION": RELATION_EVALUATION_PROMPT_ZH, + }, + "en": { + "ENTITY": ENTITY_EVALUATION_PROMPT_EN, + "RELATION": RELATION_EVALUATION_PROMPT_EN, + }, +} diff --git a/graphgen/templates/evaluation/kg/consistency_evaluation.py b/graphgen/templates/evaluation/kg/consistency_evaluation.py new file mode 100644 index 00000000..b8cf2f8d --- /dev/null +++ b/graphgen/templates/evaluation/kg/consistency_evaluation.py @@ -0,0 +1,97 @@ +ENTITY_TYPE_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中被提取为不同的类型,是否存在语义冲突。 + +实体名称:{entity_name} + +在不同文本块中的类型提取结果: +{type_extractions} + +预设的实体类型列表(供参考): +concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene + +请判断这些类型是否存在语义冲突(即它们是否描述的是同一类事物,还是存在矛盾)。 +注意:如果类型只是同一概念的不同表述(如 concept 和 keyword),可能不算严重冲突。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数,0表示无冲突,1表示严重冲突>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_types": ["<存在冲突的类型对>"], + "recommended_type": "<如果存在冲突,推荐的正确类型(必须是预设类型之一)>" +}} +""" + +ENTITY_DESCRIPTION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中的描述是否存在语义冲突。 + +实体名称:{entity_name} + +在不同文本块中的描述: +{descriptions} + +请判断这些描述是否存在语义冲突(即它们是否描述的是同一个实体,还是存在矛盾的信息)。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_descriptions": ["<存在冲突的描述对>"], + "conflict_details": "<具体的冲突内容>" +}} +""" + +RELATION_CONFLICT_PROMPT = """你是一个知识图谱一致性评估专家。你的任务是判断同一对实体在不同文本块中的关系描述是否存在语义冲突。 + +实体对:{source_entity} -> {target_entity} + +在不同文本块中的关系描述: +{relation_descriptions} + +请判断这些关系描述是否存在语义冲突。 + +请以 JSON 格式返回: +{{ + "has_conflict": , + "conflict_severity": <0-1之间的浮点数>, + "conflict_reasoning": "<冲突判断的理由>", + "conflicting_relations": ["<存在冲突的关系描述对>"] +}} +""" + +ENTITY_EXTRACTION_PROMPT = """从以下文本块中提取指定实体的类型和描述。 + +**重要**:你只需要提取指定的实体,不要提取其他实体。 + +实体名称:{entity_name} + +文本块: +{chunk_content} + +请从文本块中找到并提取**仅此实体**(实体名称:{entity_name})的以下信息: + +1. entity_type: 实体类型,必须是以下预设类型之一(小写): + - concept: 概念 + - date: 日期 + - location: 地点 + - keyword: 关键词 + - organization: 组织 + - person: 人物 + - event: 事件 + - work: 作品/工作 + - nature: 自然 + - artificial: 人工 + - science: 科学 + - technology: 技术 + - mission: 任务 + - gene: 基因 + + 如果无法确定类型,请使用 "concept" 作为默认值。 + +2. description: 实体描述(简要描述该实体在文本中的作用和特征) + +请以 JSON 格式返回: +{{ + "entity_type": "<实体类型(必须是上述预设类型之一)>", + "description": "<实体描述>" +}} +""" From 77bb00dbc3d3e040dff24da642ea6b3bf41d0670 Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 25 Dec 2025 21:53:56 +0800 Subject: [PATCH 12/29] refactor: refactor base_evaluator --- graphgen/bases/base_evaluator.py | 50 +++----------------------------- 1 file changed, 4 insertions(+), 46 deletions(-) diff --git a/graphgen/bases/base_evaluator.py b/graphgen/bases/base_evaluator.py index e24cfa43..b17ea935 100644 --- a/graphgen/bases/base_evaluator.py +++ b/graphgen/bases/base_evaluator.py @@ -1,52 +1,10 @@ -import asyncio - -from tqdm.asyncio import tqdm as tqdm_async - +from abc import ABC, abstractmethod from graphgen.bases.datatypes import QAPair -from graphgen.utils import create_event_loop -class BaseEvaluator: - def __init__(self, max_concurrent: int = 100): - self.max_concurrent = max_concurrent - self.results: list[float] = None - - def evaluate(self, pairs: list[QAPair]) -> list[float]: +class BaseEvaluator(ABC): + @abstractmethod + def evaluate(self, pair: QAPair) -> float: """ Evaluate the text and return a score. """ - return create_event_loop().run_until_complete(self.async_evaluate(pairs)) - - async def async_evaluate(self, pairs: list[QAPair]) -> list[float]: - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def evaluate_with_semaphore(pair): - async with semaphore: # 获取Semaphore - return await self.evaluate_single(pair) - - results = [] - for result in tqdm_async( - asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]), - total=len(pairs), - ): - results.append(await result) - return results - - async def evaluate_single(self, pair: QAPair) -> float: - raise NotImplementedError() - - def get_average_score(self, pairs: list[QAPair]) -> float: - """ - Get the average score of a batch of texts. - """ - results = self.evaluate(pairs) - self.results = results - return sum(self.results) / len(pairs) - - def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: - """ - Get the min and max score of a batch of texts. - """ - if self.results is None: - self.get_average_score(pairs) - return min(self.results), max(self.results) From 19510d93cfb9a4f6bcc4c271a0f0bf0eda8dd02e Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 25 Dec 2025 21:59:30 +0800 Subject: [PATCH 13/29] refator: refactor LengthEvaluator --- .../models/evaluator/qa/length_evaluator.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/graphgen/models/evaluator/qa/length_evaluator.py b/graphgen/models/evaluator/qa/length_evaluator.py index 74716f70..485a8155 100644 --- a/graphgen/models/evaluator/qa/length_evaluator.py +++ b/graphgen/models/evaluator/qa/length_evaluator.py @@ -1,19 +1,16 @@ from graphgen.bases.base_evaluator import BaseEvaluator from graphgen.bases.datatypes import QAPair from graphgen.models.tokenizer import Tokenizer -from graphgen.utils import create_event_loop class LengthEvaluator(BaseEvaluator): - def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100): - super().__init__(max_concurrent) - self.tokenizer_name = tokenizer_name - self.tokenizer = Tokenizer(model_name=self.tokenizer_name) + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer - async def evaluate_single(self, pair: QAPair) -> float: - loop = create_event_loop() - return await loop.run_in_executor(None, self._calculate_length, pair.answer) - - def _calculate_length(self, text: str) -> float: - tokens = self.tokenizer.encode(text) + def evaluate(self, pair: QAPair) -> float: + """ + Evaluate the length of the qa pair. + """ + content = pair.question + pair.answer + tokens = self.tokenizer.encode(content) return len(tokens) From 028b043181ced14c47f90c3b2777a4703ffca1ca Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 25 Dec 2025 22:07:28 +0800 Subject: [PATCH 14/29] refactor: refactor MTLDEvaluator --- .../models/evaluator/qa/mtld_evaluator.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/graphgen/models/evaluator/qa/mtld_evaluator.py b/graphgen/models/evaluator/qa/mtld_evaluator.py index 8503ea4e..0156289f 100644 --- a/graphgen/models/evaluator/qa/mtld_evaluator.py +++ b/graphgen/models/evaluator/qa/mtld_evaluator.py @@ -2,37 +2,33 @@ from graphgen.bases.base_evaluator import BaseEvaluator from graphgen.bases.datatypes import QAPair -from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language - -nltk_helper = NLTKHelper() +from graphgen.utils import NLTKHelper, detect_main_language class MTLDEvaluator(BaseEvaluator): """ - 衡量文本词汇多样性的指标 + Metrics for measuring the lexical diversity of text. """ - def __init__(self, max_concurrent: int = 100): - super().__init__(max_concurrent) - self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english")) - self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese")) - - async def evaluate_single(self, pair: QAPair) -> float: - loop = create_event_loop() - return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer) + def __init__(self, threshold: float = 0.72): + self.nltk_helper = NLTKHelper() + self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("english")) + self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("chinese")) + self.threshold = threshold - def _calculate_mtld_score(self, text: str, threshold=0.72) -> float: + def evaluate(self, pair: QAPair) -> float: """ - 计算MTLD (向前和向后的平均值) + Calculate the MTLD (Mean Token Length Diversity) score for a given text. min is 1.0 higher is better """ + text = pair.answer if not text or not text.strip(): return 0.0 lang = detect_main_language(text) - tokens = nltk_helper.word_tokenize(text, lang) + tokens = self.nltk_helper.word_tokenize(text, lang) stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en filtered_tokens = [word for word in tokens if word not in stopwords] @@ -41,13 +37,13 @@ def _calculate_mtld_score(self, text: str, threshold=0.72) -> float: if not filtered_tokens: return 0 - # 计算向前的MTLD - forward_factors = self._compute_factors(filtered_tokens, threshold) + # Compute forward factors + forward_factors = self._compute_factors(filtered_tokens, self.threshold) - # 计算向后的MTLD - backward_factors = self._compute_factors(filtered_tokens[::-1], threshold) + # Compute backward factors + backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold) - # 取平均值 + # Compute average factors return (forward_factors + backward_factors) / 2 @staticmethod @@ -66,7 +62,7 @@ def _compute_factors(tokens: list, threshold: float) -> float: current_segment = [] unique_words = set() - # 处理最后一个不完整片段 + # handle last segment if current_segment: ttr = len(unique_words) / len(current_segment) if ttr <= threshold: From c161358886a77d7036ad302c86dff27fea784d20 Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 25 Dec 2025 22:39:30 +0800 Subject: [PATCH 15/29] refactor: refactor NLTKHelper --- graphgen/utils/help_nltk.py | 65 +++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/graphgen/utils/help_nltk.py b/graphgen/utils/help_nltk.py index 2d2610ba..ab70236f 100644 --- a/graphgen/utils/help_nltk.py +++ b/graphgen/utils/help_nltk.py @@ -1,39 +1,54 @@ +from functools import lru_cache import os -from typing import Dict, List, Optional +from typing import Dict, List, Final, Optional import nltk import jieba -resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources") - - class NLTKHelper: - _stopwords: Dict[str, Optional[List[str]]] = { - "english": None, - "chinese": None, + """ + NLTK helper class + """ + + SUPPORTED_LANGUAGES: Final[Dict[str, str]] = { + "en": "english", + "zh": "chinese" + } + _NLTK_PACKAGES: Final[Dict[str, str]] = { + "stopwords": "corpora", + "punkt_tab": "tokenizers" } - def __init__(self): + def __init__(self, nltk_data_path: Optional[str] = None): + self._nltk_path = nltk_data_path or os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "resources", + "nltk_data" + ) + nltk.data.path.append(self._nltk_path) jieba.initialize() + self._ensure_nltk_data("stopwords") + self._ensure_nltk_data("punkt_tab") + + def _ensure_nltk_data(self, package_name: str) -> None: + """ + ensure nltk data is downloaded + """ + try: + nltk.data.find(f"{self._NLTK_PACKAGES[package_name]}/{package_name}") + except LookupError: + nltk.download(package_name, download_dir=self._nltk_path, quiet=True) + + @lru_cache(maxsize=2) def get_stopwords(self, lang: str) -> List[str]: - nltk.data.path.append(os.path.join(resource_path, "nltk_data")) - if self._stopwords[lang] is None: - try: - nltk.data.find("corpora/stopwords") - except LookupError: - nltk.download("stopwords", download_dir=os.path.join(resource_path, "nltk_data")) - - self._stopwords[lang] = nltk.corpus.stopwords.words(lang) - return self._stopwords[lang] - - @staticmethod - def word_tokenize(text: str, lang: str) -> List[str]: + if lang not in self.SUPPORTED_LANGUAGES: + raise ValueError(f"Language {lang} is not supported.") + return nltk.corpus.stopwords.words(self.SUPPORTED_LANGUAGES[lang]) + + def word_tokenize(self, text: str, lang: str) -> List[str]: + if lang not in self.SUPPORTED_LANGUAGES: + raise ValueError(f"Language {lang} is not supported.") if lang == "zh": return jieba.lcut(text) - nltk.data.path.append(os.path.join(resource_path, "nltk_data")) - try: - nltk.data.find("tokenizers/punkt_tab") - except LookupError: - nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data")) return nltk.word_tokenize(text) From 58ede2eda593be46a5512440b50a77dcc5916f62 Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 25 Dec 2025 22:54:54 +0800 Subject: [PATCH 16/29] refactor: refactor RewardEvaluator --- graphgen/bases/__init__.py | 1 + graphgen/bases/base_evaluator.py | 2 +- .../models/evaluator/qa/length_evaluator.py | 3 +- .../models/evaluator/qa/mtld_evaluator.py | 3 +- .../models/evaluator/qa/reward_evaluator.py | 145 ++++++------------ graphgen/models/evaluator/qa/uni_evaluator.py | 9 +- 6 files changed, 61 insertions(+), 102 deletions(-) diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 41136974..0727b3fa 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -9,4 +9,5 @@ from .base_splitter import BaseSplitter from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace from .base_tokenizer import BaseTokenizer +from .base_evaluator import BaseEvaluator from .datatypes import Chunk, Config, Node, QAPair, Token diff --git a/graphgen/bases/base_evaluator.py b/graphgen/bases/base_evaluator.py index b17ea935..3cc5df18 100644 --- a/graphgen/bases/base_evaluator.py +++ b/graphgen/bases/base_evaluator.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from graphgen.bases.datatypes import QAPair +from .datatypes import QAPair class BaseEvaluator(ABC): diff --git a/graphgen/models/evaluator/qa/length_evaluator.py b/graphgen/models/evaluator/qa/length_evaluator.py index 485a8155..3af66380 100644 --- a/graphgen/models/evaluator/qa/length_evaluator.py +++ b/graphgen/models/evaluator/qa/length_evaluator.py @@ -1,5 +1,4 @@ -from graphgen.bases.base_evaluator import BaseEvaluator -from graphgen.bases.datatypes import QAPair +from graphgen.bases import BaseEvaluator, QAPair from graphgen.models.tokenizer import Tokenizer diff --git a/graphgen/models/evaluator/qa/mtld_evaluator.py b/graphgen/models/evaluator/qa/mtld_evaluator.py index 0156289f..3cd43b75 100644 --- a/graphgen/models/evaluator/qa/mtld_evaluator.py +++ b/graphgen/models/evaluator/qa/mtld_evaluator.py @@ -1,7 +1,6 @@ from typing import Set -from graphgen.bases.base_evaluator import BaseEvaluator -from graphgen.bases.datatypes import QAPair +from graphgen.bases import BaseEvaluator, QAPair from graphgen.utils import NLTKHelper, detect_main_language diff --git a/graphgen/models/evaluator/qa/reward_evaluator.py b/graphgen/models/evaluator/qa/reward_evaluator.py index 4d2c2fb9..31955336 100644 --- a/graphgen/models/evaluator/qa/reward_evaluator.py +++ b/graphgen/models/evaluator/qa/reward_evaluator.py @@ -1,107 +1,64 @@ -from dataclasses import dataclass +from typing import Optional +from graphgen.bases import BaseEvaluator, QAPair -from tqdm import tqdm -from graphgen.bases.datatypes import QAPair - - -@dataclass -class RewardEvaluator: +class RewardEvaluator(BaseEvaluator): """ - Reward Model Evaluator. - OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好 + Reward Model Evaluator for single QAPair evaluation. """ - reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2" - max_length: int = 2560 - results: list[float] = None - - def __post_init__(self): - import torch - - self.num_gpus = torch.cuda.device_count() + def __init__( + self, + reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2", + max_length: int = 2560, + device: Optional[str] = None, + ): + """ + Initialize the reward evaluator. + + Args: + reward_name: Model name or path on HuggingFace Hub + max_length: Maximum token length for the model + device: Device to run the model on. If None, auto-detect CUDA/CPU. + """ + self.reward_name = reward_name + self.max_length = max_length - @staticmethod - def process_chunk(rank, pairs, reward_name, max_length, return_dict): import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer - device = f"cuda:{rank}" - torch.cuda.set_device(rank) - - rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name) - tokenizer = AutoTokenizer.from_pretrained(reward_name) - rank_model.to(device) - rank_model.eval() - - results = [] - with torch.no_grad(): - for pair in tqdm(pairs): - inputs = tokenizer( - pair.question, - pair.answer, - return_tensors="pt", - max_length=max_length, - truncation=True, - ) - inputs = {k: v.to(device) for k, v in inputs.items()} - score = rank_model(**inputs).logits[0].item() - results.append(score) - - return_dict[rank] = results + # Set device (auto-detect if not specified) + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - def evaluate(self, pairs: list[QAPair]) -> list[float]: - import torch.multiprocessing as mp - - chunk_size = len(pairs) // self.num_gpus - chunks = [] - for i in range(self.num_gpus): - start = i * chunk_size - end = start + chunk_size - if i == self.num_gpus - 1: - end = len(pairs) - chunks.append(pairs[start:end]) - - # multi-process - manager = mp.Manager() - return_dict = manager.dict() - processes = [] - - for rank, chunk in enumerate(chunks): - p = mp.Process( - target=self.process_chunk, - args=(rank, chunk, self.reward_name, self.max_length, return_dict), - ) - p.start() - processes.append(p) - - for p in processes: - p.join() - - # 合并结果 - results = [] - for rank in range(len(chunks)): - results.extend(return_dict[rank]) - - for p in processes: - if p.is_alive(): - p.terminate() - p.join() - - return results - - def get_average_score(self, pairs: list[QAPair]) -> float: - """ - Get the average score of a batch of texts. - """ - results = self.evaluate(pairs) - self.results = results - return sum(self.results) / len(pairs) + try: + self.tokenizer = AutoTokenizer.from_pretrained(reward_name) + self.model = AutoModelForSequenceClassification.from_pretrained(reward_name) + self.model.to(self.device) + self.model.eval() + except Exception as e: + raise RuntimeError(f"Failed to load reward model '{reward_name}': {e}") from e - def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: + def evaluate(self, pair: QAPair) -> float: """ - Get the min and max score of a batch of texts. + Evaluate a single question-answer pair using the reward model. + + Args: + pair: QAPair containing question and answer strings + + Returns: + Score as a float """ - if self.results is None: - self.get_average_score(pairs) - return min(self.results), max(self.results) + # Tokenize + inputs = self.tokenizer( + pair.question, + pair.answer, + return_tensors="pt", + max_length=self.max_length, + truncation=True, + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get score + score = self.model(**inputs).logits[0].item() + + return score diff --git a/graphgen/models/evaluator/qa/uni_evaluator.py b/graphgen/models/evaluator/qa/uni_evaluator.py index 20fa3517..d172f4d6 100644 --- a/graphgen/models/evaluator/qa/uni_evaluator.py +++ b/graphgen/models/evaluator/qa/uni_evaluator.py @@ -1,10 +1,10 @@ # https://github.com/maszhongming/UniEval/tree/main -from dataclasses import dataclass, field +from dataclasses import field from tqdm import tqdm -from graphgen.bases.datatypes import QAPair +from graphgen.bases import BaseEvaluator, QAPair def _add_questions(dimension: str, question: str, answer: str): @@ -32,8 +32,11 @@ def _add_questions(dimension: str, question: str, answer: str): return cur_input -@dataclass + class UniEvaluator: + """ + UniEvaluator class + """ model_name: str = "MingZhong/unieval-sum" dimensions: list = field( default_factory=lambda: ["naturalness", "coherence", "understandability"] From f3a03916ba364c0d02bd2309fa9a1afec34b063d Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 25 Dec 2025 23:12:32 +0800 Subject: [PATCH 17/29] refactor: refactor UniEvaluator --- .../models/evaluator/qa/reward_evaluator.py | 4 +- graphgen/models/evaluator/qa/uni_evaluator.py | 257 ++++++------------ 2 files changed, 91 insertions(+), 170 deletions(-) diff --git a/graphgen/models/evaluator/qa/reward_evaluator.py b/graphgen/models/evaluator/qa/reward_evaluator.py index 31955336..a7fcbc22 100644 --- a/graphgen/models/evaluator/qa/reward_evaluator.py +++ b/graphgen/models/evaluator/qa/reward_evaluator.py @@ -26,6 +26,7 @@ def __init__( import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer + self.torch = torch # Set device (auto-detect if not specified) self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") @@ -59,6 +60,7 @@ def evaluate(self, pair: QAPair) -> float: inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get score - score = self.model(**inputs).logits[0].item() + with self.torch.no_grad(): + score = self.model(**inputs).logits[0].item() return score diff --git a/graphgen/models/evaluator/qa/uni_evaluator.py b/graphgen/models/evaluator/qa/uni_evaluator.py index d172f4d6..9dc7ad2c 100644 --- a/graphgen/models/evaluator/qa/uni_evaluator.py +++ b/graphgen/models/evaluator/qa/uni_evaluator.py @@ -1,186 +1,105 @@ # https://github.com/maszhongming/UniEval/tree/main - -from dataclasses import field - -from tqdm import tqdm - +from typing import Optional, List from graphgen.bases import BaseEvaluator, QAPair -def _add_questions(dimension: str, question: str, answer: str): - if dimension == "naturalness": - cur_input = ( - "question: Is this a natural response in the dialogue? response: " - + answer - ) - elif dimension == "coherence": - cur_input = ( - "question: Is this a coherent response given the dialogue history? response: " - + answer - + " dialogue history: " - + question - ) - elif dimension == "understandability": - cur_input = ( - "question: Is this an understandable response in the dialogue? response: " - + answer - ) - else: - raise NotImplementedError( - "The input format for this dimension is still undefined. Please customize it first." - ) - return cur_input - - - -class UniEvaluator: +class UniEvaluator(BaseEvaluator): """ - UniEvaluator class + UniEvaluator for single QAPair evaluation across quality dimensions. + + Dimensions: naturalness, coherence, understandability + + Usage: + evaluator = UniEvaluator() + pair = QAPair(question="...", answer="...") + scores = evaluator.evaluate(pair) + # {"naturalness": 0.85, "coherence": 0.92, "understandability": 0.88} """ - model_name: str = "MingZhong/unieval-sum" - dimensions: list = field( - default_factory=lambda: ["naturalness", "coherence", "understandability"] - ) - max_length: int = 2560 - results: dict = None - - def __post_init__(self): - import torch - self.num_gpus = torch.cuda.device_count() - self.results = {} + DEFAULT_MODEL: str = "MingZhong/unieval-sum" + DEFAULT_DIMS: List[str] = ["naturalness", "coherence", "understandability"] + DEFAULT_MAX_LENGTH: int = 2560 - @staticmethod - def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict): + def __init__( + self, + model_name: Optional[str] = None, + max_length: Optional[int] = None, + device: Optional[str] = None, + ): + """ + Args: + model_name: HuggingFace model name/path + max_length: Tokenizer max sequence length + device: 'cuda', 'cpu', or None for auto-detect + """ import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + self.torch = torch - device = f"cuda:{rank}" - torch.cuda.set_device(rank) - - rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) - rank_model.to(device) - rank_model.eval() - - softmax = torch.nn.Softmax(dim=1) - - pos_id = tokenizer("Yes")["input_ids"][0] - neg_id = tokenizer("No")["input_ids"][0] - - results = [] - with torch.no_grad(): - for pair in tqdm(pairs): - text = _add_questions(dimension, pair.question, pair.answer) + self.model_name = model_name or self.DEFAULT_MODEL + self.max_length = max_length or self.DEFAULT_MAX_LENGTH + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - tgt = "No" + # Load model & tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) + self.model.to(self.device) + self.model.eval() - encoded_src = tokenizer( - text, - max_length=max_length, - truncation=True, - padding=True, - return_tensors="pt", - ) - encoded_tgt = tokenizer( - tgt, - max_length=max_length, - truncation=True, - padding=True, - return_tensors="pt", - ) + # Pre-compute Yes/No token IDs + self._yes_id = self.tokenizer("Yes")["input_ids"][0] + self._no_id = self.tokenizer("No")["input_ids"][0] - src_tokens = encoded_src["input_ids"].to(device) - src_mask = encoded_src["attention_mask"].to(device) - - tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1) - - output = rank_model( + @staticmethod + def _build_input_text(dimension: str, question: str, answer: str) -> str: + """Construct input text for specified dimension.""" + if dimension == "naturalness": + return f"question: Is this a natural response? response: {answer}" + elif dimension == "coherence": + return f"question: Is this a coherent response? response: {answer} history: {question}" + elif dimension == "understandability": + return f"question: Is this an understandable response? response: {answer}" + raise NotImplementedError(f"Unsupported dimension '{dimension}'") + + def evaluate( + self, + pair: QAPair, + dimensions: Optional[List[str]] = None, + ) -> dict[str, float]: + """Evaluate a single QAPair across specified dimensions.""" + dimensions = dimensions or self.DEFAULT_DIMS + + # Validate dimensions + invalid = set(dimensions) - set(self.DEFAULT_DIMS) + if invalid: + raise ValueError(f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}") + + results = {} + no_token = self.torch.tensor([[self._no_id]], device=self.device) + + for dim in dimensions: + # Tokenize input + src = self.tokenizer( + self._build_input_text(dim, pair.question, pair.answer), + max_length=self.max_length, + truncation=True, + return_tensors="pt", + ) + src_tokens = src["input_ids"].to(self.device) + src_mask = src["attention_mask"].to(self.device) + + # Score + with self.torch.no_grad(): + logits = self.model( input_ids=src_tokens, attention_mask=src_mask, - labels=tgt_tokens, + labels=no_token, use_cache=False, - ) - - logits = output.logits.view(-1, rank_model.config.vocab_size) - - pos_score = softmax(logits)[:, pos_id] # Yes - neg_score = softmax(logits)[:, neg_id] - score = pos_score / (pos_score + neg_score) - - results.append(score.item()) - - return_dict[rank] = results - - def evaluate(self, pairs: list[QAPair]) -> list[dict]: - import torch.multiprocessing as mp - - final_results = [] - for dimension in self.dimensions: - chunk_size = len(pairs) // self.num_gpus - chunks = [] - for i in range(self.num_gpus): - start = i * chunk_size - end = start + chunk_size - if i == self.num_gpus - 1: - end = len(pairs) - chunks.append(pairs[start:end]) - - # multi-process - manager = mp.Manager() - return_dict = manager.dict() - processes = [] - - for rank, chunk in enumerate(chunks): - p = mp.Process( - target=self.process_chunk, - args=( - rank, - chunk, - self.model_name, - self.max_length, - dimension, - return_dict, - ), - ) - p.start() - processes.append(p) - - for p in processes: - p.join() - - # 合并结果 - results = [] - for rank in range(len(chunks)): - results.extend(return_dict[rank]) - - for p in processes: - if p.is_alive(): - p.terminate() - p.join() - - final_results.append({dimension: results}) - return final_results - - def get_average_score(self, pairs: list[QAPair]) -> dict: - """ - Get the average score of a batch of texts. - """ - results = self.evaluate(pairs) - final_results = {} - for result in results: - for key, value in result.items(): - final_results[key] = sum(value) / len(value) - self.results[key] = value - return final_results - - def get_min_max_score(self, pairs: list[QAPair]) -> dict: - """ - Get the min and max score of a batch of texts. - """ - if self.results is None: - self.get_average_score(pairs) - final_results = {} - for key, value in self.results.items(): - final_results[key] = min(value), max(value) - return final_results + ).logits[:, 0, :] # [1, vocab_size] + + probs = self.torch.softmax(logits, dim=-1)[0] + score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id]) + + results[dim] = score.item() + + return results From 2a3f09fc4e0725b1c3903a09dadddc51cb397608 Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Thu, 25 Dec 2025 23:16:51 +0800 Subject: [PATCH 18/29] refactor: refactor evaluator structure --- graphgen/models/evaluator/__init__.py | 4 ++-- graphgen/models/evaluator/qa/__init__.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index 83a48aaa..b6121648 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -1,2 +1,2 @@ -from graphgen.models.evaluator.kg.kg_quality_evaluator import KGQualityEvaluator -from graphgen.models.evaluator.qa.uni_evaluator import UniEvaluator +from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator +from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator diff --git a/graphgen/models/evaluator/qa/__init__.py b/graphgen/models/evaluator/qa/__init__.py index e69de29b..a9b445b4 100644 --- a/graphgen/models/evaluator/qa/__init__.py +++ b/graphgen/models/evaluator/qa/__init__.py @@ -0,0 +1,4 @@ +from .length_evaluator import LengthEvaluator +from .mtld_evaluator import MTLDEvaluator +from .reward_evaluator import RewardEvaluator +from .uni_evaluator import UniEvaluator From a4d7993cd19b5b2552a465e555315c8bb39de77c Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Fri, 26 Dec 2025 00:20:39 +0800 Subject: [PATCH 19/29] refactor: change evaluation methods in acc and consistency to sync --- .../models/evaluator/kg/accuracy_evaluator.py | 65 ++++++++--------- .../evaluator/kg/consistency_evaluator.py | 72 ++++++++----------- 2 files changed, 61 insertions(+), 76 deletions(-) diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index f9d2e405..09bbedfc 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -6,7 +6,7 @@ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.bases.datatypes import Chunk from graphgen.templates import ACCURACY_EVALUATION_PROMPT -from graphgen.utils import create_event_loop, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class AccuracyEvaluator: @@ -43,10 +43,7 @@ def evaluate(self) -> Dict[str, Any]: logger.info(f"Found {len(chunks)} chunks to evaluate") # 2. Evaluate each chunk - loop = create_event_loop() - entity_evaluations, relation_evaluations = loop.run_until_complete( - self._evaluate_all_chunks(chunks) - ) + entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks) # 3. Aggregate results return self._aggregate_evaluation_results( @@ -112,54 +109,47 @@ def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]: return relations - async def _evaluate_all_chunks( + def _evaluate_all_chunks( self, chunks: List[Chunk] ) -> tuple[List[Dict], List[Dict]]: - """Evaluate all chunks concurrently.""" - semaphore = asyncio.Semaphore(self.max_concurrent) + """Evaluate all chunks sequentially.""" + entity_evaluations = [] + relation_evaluations = [] - async def evaluate_chunk(chunk: Chunk): - async with semaphore: + for chunk in chunks: + try: entities = self._get_extracted_entities_for_chunk(chunk.id) relations = self._get_extracted_relations_for_chunk(chunk.id) - entity_eval = await self._evaluate_entity_extraction(chunk, entities) - relation_eval = await self._evaluate_relation_extraction( - chunk, relations - ) - - return entity_eval, relation_eval + entity_eval = self._evaluate_entity_extraction(chunk, entities) + relation_eval = self._evaluate_relation_extraction(chunk, relations) - tasks = [evaluate_chunk(chunk) for chunk in chunks] - results = await asyncio.gather(*tasks, return_exceptions=True) - - entity_evaluations = [] - relation_evaluations = [] - - for i, result in enumerate(results): - if isinstance(result, Exception): - logger.error(f"Failed to evaluate chunk {chunks[i].id}: {result}") + entity_evaluations.append(entity_eval) + relation_evaluations.append(relation_eval) + except Exception as e: + logger.error(f"Failed to evaluate chunk {chunk.id}: {e}") continue - entity_eval, relation_eval = result - entity_evaluations.append(entity_eval) - relation_evaluations.append(relation_eval) - return entity_evaluations, relation_evaluations - async def _evaluate_entity_extraction( + def _evaluate_entity_extraction( self, chunk: Chunk, extracted_entities: List[Dict] ) -> Dict[str, Any]: """Use LLM to evaluate entity extraction quality.""" try: - prompt = ENTITY_EVALUATION_PROMPT.format( + lang = detect_main_language(chunk.content) + prompt_template = ACCURACY_EVALUATION_PROMPT.get(lang, {}).get("ENTITY") + if not prompt_template: + prompt_template = ACCURACY_EVALUATION_PROMPT.get("en", {}).get("ENTITY") + + prompt = prompt_template.format( chunk_content=chunk.content, extracted_entities=json.dumps( extracted_entities, ensure_ascii=False, indent=2 ), ) - response = await self.llm_client.generate_answer(prompt) + response = asyncio.run(self.llm_client.generate_answer(prompt)) # Try to parse JSON response try: @@ -220,19 +210,24 @@ async def _evaluate_entity_extraction( "issues": [f"Evaluation error: {str(e)}"], } - async def _evaluate_relation_extraction( + def _evaluate_relation_extraction( self, chunk: Chunk, extracted_relations: List[Dict] ) -> Dict[str, Any]: """Use LLM to evaluate relation extraction quality.""" try: - prompt = RELATION_EVALUATION_PROMPT.format( + lang = detect_main_language(chunk.content) + prompt_template = ACCURACY_EVALUATION_PROMPT.get(lang, {}).get("RELATION") + if not prompt_template: + prompt_template = ACCURACY_EVALUATION_PROMPT.get("en", {}).get("RELATION") + + prompt = prompt_template.format( chunk_content=chunk.content, extracted_relations=json.dumps( extracted_relations, ensure_ascii=False, indent=2 ), ) - response = await self.llm_client.generate_answer(prompt) + response = asyncio.run(self.llm_client.generate_answer(prompt)) # Try to parse JSON response try: diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py index a840abc6..069e7591 100644 --- a/graphgen/models/evaluator/kg/consistency_evaluator.py +++ b/graphgen/models/evaluator/kg/consistency_evaluator.py @@ -5,7 +5,13 @@ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.bases.datatypes import Chunk -from graphgen.utils import create_event_loop, logger +from graphgen.templates.evaluation.kg.consistency_evaluation import ( + ENTITY_DESCRIPTION_CONFLICT_PROMPT, + ENTITY_EXTRACTION_PROMPT, + ENTITY_TYPE_CONFLICT_PROMPT, + RELATION_CONFLICT_PROMPT, +) +from graphgen.utils import logger class ConsistencyEvaluator: @@ -20,12 +26,10 @@ def __init__( graph_storage: BaseGraphStorage, chunk_storage: BaseKVStorage, llm_client: BaseLLMWrapper, - max_concurrent: int = 10, ): self.graph_storage = graph_storage self.chunk_storage = chunk_storage self.llm_client = llm_client - self.max_concurrent = max_concurrent def evaluate(self) -> Dict[str, Any]: """Evaluate consistency by detecting semantic conflicts.""" @@ -33,11 +37,10 @@ def evaluate(self) -> Dict[str, Any]: if not all_nodes: return {"error": "Empty graph"} - loop = create_event_loop() - return loop.run_until_complete(self._evaluate_consistency(all_nodes)) + return self._evaluate_consistency(all_nodes) - async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: - """Async evaluation of consistency.""" + def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: + """Evaluate consistency by detecting semantic conflicts.""" # Filter entities with multiple source chunks entities_with_multiple_sources = [] for node_id, node_data in all_nodes: @@ -63,35 +66,22 @@ async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources" ) - # Evaluate entities concurrently - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def evaluate_entity(entity_info): - async with semaphore: - return await self._evaluate_entity_consistency(entity_info) - - tasks = [ - evaluate_entity(entity_info) - for entity_info in entities_with_multiple_sources - ] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Aggregate results + # Evaluate entities sequentially conflicts = [] conflict_entities = set() - for i, result in enumerate(results): - if isinstance(result, Exception): + for entity_info in entities_with_multiple_sources: + try: + entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info) + if entity_conflicts: + conflicts.extend(entity_conflicts) + conflict_entities.add(entity_id) + except Exception as e: logger.error( - f"Failed to evaluate entity {entities_with_multiple_sources[i][0]}: {result}" + f"Failed to evaluate entity {entity_info[0]}: {e}" ) continue - entity_id, entity_conflicts = result - if entity_conflicts: - conflicts.extend(entity_conflicts) - conflict_entities.add(entity_id) - total_entities = len(all_nodes) conflict_rate = ( len(conflict_entities) / total_entities if total_entities > 0 else 0 @@ -114,7 +104,7 @@ def _clean_entity_id(self, entity_id: str) -> str: clean_id = clean_id[1:-1].strip() return clean_id - async def _evaluate_entity_consistency( + def _evaluate_entity_consistency( self, entity_info: tuple ) -> tuple[str, List[Dict]]: """Evaluate consistency for a single entity.""" @@ -131,7 +121,7 @@ async def _evaluate_entity_consistency( # Extract entity attributes from each chunk entity_extractions = {} for chunk in chunks: - extraction = await self._extract_entity_from_chunk(entity_id, chunk) + extraction = self._extract_entity_from_chunk(entity_id, chunk) if extraction: entity_extractions[chunk.id] = extraction @@ -143,7 +133,7 @@ async def _evaluate_entity_consistency( chunk_id: ext.get("entity_type", "") for chunk_id, ext in entity_extractions.items() } - type_conflict = await self._check_entity_type_consistency( + type_conflict = self._check_entity_type_consistency( entity_id, type_extractions ) if type_conflict and type_conflict.get("has_conflict", False): @@ -163,7 +153,7 @@ async def _evaluate_entity_consistency( chunk_id: ext.get("description", "") for chunk_id, ext in entity_extractions.items() } - desc_conflict = await self._check_entity_description_consistency( + desc_conflict = self._check_entity_description_consistency( entity_id, descriptions ) if desc_conflict and desc_conflict.get("has_conflict", False): @@ -196,7 +186,7 @@ def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]: continue return chunks - async def _extract_entity_from_chunk( + def _extract_entity_from_chunk( self, entity_id: str, chunk: Chunk ) -> Dict[str, str]: """Extract entity attributes from a chunk using LLM.""" @@ -211,7 +201,7 @@ async def _extract_entity_from_chunk( else "", # Limit content length ) - response = await self.llm_client.generate_answer(prompt) + response = asyncio.run(self.llm_client.generate_answer(prompt)) # Try to parse JSON response try: @@ -265,7 +255,7 @@ async def _extract_entity_from_chunk( ) return {} - async def _check_entity_type_consistency( + def _check_entity_type_consistency( self, entity_id: str, type_extractions: Dict[str, str] ) -> Dict[str, Any]: """Check entity type consistency using LLM.""" @@ -284,7 +274,7 @@ async def _check_entity_type_consistency( entity_name=entity_id, type_extractions="\n".join(type_list) ) - response = await self.llm_client.generate_answer(prompt) + response = asyncio.run(self.llm_client.generate_answer(prompt)) # Parse JSON response try: @@ -304,7 +294,7 @@ async def _check_entity_type_consistency( logger.error(f"Error checking type consistency for {entity_id}: {e}") return {"has_conflict": False} - async def _check_entity_description_consistency( + def _check_entity_description_consistency( self, entity_id: str, descriptions: Dict[str, str] ) -> Dict[str, Any]: """Check entity description consistency using LLM.""" @@ -327,7 +317,7 @@ async def _check_entity_description_consistency( entity_name=entity_id, descriptions="\n".join(desc_list) ) - response = await self.llm_client.generate_answer(prompt) + response = asyncio.run(self.llm_client.generate_answer(prompt)) # Parse JSON response try: @@ -347,7 +337,7 @@ async def _check_entity_description_consistency( logger.error(f"Error checking description consistency for {entity_id}: {e}") return {"has_conflict": False} - async def _check_relation_consistency( + def _check_relation_consistency( self, src_id: str, dst_id: str, relation_extractions: Dict[str, str] ) -> Dict[str, Any]: """Check relation consistency using LLM.""" @@ -367,7 +357,7 @@ async def _check_relation_consistency( relation_descriptions="\n".join(rel_list), ) - response = await self.llm_client.generate_answer(prompt) + response = asyncio.run(self.llm_client.generate_answer(prompt)) # Parse JSON response try: From 3ae232160d5b6c5c746bd3ddde1544139f6d50ec Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Fri, 26 Dec 2025 00:21:28 +0800 Subject: [PATCH 20/29] refactor: streamline evaluation functions for accuracy, consistency, and structure --- graphgen/operators/evaluate/evaluate_kg.py | 354 ++++----------------- 1 file changed, 59 insertions(+), 295 deletions(-) diff --git a/graphgen/operators/evaluate/evaluate_kg.py b/graphgen/operators/evaluate/evaluate_kg.py index 4d3a62c8..a58617fe 100644 --- a/graphgen/operators/evaluate/evaluate_kg.py +++ b/graphgen/operators/evaluate/evaluate_kg.py @@ -1,307 +1,71 @@ -import argparse -import json -from pathlib import Path +from typing import Any, Dict from dotenv import load_dotenv from graphgen.models import KGQualityEvaluator -from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger +from graphgen.utils import logger # Load environment variables load_dotenv() -def _run_evaluation(evaluator, args): - """Run the evaluation based on arguments.""" - if args.accuracy_only: - logger.info("Running accuracy evaluation only...") - return {"accuracy": evaluator.evaluate_accuracy()} - if args.consistency_only: - logger.info("Running consistency evaluation only...") - return {"consistency": evaluator.evaluate_consistency()} - if args.structure_only: - logger.info("Running structural robustness evaluation only...") - return {"structure": evaluator.evaluate_structure()} +def evaluate_accuracy(evaluator: KGQualityEvaluator) -> Dict[str, Any]: + """Evaluate accuracy of entity and relation extraction. + + Args: + evaluator: KGQualityEvaluator instance + + Returns: + Dictionary containing entity_accuracy and relation_accuracy metrics. + """ + logger.info("Running accuracy evaluation...") + results = evaluator.evaluate_accuracy() + logger.info("Accuracy evaluation completed") + return results + + +def evaluate_consistency(evaluator: KGQualityEvaluator) -> Dict[str, Any]: + """Evaluate consistency by detecting semantic conflicts. + + Args: + evaluator: KGQualityEvaluator instance + + Returns: + Dictionary containing consistency metrics including conflict_rate and conflicts. + """ + logger.info("Running consistency evaluation...") + results = evaluator.evaluate_consistency() + logger.info("Consistency evaluation completed") + return results + + +def evaluate_structure(evaluator: KGQualityEvaluator) -> Dict[str, Any]: + """Evaluate structural robustness of the graph. + + Args: + evaluator: KGQualityEvaluator instance + + Returns: + Dictionary containing structural metrics including noise_ratio, largest_cc_ratio, etc. + """ + logger.info("Running structural robustness evaluation...") + results = evaluator.evaluate_structure() + logger.info("Structural robustness evaluation completed") + return results + + +def evaluate_all(evaluator: KGQualityEvaluator) -> Dict[str, Any]: + """Run all evaluations (accuracy, consistency, structure). + + Args: + evaluator: KGQualityEvaluator instance + + Returns: + Dictionary containing all evaluation results with keys: accuracy, consistency, structure. + """ logger.info("Running all evaluations...") - return evaluator.evaluate_all() + results = evaluator.evaluate_all() + logger.info("All evaluations completed") + return results -def _print_accuracy_summary(acc): - """Print accuracy evaluation summary.""" - if "error" not in acc: - print("\n[Accuracy]") - if "entity_accuracy" in acc: - e = acc["entity_accuracy"] - overall = e.get("overall_score", {}) - accuracy = e.get("accuracy", {}) - completeness = e.get("completeness", {}) - precision = e.get("precision", {}) - - print(" Entity Extraction Quality:") - print( - f" Overall Score: {overall.get('mean', 0):.3f} (mean), " - f"{overall.get('median', 0):.3f} (median)" - ) - print( - f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " - f"{accuracy.get('median', 0):.3f} (median)" - ) - print( - f" Completeness: {completeness.get('mean', 0):.3f} (mean), " - f"{completeness.get('median', 0):.3f} (median)" - ) - print( - f" Precision: {precision.get('mean', 0):.3f} (mean), " - f"{precision.get('median', 0):.3f} (median)" - ) - print(f" Total Chunks Evaluated: {e.get('total_chunks', 0)}") - - if "relation_accuracy" in acc: - r = acc["relation_accuracy"] - overall = r.get("overall_score", {}) - accuracy = r.get("accuracy", {}) - completeness = r.get("completeness", {}) - precision = r.get("precision", {}) - - print(" Relation Extraction Quality:") - print( - f" Overall Score: {overall.get('mean', 0):.3f} (mean), " - f"{overall.get('median', 0):.3f} (median)" - ) - print( - f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), " - f"{accuracy.get('median', 0):.3f} (median)" - ) - print( - f" Completeness: {completeness.get('mean', 0):.3f} (mean), " - f"{completeness.get('median', 0):.3f} (median)" - ) - print( - f" Precision: {precision.get('mean', 0):.3f} (mean), " - f"{precision.get('median', 0):.3f} (median)" - ) - print(f" Total Chunks Evaluated: {r.get('total_chunks', 0)}") - else: - print(f"\n[Accuracy] Error: {acc['error']}") - - -def _print_consistency_summary(cons): - """Print consistency evaluation summary.""" - if "error" not in cons: - print("\n[Consistency]") - print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}") - print( - f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / " - f"{cons.get('total_entities', 0)}" - ) - entities_checked = cons.get("entities_checked", 0) - if entities_checked > 0: - print( - f" Entities Checked: {entities_checked} (entities with multiple sources)" - ) - conflicts = cons.get("conflicts", []) - if conflicts: - print(f" Total Conflicts Found: {len(conflicts)}") - # Show sample conflicts - sample_conflicts = conflicts[:3] - for conflict in sample_conflicts: - print( - f" - {conflict.get('entity_id', 'N/A')}: {conflict.get('conflict_type', 'N/A')} " - f"(severity: {conflict.get('conflict_severity', 0):.2f})" - ) - else: - print(f"\n[Consistency] Error: {cons['error']}") - - -def _print_structure_summary(struct): - """Print structural robustness evaluation summary.""" - if "error" not in struct: - print("\n[Structural Robustness]") - print(f" Total Nodes: {struct.get('total_nodes', 0)}") - print(f" Total Edges: {struct.get('total_edges', 0)}") - - thresholds = struct.get("thresholds", {}) - - # Noise Ratio - noise_check = thresholds.get("noise_ratio", {}) - noise_threshold = noise_check.get("threshold", "N/A") - noise_pass = noise_check.get("pass", False) - print( - f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} " - f"({'✓' if noise_pass else '✗'} < {noise_threshold})" - ) - - # Largest CC Ratio - lcc_check = thresholds.get("largest_cc_ratio", {}) - lcc_threshold = lcc_check.get("threshold", "N/A") - lcc_pass = lcc_check.get("pass", False) - print( - f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} " - f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})" - ) - - # Avg Degree - avg_degree_check = thresholds.get("avg_degree", {}) - avg_degree_threshold = avg_degree_check.get("threshold", "N/A") - avg_degree_pass = avg_degree_check.get("pass", False) - # Format threshold for display (handle tuple case) - if isinstance(avg_degree_threshold, tuple): - threshold_str = f"{avg_degree_threshold[0]}-{avg_degree_threshold[1]}" - else: - threshold_str = str(avg_degree_threshold) - print( - f" Avg Degree: {struct.get('avg_degree', 0):.2f} " - f"({'✓' if avg_degree_pass else '✗'} {threshold_str})" - ) - - # Power Law R² - if struct.get("powerlaw_r2") is not None: - powerlaw_check = thresholds.get("powerlaw_r2", {}) - powerlaw_threshold = powerlaw_check.get("threshold", "N/A") - powerlaw_pass = powerlaw_check.get("pass", False) - print( - f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} " - f"({'✓' if powerlaw_pass else '✗'} > {powerlaw_threshold})" - ) - else: - print(f"\n[Structural Robustness] Error: {struct['error']}") - - -def _print_summary(results): - """Print evaluation summary.""" - print("\n" + "=" * 60) - print("KG Quality Evaluation Summary") - print("=" * 60) - - if "accuracy" in results: - _print_accuracy_summary(results["accuracy"]) - if "consistency" in results: - _print_consistency_summary(results["consistency"]) - if "structure" in results: - _print_structure_summary(results["structure"]) - - print("\n" + "=" * 60) - - -def main(): - """Main function to run KG quality evaluation.""" - parser = argparse.ArgumentParser( - description="Evaluate knowledge graph quality", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Basic evaluation - python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache - - # Custom output - python -m graphgen.operators.evaluate_kg.evaluate_kg \\ - --working_dir cache \\ - --output cache/kg_evaluation.json - - # Specify backends - python -m graphgen.operators.evaluate_kg.evaluate_kg \\ - --working_dir cache \\ - --graph_backend networkx \\ - --kv_backend json_kv - """, - ) - - parser.add_argument( - "--working_dir", - type=str, - default="cache", - help="Working directory containing graph and chunk storage (default: cache)", - ) - parser.add_argument( - "--graph_backend", - type=str, - default="kuzu", - choices=["kuzu", "networkx"], - help="Graph storage backend (default: kuzu)", - ) - parser.add_argument( - "--kv_backend", - type=str, - default="rocksdb", - choices=["rocksdb", "json_kv"], - help="KV storage backend (default: rocksdb)", - ) - parser.add_argument( - "--max_concurrent", - type=int, - default=10, - help="Maximum concurrent LLM requests (default: 10)", - ) - parser.add_argument( - "--output", - type=str, - default=None, - help="Output file path for evaluation results (default: working_dir/kg_evaluation.json)", - ) - parser.add_argument( - "--accuracy_only", - action="store_true", - help="Only run accuracy evaluation", - ) - parser.add_argument( - "--consistency_only", - action="store_true", - help="Only run consistency evaluation", - ) - parser.add_argument( - "--structure_only", - action="store_true", - help="Only run structural robustness evaluation", - ) - - args = parser.parse_args() - - # Set up logging - log_dir = Path(args.working_dir) / "logs" - log_dir.mkdir(parents=True, exist_ok=True) - default_logger = set_logger(str(log_dir / "evaluate_kg.log"), name="evaluate_kg") - CURRENT_LOGGER_VAR.set(default_logger) - - # Determine output path - if args.output is None: - output_path = Path(args.working_dir) / "kg_evaluation.json" - else: - output_path = Path(args.output) - - logger.info("Starting KG quality evaluation...") - logger.info(f"Working directory: {args.working_dir}") - logger.info(f"Graph backend: {args.graph_backend}") - logger.info(f"KV backend: {args.kv_backend}") - - try: - evaluator = KGQualityEvaluator( - working_dir=args.working_dir, - graph_backend=args.graph_backend, - kv_backend=args.kv_backend, - max_concurrent=args.max_concurrent, - ) - except Exception as e: - logger.error(f"Failed to initialize evaluator: {e}") - raise - - # Run evaluation - try: - results = _run_evaluation(evaluator, args) - - # Save results - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(results, f, indent=2, ensure_ascii=False) - - logger.info(f"Evaluation completed. Results saved to: {output_path}") - - # Print summary - _print_summary(results) - - except Exception as e: - logger.error(f"Evaluation failed: {e}", exc_info=True) - raise - - -if __name__ == "__main__": - main() From 86fa173d991c2d92db4c4a04c991f8deac2f4f24 Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Fri, 26 Dec 2025 00:35:59 +0800 Subject: [PATCH 21/29] wip: perf evaluate_service --- examples/evaluate/evaluate_kg/evaluate_kg.sh | 2 +- examples/evaluate/evaluate_qa/evaluate.sh | 2 -- examples/evaluate/evaluate_qa/evaluate_qa.sh | 2 ++ .../evaluate_qa/qa_evaluation_config.yaml | 6 ++-- graphgen/models/__init__.py | 1 - .../models/evaluator/kg/accuracy_evaluator.py | 7 +++-- .../models/evaluator/qa/length_evaluator.py | 6 ++-- graphgen/operators/__init__.py | 3 ++ graphgen/operators/evaluate/__init__.py | 1 + .../operators/evaluate/evaluate_service.py | 30 ++++++++++++------- graphgen/templates/evaluation/__init__.py | 2 +- graphgen/templates/evaluation/kg/__init__.py | 1 + .../evaluation/kg/consistency_evaluation.py | 6 ++++ graphgen/utils/help_nltk.py | 9 ++++++ 14 files changed, 56 insertions(+), 22 deletions(-) delete mode 100644 examples/evaluate/evaluate_qa/evaluate.sh create mode 100644 examples/evaluate/evaluate_qa/evaluate_qa.sh diff --git a/examples/evaluate/evaluate_kg/evaluate_kg.sh b/examples/evaluate/evaluate_kg/evaluate_kg.sh index ac40b0f6..2bf2f37e 100644 --- a/examples/evaluate/evaluate_kg/evaluate_kg.sh +++ b/examples/evaluate/evaluate_kg/evaluate_kg.sh @@ -1,2 +1,2 @@ python3 -m graphgen.run \ ---config_file examples/evaluate/evaluate_kg/evaluate_kg_config.yaml \ No newline at end of file +--config_file examples/evaluate/evaluate_kg/kg_evaluation_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_qa/evaluate.sh b/examples/evaluate/evaluate_qa/evaluate.sh deleted file mode 100644 index 8c637d1f..00000000 --- a/examples/evaluate/evaluate_qa/evaluate.sh +++ /dev/null @@ -1,2 +0,0 @@ -python3 -m graphgen.run \ ---config_file examples/evaluate/evaluate_qa/evaluate_qa_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_qa/evaluate_qa.sh b/examples/evaluate/evaluate_qa/evaluate_qa.sh new file mode 100644 index 00000000..5bfe392c --- /dev/null +++ b/examples/evaluate/evaluate_qa/evaluate_qa.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/evaluate/evaluate_qa/qa_evaluation_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml index 45e9d3a7..3e875143 100644 --- a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml +++ b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml @@ -92,5 +92,7 @@ nodes: metrics: - qa_length - qa_mtld - - qa_reward_score - - qa_uni_score + # - qa_reward_score + # - qa_uni_score + mtld_params: + threshold: 0.7 diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 127a4314..86a02bb9 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,5 +1,4 @@ from .evaluator import ( - KGQualityEvaluator, LengthEvaluator, MTLDEvaluator, RewardEvaluator, diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index f9d2e405..245b229c 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -152,7 +152,9 @@ async def _evaluate_entity_extraction( ) -> Dict[str, Any]: """Use LLM to evaluate entity extraction quality.""" try: - prompt = ENTITY_EVALUATION_PROMPT.format( + lang = detect_main_language(chunk.content) + + prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format( chunk_content=chunk.content, extracted_entities=json.dumps( extracted_entities, ensure_ascii=False, indent=2 @@ -225,7 +227,8 @@ async def _evaluate_relation_extraction( ) -> Dict[str, Any]: """Use LLM to evaluate relation extraction quality.""" try: - prompt = RELATION_EVALUATION_PROMPT.format( + lang = detect_main_language(chunk.content) + prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format( chunk_content=chunk.content, extracted_relations=json.dumps( extracted_relations, ensure_ascii=False, indent=2 diff --git a/graphgen/models/evaluator/qa/length_evaluator.py b/graphgen/models/evaluator/qa/length_evaluator.py index 3af66380..72719577 100644 --- a/graphgen/models/evaluator/qa/length_evaluator.py +++ b/graphgen/models/evaluator/qa/length_evaluator.py @@ -1,10 +1,12 @@ + +import os from graphgen.bases import BaseEvaluator, QAPair from graphgen.models.tokenizer import Tokenizer class LengthEvaluator(BaseEvaluator): - def __init__(self, tokenizer: Tokenizer): - self.tokenizer = tokenizer + def __init__(self): + self.tokenizer: Tokenizer = Tokenizer(os.environ["TOKENIZER_MODEL"] or "cl100k_base") def evaluate(self, pair: QAPair) -> float: """ diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 5bb1261a..21a7e554 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -7,6 +7,8 @@ from .quiz import QuizService from .read import read from .search import SearchService +from .evaluate import EvaluateService + operators = { "read": read, @@ -18,4 +20,5 @@ "search": SearchService, "partition": PartitionService, "generate": GenerateService, + "evaluate": EvaluateService, } diff --git a/graphgen/operators/evaluate/__init__.py b/graphgen/operators/evaluate/__init__.py index e69de29b..aa862ee1 100644 --- a/graphgen/operators/evaluate/__init__.py +++ b/graphgen/operators/evaluate/__init__.py @@ -0,0 +1 @@ +from .evaluate_service import EvaluateService diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index 6d2fe89a..da5d2c8a 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -10,27 +10,35 @@ class EvaluateService(BaseOperator): 2. QA Quality Evaluation """ - def __init__(self, working_dir: str = "cache", metrics: list[str] = None): + def __init__(self, working_dir: str = "cache", metrics: list[str] = None, **kwargs): super().__init__(working_dir=working_dir, op_name="evaluate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.metrics = metrics - - self.evaluators = { - "xxx": "xxxEvaluator" - } - - self.graph_storage = init_storage( - xx, xx, xx - ) + self.kwargs = kwargs + self.evaluators = {} def _init_evaluators(self): for metric in self.metrics: - + if metric == "qa_length": + from graphgen.models import LengthEvaluator + self.evaluators[metric] = LengthEvaluator() + elif metric == "qa_mtld": + from graphgen.models import MTLDEvaluator + self.evaluators[metric] = MTLDEvaluator(self.kwargs.get("mtld_params", {})) + elif metric == "qa_reward_score": + from graphgen.models import RewardEvaluator + self.evaluators[metric] = RewardEvaluator(self.kwargs.get("reward_params", {})) + elif metric == "qa_uni_score": + from graphgen.models import UniEvaluator + self.evaluators[metric] = UniEvaluator(self.kwargs.get("uni_params", {})) + else: + raise ValueError(f"Unknown metric: {metric}") def process(self, batch: pd.DataFrame) -> pd.DataFrame: items = batch.to_dict(orient="records") return pd.DataFrame(self.evaluate(items)) def evaluate(self, items: list[dict]) -> list[dict]: - # 用evaluators 评估 items + print(items) pass + diff --git a/graphgen/templates/evaluation/__init__.py b/graphgen/templates/evaluation/__init__.py index 93761e85..7c2676a5 100644 --- a/graphgen/templates/evaluation/__init__.py +++ b/graphgen/templates/evaluation/__init__.py @@ -1 +1 @@ -from .kg import ACCURACY_EVALUATION_PROMPT +from .kg import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/__init__.py b/graphgen/templates/evaluation/kg/__init__.py index 9c500d1f..db8edce6 100644 --- a/graphgen/templates/evaluation/kg/__init__.py +++ b/graphgen/templates/evaluation/kg/__init__.py @@ -1 +1,2 @@ from .accuracy_evaluation import ACCURACY_EVALUATION_PROMPT +from .consistency_evaluation import CONSISTENCY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/consistency_evaluation.py b/graphgen/templates/evaluation/kg/consistency_evaluation.py index b8cf2f8d..b540e528 100644 --- a/graphgen/templates/evaluation/kg/consistency_evaluation.py +++ b/graphgen/templates/evaluation/kg/consistency_evaluation.py @@ -95,3 +95,9 @@ "description": "<实体描述>" }} """ + +CONSISTENCY_EVALUATION_PROMPT = { + "en": "", + "zh": "" +} + diff --git a/graphgen/utils/help_nltk.py b/graphgen/utils/help_nltk.py index ab70236f..07d39ef1 100644 --- a/graphgen/utils/help_nltk.py +++ b/graphgen/utils/help_nltk.py @@ -1,7 +1,16 @@ from functools import lru_cache import os from typing import Dict, List, Final, Optional +import warnings import nltk + +warnings.filterwarnings( + "ignore", + category=UserWarning, + module="jieba\._compat" +) + + import jieba class NLTKHelper: From 06fc6e3a12102e444ddaaa7028784cfc5eed1ea6 Mon Sep 17 00:00:00 2001 From: chenzihong <522023320011@smail.nju.edu.cn> Date: Fri, 26 Dec 2025 01:42:25 +0800 Subject: [PATCH 22/29] perf: perf evaluate_service --- .../operators/evaluate/evaluate_service.py | 52 +++++++++++++++++-- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index da5d2c8a..82825263 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -1,7 +1,9 @@ +from typing import Any import pandas as pd -from graphgen.bases import BaseLLMWrapper, BaseOperator +from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair from graphgen.common import init_llm +from graphgen.utils import run_concurrent class EvaluateService(BaseOperator): @@ -38,7 +40,49 @@ def process(self, batch: pd.DataFrame) -> pd.DataFrame: items = batch.to_dict(orient="records") return pd.DataFrame(self.evaluate(items)) - def evaluate(self, items: list[dict]) -> list[dict]: - print(items) - pass + async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: + try: + qa_pair = QAPair( + question=str(item.get("question", "")), + answer=str(item.get("answer", "")) + ) + if not qa_pair.question or not qa_pair.answer: + self.logger.error("Empty question or answer, skipping.") + return {} + except Exception as e: + self.logger.error( + "Error in QAPair creation: %s", + str(e) + ) + return {} + for metric, evaluator in self.evaluators.items(): + try: + score = evaluator.evaluate(qa_pair) + if isinstance(score, dict): + for sub_metric, sub_score in score.items(): + item[f"{metric}_{sub_metric}"] = float(sub_score) + else: + item[metric] = float(score) + except Exception as e: + self.logger.error( + "Error in %s evaluation: %s", + metric, + str(e) + ) + item[metric] = None + + def evaluate(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not items: + return [] + + results = run_concurrent( + self._process_single, + items, + desc="Evaluating items", + unit="item", + ) + + results = [item for item in results if item] + + return results From f9d6dc38603cc9cbe9a98cc16b22feaee4979b6f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 26 Dec 2025 12:02:49 +0800 Subject: [PATCH 23/29] fix: fix output node --- graphgen/engine.py | 2 + .../models/evaluator/qa/mtld_evaluator.py | 4 +- .../operators/evaluate/evaluate_service.py | 51 ++++++++++++++----- graphgen/run.py | 9 ++-- 4 files changed, 47 insertions(+), 19 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index 26bcff58..7b871a61 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -271,6 +271,8 @@ def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: for node in sorted_nodes: self._execute_node(node, initial_ds) + if getattr(node, "save_output", False): + self.datasets[node.id] = self.datasets[node.id].materialize() output_nodes = [n for n in sorted_nodes if getattr(n, "save_output", False)] return {node.id: self.datasets[node.id] for node in output_nodes} diff --git a/graphgen/models/evaluator/qa/mtld_evaluator.py b/graphgen/models/evaluator/qa/mtld_evaluator.py index 3cd43b75..e4e18d32 100644 --- a/graphgen/models/evaluator/qa/mtld_evaluator.py +++ b/graphgen/models/evaluator/qa/mtld_evaluator.py @@ -11,8 +11,8 @@ class MTLDEvaluator(BaseEvaluator): def __init__(self, threshold: float = 0.72): self.nltk_helper = NLTKHelper() - self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("english")) - self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("chinese")) + self.stopwords_en: Set[str] = set(self.nltk_helper.get_stopwords("en")) + self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh")) self.threshold = threshold def evaluate(self, pair: QAPair) -> float: diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index 82825263..649e44f5 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -1,4 +1,5 @@ from typing import Any + import pandas as pd from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair @@ -18,21 +19,32 @@ def __init__(self, working_dir: str = "cache", metrics: list[str] = None, **kwar self.metrics = metrics self.kwargs = kwargs self.evaluators = {} + self._init_evaluators() def _init_evaluators(self): for metric in self.metrics: if metric == "qa_length": from graphgen.models import LengthEvaluator + self.evaluators[metric] = LengthEvaluator() elif metric == "qa_mtld": from graphgen.models import MTLDEvaluator - self.evaluators[metric] = MTLDEvaluator(self.kwargs.get("mtld_params", {})) + + self.evaluators[metric] = MTLDEvaluator( + **self.kwargs.get("mtld_params", {}) + ) elif metric == "qa_reward_score": from graphgen.models import RewardEvaluator - self.evaluators[metric] = RewardEvaluator(self.kwargs.get("reward_params", {})) + + self.evaluators[metric] = RewardEvaluator( + **self.kwargs.get("reward_params", {}) + ) elif metric == "qa_uni_score": from graphgen.models import UniEvaluator - self.evaluators[metric] = UniEvaluator(self.kwargs.get("uni_params", {})) + + self.evaluators[metric] = UniEvaluator( + **self.kwargs.get("uni_params", {}) + ) else: raise ValueError(f"Unknown metric: {metric}") @@ -44,16 +56,13 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: try: qa_pair = QAPair( question=str(item.get("question", "")), - answer=str(item.get("answer", "")) + answer=str(item.get("answer", "")), ) if not qa_pair.question or not qa_pair.answer: self.logger.error("Empty question or answer, skipping.") return {} except Exception as e: - self.logger.error( - "Error in QAPair creation: %s", - str(e) - ) + self.logger.error("Error in QAPair creation: %s", str(e)) return {} for metric, evaluator in self.evaluators.items(): @@ -65,17 +74,33 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: else: item[metric] = float(score) except Exception as e: - self.logger.error( - "Error in %s evaluation: %s", - metric, - str(e) - ) + self.logger.error("Error in %s evaluation: %s", metric, str(e)) item[metric] = None + return item + + @staticmethod + def transform_messages_format(items: list[dict]) -> list[dict]: + """ + Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...] + """ + transformed = [] + for item in items: + messages = item.get("messages", []) + question = next( + (m["content"] for m in messages if m.get("role") == "user"), "" + ) + answer = next( + (m["content"] for m in messages if m.get("role") == "assistant"), "" + ) + + transformed.append({"question": question, "answer": answer}) + return transformed def evaluate(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: if not items: return [] + items = self.transform_messages_format(items) results = run_concurrent( self._process_single, items, diff --git a/graphgen/run.py b/graphgen/run.py index a1b65364..d3d47cd3 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -91,10 +91,11 @@ def main(): results = engine.execute(ds) for node_id, dataset in results.items(): - output_path = os.path.join(output_path, f"{node_id}") - os.makedirs(output_path, exist_ok=True) + logger.info("Saving results for node %s", node_id) + node_output_path = os.path.join(output_path, f"{node_id}") + os.makedirs(node_output_path, exist_ok=True) dataset.write_json( - output_path, + node_output_path, filename_provider=NodeFilenameProvider(node_id), pandas_json_args_fn=lambda: { "force_ascii": False, @@ -102,7 +103,7 @@ def main(): "lines": True, }, ) - logger.info("Node %s results saved to %s", node_id, output_path) + logger.info("Node %s results saved to %s", node_id, node_output_path) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) From 4d022fb461b8d182f5d447908243effdee2d1c2b Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Fri, 26 Dec 2025 11:46:23 +0800 Subject: [PATCH 24/29] merge --- graphgen/operators/__init__.py | 1 + graphgen/operators/evaluate/__init__.py | 2 ++ graphgen/operators/evaluate/evaluate_service.py | 4 ---- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 21a7e554..1fa47c51 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,5 +1,6 @@ from .build_kg import BuildKGService from .chunk import ChunkService +from .evaluate import EvaluateService from .extract import ExtractService from .generate import GenerateService from .judge import JudgeService diff --git a/graphgen/operators/evaluate/__init__.py b/graphgen/operators/evaluate/__init__.py index aa862ee1..060c68d6 100644 --- a/graphgen/operators/evaluate/__init__.py +++ b/graphgen/operators/evaluate/__init__.py @@ -1 +1,3 @@ from .evaluate_service import EvaluateService + +__all__ = ["EvaluateService"] diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index 649e44f5..c76c4271 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -48,10 +48,6 @@ def _init_evaluators(self): else: raise ValueError(f"Unknown metric: {metric}") - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - items = batch.to_dict(orient="records") - return pd.DataFrame(self.evaluate(items)) - async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: try: qa_pair = QAPair( From 084cb084db5fcf81b644bcd19b062b5eda240c6b Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Fri, 26 Dec 2025 12:19:10 +0800 Subject: [PATCH 25/29] feat: add KGQualityEvaluator and integrate into EvaluateService for KG evaluations --- graphgen/models/__init__.py | 1 + graphgen/models/evaluator/__init__.py | 7 +- graphgen/models/evaluator/kg/__init__.py | 2 + .../evaluator/kg/kg_quality_evaluator.py | 79 ++++++++++ .../operators/evaluate/evaluate_service.py | 142 +++++++++++++++--- 5 files changed, 213 insertions(+), 18 deletions(-) create mode 100644 graphgen/models/evaluator/kg/kg_quality_evaluator.py diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 86a02bb9..127a4314 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,4 +1,5 @@ from .evaluator import ( + KGQualityEvaluator, LengthEvaluator, MTLDEvaluator, RewardEvaluator, diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index b6121648..79192237 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -1,2 +1,7 @@ from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator -from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator +from .kg import ( + AccuracyEvaluator, + ConsistencyEvaluator, + KGQualityEvaluator, + StructureEvaluator, +) diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py index 375cbc50..dc83b6d3 100644 --- a/graphgen/models/evaluator/kg/__init__.py +++ b/graphgen/models/evaluator/kg/__init__.py @@ -9,10 +9,12 @@ from .accuracy_evaluator import AccuracyEvaluator from .consistency_evaluator import ConsistencyEvaluator +from .kg_quality_evaluator import KGQualityEvaluator from .structure_evaluator import StructureEvaluator __all__ = [ "AccuracyEvaluator", "ConsistencyEvaluator", + "KGQualityEvaluator", "StructureEvaluator", ] diff --git a/graphgen/models/evaluator/kg/kg_quality_evaluator.py b/graphgen/models/evaluator/kg/kg_quality_evaluator.py new file mode 100644 index 00000000..3b49b070 --- /dev/null +++ b/graphgen/models/evaluator/kg/kg_quality_evaluator.py @@ -0,0 +1,79 @@ +from typing import Any, Dict + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.common import init_llm, init_storage +from graphgen.models.evaluator.kg.accuracy_evaluator import AccuracyEvaluator +from graphgen.models.evaluator.kg.consistency_evaluator import ConsistencyEvaluator +from graphgen.models.evaluator.kg.structure_evaluator import StructureEvaluator +from graphgen.utils import logger + + +class KGQualityEvaluator: + def __init__( + self, + working_dir: str = "cache", + graph_backend: str = "kuzu", + kv_backend: str = "rocksdb", + **kwargs + ): + # Initialize storage + self.graph_storage: BaseGraphStorage = init_storage( + backend=graph_backend, working_dir=working_dir, namespace="graph" + ) + self.chunk_storage: BaseKVStorage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="chunk" + ) + + # Initialize LLM client + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + + # Initialize individual evaluators + self.accuracy_evaluator = AccuracyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + ) + + self.consistency_evaluator = ConsistencyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + ) + + # Structure evaluator doesn't need chunk_storage or llm_client + structure_params = kwargs.get("structure_params", {}) + self.structure_evaluator = StructureEvaluator( + graph_storage=self.graph_storage, + **structure_params + ) + + logger.info("KGQualityEvaluator initialized") + + def evaluate_accuracy(self) -> Dict[str, Any]: + logger.info("Running accuracy evaluation...") + results = self.accuracy_evaluator.evaluate() + logger.info("Accuracy evaluation completed") + return results + + def evaluate_consistency(self) -> Dict[str, Any]: + logger.info("Running consistency evaluation...") + results = self.consistency_evaluator.evaluate() + logger.info("Consistency evaluation completed") + return results + + def evaluate_structure(self) -> Dict[str, Any]: + logger.info("Running structural robustness evaluation...") + results = self.structure_evaluator.evaluate() + logger.info("Structural robustness evaluation completed") + return results + + def evaluate_all(self) -> Dict[str, Any]: + logger.info("Running all KG evaluations...") + results = { + "accuracy": self.evaluate_accuracy(), + "consistency": self.evaluate_consistency(), + "structure": self.evaluate_structure(), + } + logger.info("All KG evaluations completed") + return results + diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index c76c4271..cd4f1c78 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -1,10 +1,11 @@ -from typing import Any +from typing import Any, Dict, List, Union import pandas as pd from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair from graphgen.common import init_llm -from graphgen.utils import run_concurrent +from graphgen.models import KGQualityEvaluator +from graphgen.utils import logger, run_concurrent class EvaluateService(BaseOperator): @@ -13,40 +14,67 @@ class EvaluateService(BaseOperator): 2. QA Quality Evaluation """ - def __init__(self, working_dir: str = "cache", metrics: list[str] = None, **kwargs): + def __init__( + self, + working_dir: str = "cache", + metrics: list[str] = None, + graph_backend: str = "kuzu", + kv_backend: str = "rocksdb", + **kwargs + ): super().__init__(working_dir=working_dir, op_name="evaluate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") - self.metrics = metrics + self.metrics = metrics or [] self.kwargs = kwargs - self.evaluators = {} + self.graph_backend = graph_backend + self.kv_backend = kv_backend + + # Separate QA and KG metrics + self.qa_metrics = [m for m in self.metrics if m.startswith("qa_")] + self.kg_metrics = [m for m in self.metrics if m.startswith("kg_")] + + # Initialize evaluators + self.qa_evaluators = {} + self.kg_evaluator = None + self._init_evaluators() def _init_evaluators(self): - for metric in self.metrics: + """Initialize QA and KG evaluators based on metrics.""" + # Initialize QA evaluators + for metric in self.qa_metrics: if metric == "qa_length": from graphgen.models import LengthEvaluator - self.evaluators[metric] = LengthEvaluator() + self.qa_evaluators[metric] = LengthEvaluator() elif metric == "qa_mtld": from graphgen.models import MTLDEvaluator - - self.evaluators[metric] = MTLDEvaluator( + self.qa_evaluators[metric] = MTLDEvaluator( **self.kwargs.get("mtld_params", {}) ) elif metric == "qa_reward_score": from graphgen.models import RewardEvaluator - - self.evaluators[metric] = RewardEvaluator( + self.qa_evaluators[metric] = RewardEvaluator( **self.kwargs.get("reward_params", {}) ) elif metric == "qa_uni_score": from graphgen.models import UniEvaluator - - self.evaluators[metric] = UniEvaluator( + self.qa_evaluators[metric] = UniEvaluator( **self.kwargs.get("uni_params", {}) ) else: - raise ValueError(f"Unknown metric: {metric}") + raise ValueError(f"Unknown QA metric: {metric}") + + # Initialize KG evaluator if KG metrics are specified + if self.kg_metrics: + kg_params = self.kwargs.get("kg_params", {}) + self.kg_evaluator = KGQualityEvaluator( + working_dir=self.working_dir, + graph_backend=self.graph_backend, + kv_backend=self.kv_backend, + **kg_params + ) + logger.info("KG evaluator initialized") async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: try: @@ -61,7 +89,7 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: self.logger.error("Error in QAPair creation: %s", str(e)) return {} - for metric, evaluator in self.evaluators.items(): + for metric, evaluator in self.qa_evaluators.items(): try: score = evaluator.evaluate(qa_pair) if isinstance(score, dict): @@ -92,18 +120,98 @@ def transform_messages_format(items: list[dict]) -> list[dict]: transformed.append({"question": question, "answer": answer}) return transformed - def evaluate(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: + def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: if not items: return [] + if not self.qa_evaluators: + logger.warning("No QA evaluators initialized, skipping QA evaluation") + return [] + items = self.transform_messages_format(items) results = run_concurrent( self._process_single, items, - desc="Evaluating items", + desc="Evaluating QA items", unit="item", ) results = [item for item in results if item] + return results + def _evaluate_kg(self) -> Dict[str, Any]: + if not self.kg_evaluator: + logger.warning("No KG evaluator initialized, skipping KG evaluation") + return {} + + results = {} + + # Map metric names to evaluation functions + kg_metric_map = { + "kg_accuracy": self.kg_evaluator.evaluate_accuracy, + "kg_consistency": self.kg_evaluator.evaluate_consistency, + "kg_structure": self.kg_evaluator.evaluate_structure, + } + + # Run KG evaluations based on metrics + for metric in self.kg_metrics: + if metric in kg_metric_map: + logger.info("Running %s evaluation...", metric) + metric_key = metric.replace("kg_", "") # Remove "kg_" prefix + try: + results[metric_key] = kg_metric_map[metric]() + except Exception as e: + logger.error("Error in %s evaluation: %s", metric, str(e)) + results[metric_key] = {"error": str(e)} + else: + logger.warning("Unknown KG metric: %s, skipping", metric) + + # If no valid metrics were found, run all evaluations + if not results: + logger.info("No valid KG metrics found, running all evaluations") + results = self.kg_evaluator.evaluate_all() + return results + + def evaluate( + self, items: list[dict[str, Any]] = None + ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: + # Determine evaluation type + has_qa_metrics = len(self.qa_metrics) > 0 + has_kg_metrics = len(self.kg_metrics) > 0 + + # If items provided and QA metrics exist, do QA evaluation + if items is not None and has_qa_metrics: + return self._evaluate_qa(items) + + # If KG metrics exist, do KG evaluation + if has_kg_metrics: + return self._evaluate_kg() + + # If no metrics specified, try to infer from context + if items is not None: + logger.warning("No QA metrics specified but items provided, skipping evaluation") + return [] + else: + logger.warning("No metrics specified, skipping evaluation") + return {} + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + has_qa_metrics = len(self.qa_metrics) > 0 + has_kg_metrics = len(self.kg_metrics) > 0 + + # QA evaluation: process batch items + if has_qa_metrics: + items = batch.to_dict(orient="records") + results = self._evaluate_qa(items) + return pd.DataFrame(results) + + # KG evaluation: evaluate from storage + if has_kg_metrics: + results = self._evaluate_kg() + # Convert dict to DataFrame (single row) + return pd.DataFrame([results]) + + # No metrics specified + logger.warning("No metrics specified, returning empty DataFrame") + return pd.DataFrame() From 98968e668df77551d51637e3000e3a5957be1cd4 Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Fri, 26 Dec 2025 15:45:26 +0800 Subject: [PATCH 26/29] refactor: remove KGQualityEvaluator and restructure KG evaluation integration --- graphgen/models/__init__.py | 1 - graphgen/models/evaluator/__init__.py | 1 - graphgen/models/evaluator/kg/README.md | 5 +- graphgen/models/evaluator/kg/__init__.py | 2 - .../evaluator/kg/kg_quality_evaluator.py | 79 -------------- graphgen/operators/evaluate/evaluate_kg.py | 100 ++++++++++-------- .../operators/evaluate/evaluate_service.py | 35 +++--- 7 files changed, 80 insertions(+), 143 deletions(-) delete mode 100644 graphgen/models/evaluator/kg/kg_quality_evaluator.py diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 127a4314..86a02bb9 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,5 +1,4 @@ from .evaluator import ( - KGQualityEvaluator, LengthEvaluator, MTLDEvaluator, RewardEvaluator, diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index 79192237..4562a048 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -2,6 +2,5 @@ from .kg import ( AccuracyEvaluator, ConsistencyEvaluator, - KGQualityEvaluator, StructureEvaluator, ) diff --git a/graphgen/models/evaluator/kg/README.md b/graphgen/models/evaluator/kg/README.md index 833e9ad6..10e26f6b 100644 --- a/graphgen/models/evaluator/kg/README.md +++ b/graphgen/models/evaluator/kg/README.md @@ -4,12 +4,13 @@ This module provides comprehensive quality evaluation for knowledge graphs built ## Module Structure -The evaluation functionality has been split into modular components: +The evaluation functionality is organized into modular components: - **`accuracy_evaluator.py`**: Entity/relation extraction quality evaluation using LLM-as-a-Judge - **`consistency_evaluator.py`**: Attribute value conflict detection - **`structure_evaluator.py`**: Graph structural robustness metrics -- **`kg_quality_evaluator.py`**: Main evaluator class that integrates all modules + +The evaluation components are integrated in `graphgen/operators/evaluate/evaluate_kg.py`, which provides functions to create and use these evaluators. ## Features diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py index dc83b6d3..375cbc50 100644 --- a/graphgen/models/evaluator/kg/__init__.py +++ b/graphgen/models/evaluator/kg/__init__.py @@ -9,12 +9,10 @@ from .accuracy_evaluator import AccuracyEvaluator from .consistency_evaluator import ConsistencyEvaluator -from .kg_quality_evaluator import KGQualityEvaluator from .structure_evaluator import StructureEvaluator __all__ = [ "AccuracyEvaluator", "ConsistencyEvaluator", - "KGQualityEvaluator", "StructureEvaluator", ] diff --git a/graphgen/models/evaluator/kg/kg_quality_evaluator.py b/graphgen/models/evaluator/kg/kg_quality_evaluator.py deleted file mode 100644 index 3b49b070..00000000 --- a/graphgen/models/evaluator/kg/kg_quality_evaluator.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any, Dict - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper -from graphgen.common import init_llm, init_storage -from graphgen.models.evaluator.kg.accuracy_evaluator import AccuracyEvaluator -from graphgen.models.evaluator.kg.consistency_evaluator import ConsistencyEvaluator -from graphgen.models.evaluator.kg.structure_evaluator import StructureEvaluator -from graphgen.utils import logger - - -class KGQualityEvaluator: - def __init__( - self, - working_dir: str = "cache", - graph_backend: str = "kuzu", - kv_backend: str = "rocksdb", - **kwargs - ): - # Initialize storage - self.graph_storage: BaseGraphStorage = init_storage( - backend=graph_backend, working_dir=working_dir, namespace="graph" - ) - self.chunk_storage: BaseKVStorage = init_storage( - backend=kv_backend, working_dir=working_dir, namespace="chunk" - ) - - # Initialize LLM client - self.llm_client: BaseLLMWrapper = init_llm("synthesizer") - - # Initialize individual evaluators - self.accuracy_evaluator = AccuracyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - ) - - self.consistency_evaluator = ConsistencyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - ) - - # Structure evaluator doesn't need chunk_storage or llm_client - structure_params = kwargs.get("structure_params", {}) - self.structure_evaluator = StructureEvaluator( - graph_storage=self.graph_storage, - **structure_params - ) - - logger.info("KGQualityEvaluator initialized") - - def evaluate_accuracy(self) -> Dict[str, Any]: - logger.info("Running accuracy evaluation...") - results = self.accuracy_evaluator.evaluate() - logger.info("Accuracy evaluation completed") - return results - - def evaluate_consistency(self) -> Dict[str, Any]: - logger.info("Running consistency evaluation...") - results = self.consistency_evaluator.evaluate() - logger.info("Consistency evaluation completed") - return results - - def evaluate_structure(self) -> Dict[str, Any]: - logger.info("Running structural robustness evaluation...") - results = self.structure_evaluator.evaluate() - logger.info("Structural robustness evaluation completed") - return results - - def evaluate_all(self) -> Dict[str, Any]: - logger.info("Running all KG evaluations...") - results = { - "accuracy": self.evaluate_accuracy(), - "consistency": self.evaluate_consistency(), - "structure": self.evaluate_structure(), - } - logger.info("All KG evaluations completed") - return results - diff --git a/graphgen/operators/evaluate/evaluate_kg.py b/graphgen/operators/evaluate/evaluate_kg.py index a58617fe..5b520ea9 100644 --- a/graphgen/operators/evaluate/evaluate_kg.py +++ b/graphgen/operators/evaluate/evaluate_kg.py @@ -2,70 +2,86 @@ from dotenv import load_dotenv -from graphgen.models import KGQualityEvaluator +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.common import init_llm, init_storage +from graphgen.models.evaluator.kg.accuracy_evaluator import AccuracyEvaluator +from graphgen.models.evaluator.kg.consistency_evaluator import ConsistencyEvaluator +from graphgen.models.evaluator.kg.structure_evaluator import StructureEvaluator from graphgen.utils import logger # Load environment variables load_dotenv() -def evaluate_accuracy(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Evaluate accuracy of entity and relation extraction. - - Args: - evaluator: KGQualityEvaluator instance +class KGEvaluators: + def __init__( + self, + working_dir: str = "cache", + graph_backend: str = "kuzu", + kv_backend: str = "rocksdb", + **kwargs + ): + # Initialize storage + self.graph_storage: BaseGraphStorage = init_storage( + backend=graph_backend, working_dir=working_dir, namespace="graph" + ) + self.chunk_storage: BaseKVStorage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="chunk" + ) - Returns: - Dictionary containing entity_accuracy and relation_accuracy metrics. - """ + # Initialize LLM client + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + + # Initialize individual evaluators + self.accuracy_evaluator = AccuracyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + ) + + self.consistency_evaluator = ConsistencyEvaluator( + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, + ) + + # Structure evaluator doesn't need chunk_storage or llm_client + structure_params = kwargs.get("structure_params", {}) + self.structure_evaluator = StructureEvaluator( + graph_storage=self.graph_storage, + **structure_params + ) + + logger.info("KG evaluators initialized") + + +def evaluate_accuracy(evaluators: KGEvaluators) -> Dict[str, Any]: logger.info("Running accuracy evaluation...") - results = evaluator.evaluate_accuracy() + results = evaluators.accuracy_evaluator.evaluate() logger.info("Accuracy evaluation completed") return results -def evaluate_consistency(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Evaluate consistency by detecting semantic conflicts. - - Args: - evaluator: KGQualityEvaluator instance - - Returns: - Dictionary containing consistency metrics including conflict_rate and conflicts. - """ +def evaluate_consistency(evaluators: KGEvaluators) -> Dict[str, Any]: logger.info("Running consistency evaluation...") - results = evaluator.evaluate_consistency() + results = evaluators.consistency_evaluator.evaluate() logger.info("Consistency evaluation completed") return results -def evaluate_structure(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Evaluate structural robustness of the graph. - - Args: - evaluator: KGQualityEvaluator instance - - Returns: - Dictionary containing structural metrics including noise_ratio, largest_cc_ratio, etc. - """ +def evaluate_structure(evaluators: KGEvaluators) -> Dict[str, Any]: logger.info("Running structural robustness evaluation...") - results = evaluator.evaluate_structure() + results = evaluators.structure_evaluator.evaluate() logger.info("Structural robustness evaluation completed") return results -def evaluate_all(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Run all evaluations (accuracy, consistency, structure). - - Args: - evaluator: KGQualityEvaluator instance - - Returns: - Dictionary containing all evaluation results with keys: accuracy, consistency, structure. - """ +def evaluate_all(evaluators: KGEvaluators) -> Dict[str, Any]: logger.info("Running all evaluations...") - results = evaluator.evaluate_all() + results = { + "accuracy": evaluate_accuracy(evaluators), + "consistency": evaluate_consistency(evaluators), + "structure": evaluate_structure(evaluators), + } logger.info("All evaluations completed") return results - - diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index cd4f1c78..c9d3c24f 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -1,10 +1,15 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import pandas as pd -from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair -from graphgen.common import init_llm -from graphgen.models import KGQualityEvaluator +from graphgen.bases import BaseOperator, QAPair +from graphgen.operators.evaluate.evaluate_kg import ( + KGEvaluators, + evaluate_accuracy, + evaluate_all, + evaluate_consistency, + evaluate_structure, +) from graphgen.utils import logger, run_concurrent @@ -23,7 +28,6 @@ def __init__( **kwargs ): super().__init__(working_dir=working_dir, op_name="evaluate_service") - self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.metrics = metrics or [] self.kwargs = kwargs self.graph_backend = graph_backend @@ -35,7 +39,7 @@ def __init__( # Initialize evaluators self.qa_evaluators = {} - self.kg_evaluator = None + self.kg_evaluators: Optional[KGEvaluators] = None self._init_evaluators() @@ -65,16 +69,15 @@ def _init_evaluators(self): else: raise ValueError(f"Unknown QA metric: {metric}") - # Initialize KG evaluator if KG metrics are specified + # Initialize KG evaluators if KG metrics are specified if self.kg_metrics: kg_params = self.kwargs.get("kg_params", {}) - self.kg_evaluator = KGQualityEvaluator( + self.kg_evaluators = KGEvaluators( working_dir=self.working_dir, graph_backend=self.graph_backend, kv_backend=self.kv_backend, **kg_params ) - logger.info("KG evaluator initialized") async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: try: @@ -140,17 +143,17 @@ def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: return results def _evaluate_kg(self) -> Dict[str, Any]: - if not self.kg_evaluator: - logger.warning("No KG evaluator initialized, skipping KG evaluation") + if not self.kg_evaluators: + logger.warning("No KG evaluators initialized, skipping KG evaluation") return {} results = {} # Map metric names to evaluation functions kg_metric_map = { - "kg_accuracy": self.kg_evaluator.evaluate_accuracy, - "kg_consistency": self.kg_evaluator.evaluate_consistency, - "kg_structure": self.kg_evaluator.evaluate_structure, + "kg_accuracy": evaluate_accuracy, + "kg_consistency": evaluate_consistency, + "kg_structure": evaluate_structure, } # Run KG evaluations based on metrics @@ -159,7 +162,7 @@ def _evaluate_kg(self) -> Dict[str, Any]: logger.info("Running %s evaluation...", metric) metric_key = metric.replace("kg_", "") # Remove "kg_" prefix try: - results[metric_key] = kg_metric_map[metric]() + results[metric_key] = kg_metric_map[metric](self.kg_evaluators) except Exception as e: logger.error("Error in %s evaluation: %s", metric, str(e)) results[metric_key] = {"error": str(e)} @@ -169,7 +172,7 @@ def _evaluate_kg(self) -> Dict[str, Any]: # If no valid metrics were found, run all evaluations if not results: logger.info("No valid KG metrics found, running all evaluations") - results = self.kg_evaluator.evaluate_all() + results = evaluate_all(self.kg_evaluators) return results From 71ebba29623abab15e77747f88345d6618e90407 Mon Sep 17 00:00:00 2001 From: CHERRY-ui8 <2693275288@qq.com> Date: Fri, 26 Dec 2025 15:57:23 +0800 Subject: [PATCH 27/29] pylints --- .../models/evaluator/qa/length_evaluator.py | 5 +-- graphgen/models/evaluator/qa/uni_evaluator.py | 4 +-- graphgen/operators/__init__.py | 1 - graphgen/operators/evaluate/evaluate_kg.py | 10 +++--- .../operators/evaluate/evaluate_service.py | 33 +++++++++---------- .../evaluation/kg/consistency_evaluation.py | 1 - graphgen/utils/help_nltk.py | 8 ++--- 7 files changed, 29 insertions(+), 33 deletions(-) diff --git a/graphgen/models/evaluator/qa/length_evaluator.py b/graphgen/models/evaluator/qa/length_evaluator.py index 72719577..266edfb6 100644 --- a/graphgen/models/evaluator/qa/length_evaluator.py +++ b/graphgen/models/evaluator/qa/length_evaluator.py @@ -5,8 +5,9 @@ class LengthEvaluator(BaseEvaluator): - def __init__(self): - self.tokenizer: Tokenizer = Tokenizer(os.environ["TOKENIZER_MODEL"] or "cl100k_base") + def __init__(self, tokenizer_name: str = None): + tokenizer_model = tokenizer_name or os.environ.get("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer: Tokenizer = Tokenizer(tokenizer_model) def evaluate(self, pair: QAPair) -> float: """ diff --git a/graphgen/models/evaluator/qa/uni_evaluator.py b/graphgen/models/evaluator/qa/uni_evaluator.py index 9dc7ad2c..38406512 100644 --- a/graphgen/models/evaluator/qa/uni_evaluator.py +++ b/graphgen/models/evaluator/qa/uni_evaluator.py @@ -55,9 +55,9 @@ def _build_input_text(dimension: str, question: str, answer: str) -> str: """Construct input text for specified dimension.""" if dimension == "naturalness": return f"question: Is this a natural response? response: {answer}" - elif dimension == "coherence": + if dimension == "coherence": return f"question: Is this a coherent response? response: {answer} history: {question}" - elif dimension == "understandability": + if dimension == "understandability": return f"question: Is this an understandable response? response: {answer}" raise NotImplementedError(f"Unsupported dimension '{dimension}'") diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 1fa47c51..ab840cc5 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -8,7 +8,6 @@ from .quiz import QuizService from .read import read from .search import SearchService -from .evaluate import EvaluateService operators = { diff --git a/graphgen/operators/evaluate/evaluate_kg.py b/graphgen/operators/evaluate/evaluate_kg.py index 5b520ea9..fe3032dd 100644 --- a/graphgen/operators/evaluate/evaluate_kg.py +++ b/graphgen/operators/evaluate/evaluate_kg.py @@ -28,30 +28,30 @@ def __init__( self.chunk_storage: BaseKVStorage = init_storage( backend=kv_backend, working_dir=working_dir, namespace="chunk" ) - + # Initialize LLM client self.llm_client: BaseLLMWrapper = init_llm("synthesizer") - + # Initialize individual evaluators self.accuracy_evaluator = AccuracyEvaluator( graph_storage=self.graph_storage, chunk_storage=self.chunk_storage, llm_client=self.llm_client, ) - + self.consistency_evaluator = ConsistencyEvaluator( graph_storage=self.graph_storage, chunk_storage=self.chunk_storage, llm_client=self.llm_client, ) - + # Structure evaluator doesn't need chunk_storage or llm_client structure_params = kwargs.get("structure_params", {}) self.structure_evaluator = StructureEvaluator( graph_storage=self.graph_storage, **structure_params ) - + logger.info("KG evaluators initialized") diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index c9d3c24f..a5009b58 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -32,15 +32,15 @@ def __init__( self.kwargs = kwargs self.graph_backend = graph_backend self.kv_backend = kv_backend - + # Separate QA and KG metrics self.qa_metrics = [m for m in self.metrics if m.startswith("qa_")] self.kg_metrics = [m for m in self.metrics if m.startswith("kg_")] - + # Initialize evaluators self.qa_evaluators = {} self.kg_evaluators: Optional[KGEvaluators] = None - + self._init_evaluators() def _init_evaluators(self): @@ -68,7 +68,7 @@ def _init_evaluators(self): ) else: raise ValueError(f"Unknown QA metric: {metric}") - + # Initialize KG evaluators if KG metrics are specified if self.kg_metrics: kg_params = self.kwargs.get("kg_params", {}) @@ -148,14 +148,14 @@ def _evaluate_kg(self) -> Dict[str, Any]: return {} results = {} - + # Map metric names to evaluation functions kg_metric_map = { "kg_accuracy": evaluate_accuracy, "kg_consistency": evaluate_consistency, "kg_structure": evaluate_structure, } - + # Run KG evaluations based on metrics for metric in self.kg_metrics: if metric in kg_metric_map: @@ -168,12 +168,12 @@ def _evaluate_kg(self) -> Dict[str, Any]: results[metric_key] = {"error": str(e)} else: logger.warning("Unknown KG metric: %s, skipping", metric) - + # If no valid metrics were found, run all evaluations if not results: logger.info("No valid KG metrics found, running all evaluations") results = evaluate_all(self.kg_evaluators) - + return results def evaluate( @@ -182,39 +182,38 @@ def evaluate( # Determine evaluation type has_qa_metrics = len(self.qa_metrics) > 0 has_kg_metrics = len(self.kg_metrics) > 0 - + # If items provided and QA metrics exist, do QA evaluation if items is not None and has_qa_metrics: return self._evaluate_qa(items) - + # If KG metrics exist, do KG evaluation if has_kg_metrics: return self._evaluate_kg() - + # If no metrics specified, try to infer from context if items is not None: logger.warning("No QA metrics specified but items provided, skipping evaluation") return [] - else: - logger.warning("No metrics specified, skipping evaluation") - return {} + logger.warning("No metrics specified, skipping evaluation") + return {} def process(self, batch: pd.DataFrame) -> pd.DataFrame: has_qa_metrics = len(self.qa_metrics) > 0 has_kg_metrics = len(self.kg_metrics) > 0 - + # QA evaluation: process batch items if has_qa_metrics: items = batch.to_dict(orient="records") results = self._evaluate_qa(items) return pd.DataFrame(results) - + # KG evaluation: evaluate from storage if has_kg_metrics: results = self._evaluate_kg() # Convert dict to DataFrame (single row) return pd.DataFrame([results]) - + # No metrics specified logger.warning("No metrics specified, returning empty DataFrame") return pd.DataFrame() diff --git a/graphgen/templates/evaluation/kg/consistency_evaluation.py b/graphgen/templates/evaluation/kg/consistency_evaluation.py index b540e528..1600ef94 100644 --- a/graphgen/templates/evaluation/kg/consistency_evaluation.py +++ b/graphgen/templates/evaluation/kg/consistency_evaluation.py @@ -100,4 +100,3 @@ "en": "", "zh": "" } - diff --git a/graphgen/utils/help_nltk.py b/graphgen/utils/help_nltk.py index 07d39ef1..c7d5e301 100644 --- a/graphgen/utils/help_nltk.py +++ b/graphgen/utils/help_nltk.py @@ -3,16 +3,14 @@ from typing import Dict, List, Final, Optional import warnings import nltk +import jieba warnings.filterwarnings( - "ignore", + "ignore", category=UserWarning, - module="jieba\._compat" + module=r"jieba\._compat" ) - -import jieba - class NLTKHelper: """ NLTK helper class From f6cce9b5a0246c919ca1c0f039428279b67715ed Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 26 Dec 2025 16:04:15 +0800 Subject: [PATCH 28/29] feat: add kg_structure evaluation --- .../evaluate_kg/kg_evaluation_config.yaml | 6 +- .../evaluate_qa/qa_evaluation_config.yaml | 4 +- graphgen/bases/base_storage.py | 47 ++++- graphgen/common/init_storage.py | 44 ++++- graphgen/models/__init__.py | 4 +- graphgen/models/evaluator/__init__.py | 7 +- graphgen/models/evaluator/kg/__init__.py | 2 - .../evaluator/kg/kg_quality_evaluator.py | 79 -------- .../evaluator/kg/structure_evaluator.py | 139 ++++---------- graphgen/models/storage/graph/kuzu_storage.py | 91 ++++++++- .../models/storage/graph/networkx_storage.py | 27 ++- graphgen/operators/evaluate/evaluate_kg.py | 71 ------- graphgen/operators/evaluate/evaluate_qa.py | 177 ----------------- .../operators/evaluate/evaluate_service.py | 178 +++++++----------- requirements.txt | 3 +- 15 files changed, 312 insertions(+), 567 deletions(-) delete mode 100644 graphgen/models/evaluator/kg/kg_quality_evaluator.py delete mode 100644 graphgen/operators/evaluate/evaluate_kg.py delete mode 100644 graphgen/operators/evaluate/evaluate_qa.py diff --git a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml index 57c6f307..41d88c44 100644 --- a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml +++ b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml @@ -27,7 +27,7 @@ nodes: op_name: build_kg type: map_batch dependencies: - - chunk_documents + - chunk execution_params: replicas: 1 batch_size: 128 @@ -40,6 +40,6 @@ nodes: - build_kg params: metrics: - - kg_accuracy - - kg_consistency - kg_structure +# - kg_accuracy +# - kg_consistency diff --git a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml index 3e875143..459f9fad 100644 --- a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml +++ b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml @@ -92,7 +92,7 @@ nodes: metrics: - qa_length - qa_mtld - # - qa_reward_score - # - qa_uni_score + - qa_reward_score + - qa_uni_score mtld_params: threshold: 0.7 diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index ff7d2d1a..e72c5869 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -1,5 +1,6 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, TypeVar, Union +from typing import Dict, Generic, List, Set, TypeVar, Union T = TypeVar("T") @@ -45,52 +46,90 @@ def reload(self): raise NotImplementedError -class BaseGraphStorage(StorageNameSpace): +class BaseGraphStorage(StorageNameSpace, ABC): + @abstractmethod + def is_directed(self) -> bool: + pass + + @abstractmethod def has_node(self, node_id: str) -> bool: raise NotImplementedError + @abstractmethod def has_edge(self, source_node_id: str, target_node_id: str) -> bool: raise NotImplementedError + @abstractmethod def node_degree(self, node_id: str) -> int: raise NotImplementedError - def edge_degree(self, src_id: str, tgt_id: str) -> int: - raise NotImplementedError + @abstractmethod + def get_all_node_degrees(self) -> Dict[str, int]: + pass + def get_isolated_nodes(self) -> List[str]: + return [ + node_id + for node_id, degree in self.get_all_node_degrees().items() + if degree == 0 + ] + + @abstractmethod def get_node(self, node_id: str) -> Union[dict, None]: raise NotImplementedError + @abstractmethod def update_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError + @abstractmethod def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: raise NotImplementedError + @abstractmethod + def get_node_count(self) -> int: + pass + + @abstractmethod def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]: raise NotImplementedError + @abstractmethod def update_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): raise NotImplementedError + @abstractmethod def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: raise NotImplementedError + @abstractmethod + def get_edge_count(self) -> int: + pass + + @abstractmethod def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]: raise NotImplementedError + @abstractmethod def upsert_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError + @abstractmethod def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ): raise NotImplementedError + @abstractmethod def delete_node(self, node_id: str): raise NotImplementedError + @abstractmethod def reload(self): raise NotImplementedError + + @abstractmethod + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + raise NotImplementedError diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py index 56528e7a..aaffb630 100644 --- a/graphgen/common/init_storage.py +++ b/graphgen/common/init_storage.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any, Dict, List, Set, Union import ray @@ -68,6 +68,21 @@ def __init__(self, backend: str, working_dir: str, namespace: str): def index_done_callback(self): return self.graph.index_done_callback() + def is_directed(self) -> bool: + return self.graph.is_directed() + + def get_all_node_degrees(self) -> Dict[str, int]: + return self.graph.get_all_node_degrees() + + def get_node_count(self) -> int: + return self.graph.get_node_count() + + def get_edge_count(self) -> int: + return self.graph.get_edge_count() + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + return self.graph.get_connected_components(undirected) + def has_node(self, node_id: str) -> bool: return self.graph.has_node(node_id) @@ -165,6 +180,21 @@ def __init__(self, actor_handle: ray.actor.ActorHandle): def index_done_callback(self): return ray.get(self.actor.index_done_callback.remote()) + def is_directed(self) -> bool: + return ray.get(self.actor.is_directed.remote()) + + def get_all_node_degrees(self) -> Dict[str, int]: + return ray.get(self.actor.get_all_node_degrees.remote()) + + def get_node_count(self) -> int: + return ray.get(self.actor.get_node_count.remote()) + + def get_edge_count(self) -> int: + return ray.get(self.actor.get_edge_count.remote()) + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + return ray.get(self.actor.get_connected_components.remote(undirected)) + def has_node(self, node_id: str) -> bool: return ray.get(self.actor.has_node.remote(node_id)) @@ -239,10 +269,14 @@ def create_storage(backend: str, working_dir: str, namespace: str): try: actor_handle = ray.get_actor(actor_name) except ValueError: - actor_handle = ray.remote(actor_class).options( - name=actor_name, - get_if_exists=True, - ).remote(backend, working_dir, namespace) + actor_handle = ( + ray.remote(actor_class) + .options( + name=actor_name, + get_if_exists=True, + ) + .remote(backend, working_dir, namespace) + ) ray.get(actor_handle.ready.remote()) return proxy_class(actor_handle) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 127a4314..43d38bed 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,8 +1,10 @@ from .evaluator import ( - KGQualityEvaluator, + AccuracyEvaluator, + ConsistencyEvaluator, LengthEvaluator, MTLDEvaluator, RewardEvaluator, + StructureEvaluator, UniEvaluator, ) from .generator import ( diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index 79192237..6091aeb5 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -1,7 +1,2 @@ +from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator -from .kg import ( - AccuracyEvaluator, - ConsistencyEvaluator, - KGQualityEvaluator, - StructureEvaluator, -) diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py index dc83b6d3..375cbc50 100644 --- a/graphgen/models/evaluator/kg/__init__.py +++ b/graphgen/models/evaluator/kg/__init__.py @@ -9,12 +9,10 @@ from .accuracy_evaluator import AccuracyEvaluator from .consistency_evaluator import ConsistencyEvaluator -from .kg_quality_evaluator import KGQualityEvaluator from .structure_evaluator import StructureEvaluator __all__ = [ "AccuracyEvaluator", "ConsistencyEvaluator", - "KGQualityEvaluator", "StructureEvaluator", ] diff --git a/graphgen/models/evaluator/kg/kg_quality_evaluator.py b/graphgen/models/evaluator/kg/kg_quality_evaluator.py deleted file mode 100644 index 3b49b070..00000000 --- a/graphgen/models/evaluator/kg/kg_quality_evaluator.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any, Dict - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper -from graphgen.common import init_llm, init_storage -from graphgen.models.evaluator.kg.accuracy_evaluator import AccuracyEvaluator -from graphgen.models.evaluator.kg.consistency_evaluator import ConsistencyEvaluator -from graphgen.models.evaluator.kg.structure_evaluator import StructureEvaluator -from graphgen.utils import logger - - -class KGQualityEvaluator: - def __init__( - self, - working_dir: str = "cache", - graph_backend: str = "kuzu", - kv_backend: str = "rocksdb", - **kwargs - ): - # Initialize storage - self.graph_storage: BaseGraphStorage = init_storage( - backend=graph_backend, working_dir=working_dir, namespace="graph" - ) - self.chunk_storage: BaseKVStorage = init_storage( - backend=kv_backend, working_dir=working_dir, namespace="chunk" - ) - - # Initialize LLM client - self.llm_client: BaseLLMWrapper = init_llm("synthesizer") - - # Initialize individual evaluators - self.accuracy_evaluator = AccuracyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - ) - - self.consistency_evaluator = ConsistencyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - ) - - # Structure evaluator doesn't need chunk_storage or llm_client - structure_params = kwargs.get("structure_params", {}) - self.structure_evaluator = StructureEvaluator( - graph_storage=self.graph_storage, - **structure_params - ) - - logger.info("KGQualityEvaluator initialized") - - def evaluate_accuracy(self) -> Dict[str, Any]: - logger.info("Running accuracy evaluation...") - results = self.accuracy_evaluator.evaluate() - logger.info("Accuracy evaluation completed") - return results - - def evaluate_consistency(self) -> Dict[str, Any]: - logger.info("Running consistency evaluation...") - results = self.consistency_evaluator.evaluate() - logger.info("Consistency evaluation completed") - return results - - def evaluate_structure(self) -> Dict[str, Any]: - logger.info("Running structural robustness evaluation...") - results = self.structure_evaluator.evaluate() - logger.info("Structural robustness evaluation completed") - return results - - def evaluate_all(self) -> Dict[str, Any]: - logger.info("Running all KG evaluations...") - results = { - "accuracy": self.evaluate_accuracy(), - "consistency": self.evaluate_consistency(), - "structure": self.evaluate_structure(), - } - logger.info("All KG evaluations completed") - return results - diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py index 24207c53..d9fa45a9 100644 --- a/graphgen/models/evaluator/kg/structure_evaluator.py +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -1,40 +1,12 @@ from typing import Any, Dict, Optional -import networkx as nx import numpy as np - -try: - from scipy import stats -except ImportError: - stats = None +from scipy import stats from graphgen.bases import BaseGraphStorage from graphgen.utils import logger -def _convert_to_networkx(graph_storage: BaseGraphStorage) -> nx.DiGraph: - """Convert graph storage to NetworkX graph.""" - G = nx.DiGraph() - - # Add nodes - nodes = graph_storage.get_all_nodes() or [] - for node_id, node_data in nodes: - if isinstance(node_data, dict): - G.add_node(node_id, **node_data) - else: - G.add_node(node_id) - - # Add edges - edges = graph_storage.get_all_edges() or [] - for src, dst, edge_data in edges: - if isinstance(edge_data, dict): - G.add_edge(src, dst, **edge_data) - else: - G.add_edge(src, dst) - - return G - - class StructureEvaluator: """Evaluates structural robustness of the graph.""" @@ -55,110 +27,69 @@ def __init__( self.powerlaw_r2_threshold = powerlaw_r2_threshold def evaluate(self) -> Dict[str, Any]: - # Convert graph to NetworkX - G = _convert_to_networkx(self.graph_storage) + """ + Evaluate the structural robustness of the graph. + :return: + """ + storage = self.graph_storage - if G.number_of_nodes() == 0: + total_nodes = storage.get_node_count() + if total_nodes == 0: return {"error": "Empty graph"} - # Calculate metrics - total_nodes = G.number_of_nodes() - total_edges = G.number_of_edges() + total_edges = storage.get_edge_count() + degree_map = storage.get_all_node_degrees() # Noise ratio: isolated nodes / total nodes - isolated_nodes = [n for n in G.nodes() if G.degree(n) == 0] - noise_ratio = len(isolated_nodes) / total_nodes if total_nodes > 0 else 0 + isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0] + noise_ratio = len(isolated_nodes) / total_nodes # Largest connected component - if G.is_directed(): - G_undirected = G.to_undirected() - else: - G_undirected = G - - connected_components = list(nx.connected_components(G_undirected)) - if connected_components: - largest_cc = max(connected_components, key=len) - largest_cc_ratio = ( - len(largest_cc) / total_nodes if total_nodes > 0 else 0 - ) - else: - largest_cc_ratio = 0 - - # Average node degree - if total_nodes > 0: - total_degree = sum(G.degree(n) for n in G.nodes()) - avg_degree = total_degree / total_nodes - else: - avg_degree = 0 - - # Power law distribution R² - powerlaw_r2 = self._calculate_powerlaw_r2(G) - - thresholds = { - "noise_ratio": { - "value": noise_ratio, - "threshold": self.noise_ratio_threshold, - "pass": noise_ratio < self.noise_ratio_threshold, - }, - "largest_cc_ratio": { - "value": largest_cc_ratio, - "threshold": self.largest_cc_ratio_threshold, - "pass": largest_cc_ratio > self.largest_cc_ratio_threshold, - }, - "avg_degree": { - "value": avg_degree, - "threshold": (self.avg_degree_min, self.avg_degree_max), - "pass": self.avg_degree_min <= avg_degree <= self.avg_degree_max, - }, - "powerlaw_r2": { - "value": powerlaw_r2, - "threshold": self.powerlaw_r2_threshold, - "pass": powerlaw_r2 > self.powerlaw_r2_threshold if powerlaw_r2 is not None else False, - }, - } + components = storage.get_connected_components(undirected=True) + largest_cc_ratio = ( + len(max(components, key=len)) / total_nodes if components else 0 + ) + + avg_degree = sum(degree_map.values()) / total_nodes + powerlaw_r2 = self._calculate_powerlaw_r2(degree_map) - return { + results = { "total_nodes": total_nodes, "total_edges": total_edges, - "isolated_nodes_count": len(isolated_nodes), "noise_ratio": noise_ratio, "largest_cc_ratio": largest_cc_ratio, "avg_degree": avg_degree, "powerlaw_r2": powerlaw_r2, - "thresholds": thresholds, + "is_robust": ( + noise_ratio < self.noise_ratio_threshold + and largest_cc_ratio > self.largest_cc_ratio_threshold + and self.avg_degree_min <= avg_degree <= self.avg_degree_max + and ( + powerlaw_r2 is not None and powerlaw_r2 > self.powerlaw_r2_threshold + ) + ), } - def _calculate_powerlaw_r2(self, G: "nx.Graph") -> Optional[float]: - """ - Calculate R² for power law distribution of node degrees. + return results - Returns: - R² value if calculation successful, None otherwise - """ - if stats is None: - logger.warning("scipy not available, skipping power law R² calculation") - return None + @staticmethod + def _calculate_powerlaw_r2(degree_map: Dict[str, int]) -> Optional[float]: + degrees = [deg for deg in degree_map.values() if deg > 0] - degrees = [G.degree(n) for n in G.nodes()] - if len(degrees) < 10: # Need sufficient data points + if len(degrees) < 10: logger.warning("Insufficient nodes for power law fitting") return None - # Filter out zero degrees for log fitting - non_zero_degrees = [d for d in degrees if d > 0] - if len(non_zero_degrees) < 5: - return None - try: # Fit power law: log(y) = a * log(x) + b - log_degrees = np.log(non_zero_degrees) + log_degrees = np.log(degrees) sorted_log_degrees = np.sort(log_degrees) x = np.arange(1, len(sorted_log_degrees) + 1) log_x = np.log(x) # Linear regression on log-log scale r_value, *_ = stats.linregress(log_x, sorted_log_degrees) - r2 = r_value ** 2 + r2 = r_value**2 return float(r2) except Exception as e: diff --git a/graphgen/models/storage/graph/kuzu_storage.py b/graphgen/models/storage/graph/kuzu_storage.py index db3e97ea..52b41519 100644 --- a/graphgen/models/storage/graph/kuzu_storage.py +++ b/graphgen/models/storage/graph/kuzu_storage.py @@ -1,7 +1,8 @@ import json import os +from collections import defaultdict from dataclasses import dataclass -from typing import Any +from typing import Any, Dict, List, Set try: import kuzu @@ -78,6 +79,94 @@ def _safe_json_loads(data_str: str) -> dict: print(f"Error decoding JSON: {e}") return {} + def is_directed(self) -> bool: + return True + + def get_all_node_degrees(self) -> Dict[str, int]: + query = """ + MATCH (n:Entity) + OPTIONAL MATCH (n)-[r]-() + RETURN n.id, count(r) as degree + """ + + result = self._conn.execute(query) + degree_map = {} + while result.has_next(): + row = result.get_next() + if row and len(row) >= 2: + node_id, degree = row[0], row[1] + degree_map[node_id] = int(degree) + + return degree_map + + def get_isolated_nodes(self) -> List[str]: + query = """ + MATCH (n:Entity) + WHERE NOT (n)--() + RETURN n.id + """ + + result = self._conn.execute(query) + return [row[0] for row in result if row] + + def get_node_count(self) -> int: + result = self._conn.execute("MATCH (n:Entity) RETURN count(n)") + return result.get_next()[0] + + def get_edge_count(self) -> int: + result = self._conn.execute("MATCH ()-[e:Relation]->() RETURN count(e)") + return result.get_next()[0] + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + parent = {} + rank = {} + + def find(x: str) -> str: + if parent[x] != x: + parent[x] = find(parent[x]) + return parent[x] + + def union(x: str, y: str): + root_x, root_y = find(x), find(y) + if root_x == root_y: + return + if rank[root_x] < rank[root_y]: + parent[root_x] = root_y + elif rank[root_x] > rank[root_y]: + parent[root_y] = root_x + else: + parent[root_y] = root_x + rank[root_x] += 1 + + all_nodes = self.get_all_node_degrees().keys() + for node_id in all_nodes: + parent[node_id] = node_id + rank[node_id] = 0 + + query = ( + """ + MATCH (a:Entity)-[e:Relation]-(b:Entity) + RETURN DISTINCT a.id, b.id + """ + if undirected + else """ + MATCH (a:Entity)-[e:Relation]->(b:Entity) + RETURN DISTINCT a.id, b.id + """ + ) + + result = self._conn.execute(query) + for row in result: + if row and len(row) >= 2: + union(row[0], row[1]) + + components_dict = defaultdict(set) + for node_id in all_nodes: + root = find(node_id) + components_dict[root].add(node_id) + + return list(components_dict.values()) + def has_node(self, node_id: str) -> bool: result = self._conn.execute( "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id} diff --git a/graphgen/models/storage/graph/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py index 7fb73b79..b043e9d2 100644 --- a/graphgen/models/storage/graph/networkx_storage.py +++ b/graphgen/models/storage/graph/networkx_storage.py @@ -1,7 +1,7 @@ import html import os from dataclasses import dataclass -from typing import Any, Optional, Union, cast +from typing import Any, Dict, List, Optional, Set, Union, cast import networkx as nx @@ -10,6 +10,31 @@ @dataclass class NetworkXStorage(BaseGraphStorage): + def is_directed(self) -> bool: + return self._graph.is_directed() + + def get_all_node_degrees(self) -> Dict[str, int]: + return { + str(node_id): int(self._graph.degree[node_id]) + for node_id in self._graph.nodes() + } + + def get_node_count(self) -> int: + return self._graph.number_of_nodes() + + def get_edge_count(self) -> int: + return self._graph.number_of_edges() + + def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: + graph = self._graph + + if undirected and graph.is_directed(): + graph = graph.to_undirected() + + return [ + set(str(node) for node in comp) for comp in nx.connected_components(graph) + ] + @staticmethod def load_nx_graph(file_name) -> Optional[nx.Graph]: if os.path.exists(file_name): diff --git a/graphgen/operators/evaluate/evaluate_kg.py b/graphgen/operators/evaluate/evaluate_kg.py deleted file mode 100644 index a58617fe..00000000 --- a/graphgen/operators/evaluate/evaluate_kg.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import Any, Dict - -from dotenv import load_dotenv - -from graphgen.models import KGQualityEvaluator -from graphgen.utils import logger - -# Load environment variables -load_dotenv() - - -def evaluate_accuracy(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Evaluate accuracy of entity and relation extraction. - - Args: - evaluator: KGQualityEvaluator instance - - Returns: - Dictionary containing entity_accuracy and relation_accuracy metrics. - """ - logger.info("Running accuracy evaluation...") - results = evaluator.evaluate_accuracy() - logger.info("Accuracy evaluation completed") - return results - - -def evaluate_consistency(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Evaluate consistency by detecting semantic conflicts. - - Args: - evaluator: KGQualityEvaluator instance - - Returns: - Dictionary containing consistency metrics including conflict_rate and conflicts. - """ - logger.info("Running consistency evaluation...") - results = evaluator.evaluate_consistency() - logger.info("Consistency evaluation completed") - return results - - -def evaluate_structure(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Evaluate structural robustness of the graph. - - Args: - evaluator: KGQualityEvaluator instance - - Returns: - Dictionary containing structural metrics including noise_ratio, largest_cc_ratio, etc. - """ - logger.info("Running structural robustness evaluation...") - results = evaluator.evaluate_structure() - logger.info("Structural robustness evaluation completed") - return results - - -def evaluate_all(evaluator: KGQualityEvaluator) -> Dict[str, Any]: - """Run all evaluations (accuracy, consistency, structure). - - Args: - evaluator: KGQualityEvaluator instance - - Returns: - Dictionary containing all evaluation results with keys: accuracy, consistency, structure. - """ - logger.info("Running all evaluations...") - results = evaluator.evaluate_all() - logger.info("All evaluations completed") - return results - - diff --git a/graphgen/operators/evaluate/evaluate_qa.py b/graphgen/operators/evaluate/evaluate_qa.py deleted file mode 100644 index fdbfbf82..00000000 --- a/graphgen/operators/evaluate/evaluate_qa.py +++ /dev/null @@ -1,177 +0,0 @@ -# TODO: this module needs refactoring to merge into GraphGen framework -"""Evaluate the quality of the generated text using various metrics""" - -import argparse -import json -import os - -import pandas as pd -from dotenv import load_dotenv - -from graphgen.bases.datatypes import QAPair -from graphgen.models import ( - LengthEvaluator, - MTLDEvaluator, - RewardEvaluator, - UniEvaluator, -) -from graphgen.utils import logger, set_logger - -sys_path = os.path.abspath(os.path.dirname(__file__)) -set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log")) - -load_dotenv() - - -def evaluate_length(corpus, tokenizer_name): - length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name) - logger.info("Length evaluator loaded") - scores = length_evaluator.get_average_score(corpus) - logger.info("Length scores: %s", scores) - return scores - - -def evaluate_mtld(corpus): - mtld_evaluator = MTLDEvaluator() - logger.info("MTLD evaluator loaded") - scores = mtld_evaluator.get_average_score(corpus) - logger.info("MTLD scores: %s", scores) - min_max_scores = mtld_evaluator.get_min_max_score(corpus) - logger.info("MTLD min max scores: %s", min_max_scores) - return scores, min_max_scores - - -def evaluate_reward(corpus, reward_model_names): - scores = [] - for reward_name in reward_model_names: - reward_evaluator = RewardEvaluator(reward_name=reward_name) - logger.info("Loaded reward model: %s", reward_name) - average_score = reward_evaluator.get_average_score(corpus) - logger.info("%s scores: %s", reward_name, average_score) - min_max_scores = reward_evaluator.get_min_max_score(corpus) - logger.info("%s min max scores: %s", reward_name, min_max_scores) - scores.append( - { - "reward_name": reward_name.split("/")[-1], - "score": average_score, - "min_max_scores": min_max_scores, - } - ) - del reward_evaluator - clean_gpu_cache() - return scores - - -def evaluate_uni(corpus, uni_model_name): - uni_evaluator = UniEvaluator(model_name=uni_model_name) - logger.info("Uni evaluator loaded with model %s", uni_model_name) - uni_scores = uni_evaluator.get_average_score(corpus) - for key, value in uni_scores.items(): - logger.info("Uni %s scores: %s", key, value) - min_max_scores = uni_evaluator.get_min_max_score(corpus) - for key, value in min_max_scores.items(): - logger.info("Uni %s min max scores: %s", key, value) - del uni_evaluator - clean_gpu_cache() - return ( - uni_scores["naturalness"], - uni_scores["coherence"], - uni_scores["understandability"], - min_max_scores["naturalness"], - min_max_scores["coherence"], - min_max_scores["understandability"], - ) - - -def clean_gpu_cache(): - import torch - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -if __name__ == "__main__": - import torch.multiprocessing as mp - - parser = argparse.ArgumentParser() - - parser.add_argument( - "--folder", type=str, default="cache/data", help="folder to load data" - ) - parser.add_argument( - "--output", type=str, default="cache/output", help="path to save output" - ) - - parser.add_argument( - "--tokenizer", type=str, default="cl100k_base", help="tokenizer name" - ) - parser.add_argument( - "--reward", - type=str, - default="OpenAssistant/reward-model-deberta-v3-large-v2", - help="Comma-separated list of reward models", - ) - parser.add_argument( - "--uni", type=str, default="MingZhong/unieval-sum", help="uni model name" - ) - - args = parser.parse_args() - - if not os.path.exists(args.folder): - raise ValueError(f"Folder {args.folder} does not exist") - - if not os.path.exists(args.output): - os.makedirs(args.output) - - reward_models = args.reward.split(",") - - results = [] - - logger.info("Data loaded from %s", args.folder) - mp.set_start_method("spawn") - - for file in os.listdir(args.folder): - if file.endswith(".json"): - logger.info("Processing %s", file) - with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f: - data = json.load(f) - data = [ - QAPair(question=data[key]["question"], answer=data[key]["answer"]) - for key in data - ] - - length_scores = evaluate_length(data, args.tokenizer) - mtld_scores, min_max_mtld_scores = evaluate_mtld(data) - reward_scores = evaluate_reward(data, reward_models) - ( - uni_naturalness_scores, - uni_coherence_scores, - uni_understandability_scores, - min_max_uni_naturalness_scores, - min_max_uni_coherence_scores, - min_max_uni_understandability_scores, - ) = evaluate_uni(data, args.uni) - - result = { - "file": file, - "number": len(data), - "length": length_scores, - "mtld": mtld_scores, - "mtld_min_max": min_max_mtld_scores, - "uni_naturalness": uni_naturalness_scores, - "uni_coherence": uni_coherence_scores, - "uni_understandability": uni_understandability_scores, - "uni_naturalness_min_max": min_max_uni_naturalness_scores, - "uni_coherence_min_max": min_max_uni_coherence_scores, - "uni_understandability_min_max": min_max_uni_understandability_scores, - } - for reward_score in reward_scores: - result[reward_score["reward_name"]] = reward_score["score"] - result[f"{reward_score['reward_name']}_min_max"] = reward_score[ - "min_max_scores" - ] - - results.append(result) - - results = pd.DataFrame(results) - results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False) diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index cd4f1c78..77a8d59a 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -1,10 +1,9 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict import pandas as pd from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair -from graphgen.common import init_llm -from graphgen.models import KGQualityEvaluator +from graphgen.common import init_llm, init_storage from graphgen.utils import logger, run_concurrent @@ -19,64 +18,71 @@ def __init__( working_dir: str = "cache", metrics: list[str] = None, graph_backend: str = "kuzu", - kv_backend: str = "rocksdb", - **kwargs + **kwargs, ): super().__init__(working_dir=working_dir, op_name="evaluate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.metrics = metrics or [] self.kwargs = kwargs - self.graph_backend = graph_backend - self.kv_backend = kv_backend - - # Separate QA and KG metrics - self.qa_metrics = [m for m in self.metrics if m.startswith("qa_")] - self.kg_metrics = [m for m in self.metrics if m.startswith("kg_")] - + self.graph_backend = init_storage( + backend=graph_backend, working_dir=working_dir, namespace="graph" + ) + # Initialize evaluators self.qa_evaluators = {} - self.kg_evaluator = None - + self.kg_evaluators = {} self._init_evaluators() def _init_evaluators(self): """Initialize QA and KG evaluators based on metrics.""" - # Initialize QA evaluators - for metric in self.qa_metrics: + for metric in self.metrics: if metric == "qa_length": from graphgen.models import LengthEvaluator self.qa_evaluators[metric] = LengthEvaluator() elif metric == "qa_mtld": from graphgen.models import MTLDEvaluator + self.qa_evaluators[metric] = MTLDEvaluator( **self.kwargs.get("mtld_params", {}) ) elif metric == "qa_reward_score": from graphgen.models import RewardEvaluator + self.qa_evaluators[metric] = RewardEvaluator( **self.kwargs.get("reward_params", {}) ) elif metric == "qa_uni_score": from graphgen.models import UniEvaluator + self.qa_evaluators[metric] = UniEvaluator( **self.kwargs.get("uni_params", {}) ) + elif metric == "kg_accuracy": + from graphgen.models import AccuracyEvaluator + + self.kg_evaluators[metric] = AccuracyEvaluator( + graph_storage=self.graph_backend, + **self.kwargs.get("accuracy_params", {}), + ) + elif metric == "kg_consistency": + from graphgen.models import ConsistencyEvaluator + + self.kg_evaluators[metric] = ConsistencyEvaluator( + graph_storage=self.graph_backend, + **self.kwargs.get("consistency_params", {}), + ) + elif metric == "kg_structure": + from graphgen.models import StructureEvaluator + + self.kg_evaluators[metric] = StructureEvaluator( + graph_storage=self.graph_backend, + **self.kwargs.get("structure_params", {}), + ) else: raise ValueError(f"Unknown QA metric: {metric}") - - # Initialize KG evaluator if KG metrics are specified - if self.kg_metrics: - kg_params = self.kwargs.get("kg_params", {}) - self.kg_evaluator = KGQualityEvaluator( - working_dir=self.working_dir, - graph_backend=self.graph_backend, - kv_backend=self.kv_backend, - **kg_params - ) - logger.info("KG evaluator initialized") - async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: + async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]: try: qa_pair = QAPair( question=str(item.get("question", "")), @@ -102,35 +108,34 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]: item[metric] = None return item - @staticmethod - def transform_messages_format(items: list[dict]) -> list[dict]: - """ - Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...] - """ - transformed = [] - for item in items: - messages = item.get("messages", []) - question = next( - (m["content"] for m in messages if m.get("role") == "user"), "" - ) - answer = next( - (m["content"] for m in messages if m.get("role") == "assistant"), "" - ) + def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: + def transform_messages_format(items: list[dict]) -> list[dict]: + """ + Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...] + """ + transformed = [] + for item in items: + messages = item.get("messages", []) + question = next( + (m["content"] for m in messages if m.get("role") == "user"), "" + ) + answer = next( + (m["content"] for m in messages if m.get("role") == "assistant"), "" + ) - transformed.append({"question": question, "answer": answer}) - return transformed + transformed.append({"question": question, "answer": answer}) + return transformed - def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: if not items: return [] if not self.qa_evaluators: - logger.warning("No QA evaluators initialized, skipping QA evaluation") + self.logger.warning("No QA evaluators initialized, skipping QA evaluation") return [] - items = self.transform_messages_format(items) + items = transform_messages_format(items) results = run_concurrent( - self._process_single, + self._process_single_qa, items, desc="Evaluating QA items", unit="item", @@ -140,78 +145,31 @@ def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: return results def _evaluate_kg(self) -> Dict[str, Any]: - if not self.kg_evaluator: - logger.warning("No KG evaluator initialized, skipping KG evaluation") - return {} - results = {} - - # Map metric names to evaluation functions - kg_metric_map = { - "kg_accuracy": self.kg_evaluator.evaluate_accuracy, - "kg_consistency": self.kg_evaluator.evaluate_consistency, - "kg_structure": self.kg_evaluator.evaluate_structure, - } - - # Run KG evaluations based on metrics - for metric in self.kg_metrics: - if metric in kg_metric_map: - logger.info("Running %s evaluation...", metric) - metric_key = metric.replace("kg_", "") # Remove "kg_" prefix - try: - results[metric_key] = kg_metric_map[metric]() - except Exception as e: - logger.error("Error in %s evaluation: %s", metric, str(e)) - results[metric_key] = {"error": str(e)} - else: - logger.warning("Unknown KG metric: %s, skipping", metric) - - # If no valid metrics were found, run all evaluations - if not results: - logger.info("No valid KG metrics found, running all evaluations") - results = self.kg_evaluator.evaluate_all() - - return results - def evaluate( - self, items: list[dict[str, Any]] = None - ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: - # Determine evaluation type - has_qa_metrics = len(self.qa_metrics) > 0 - has_kg_metrics = len(self.kg_metrics) > 0 - - # If items provided and QA metrics exist, do QA evaluation - if items is not None and has_qa_metrics: - return self._evaluate_qa(items) - - # If KG metrics exist, do KG evaluation - if has_kg_metrics: - return self._evaluate_kg() - - # If no metrics specified, try to infer from context - if items is not None: - logger.warning("No QA metrics specified but items provided, skipping evaluation") - return [] - else: - logger.warning("No metrics specified, skipping evaluation") - return {} + for metric, evaluator in self.kg_evaluators.items(): + try: + self.logger.info("Running %s evaluation...", metric) + score = evaluator.evaluate() + results[metric] = score + except Exception as e: + self.logger.error("Error in %s evaluation: %s", metric, str(e)) + results[metric] = {"error": str(e)} + return results def process(self, batch: pd.DataFrame) -> pd.DataFrame: - has_qa_metrics = len(self.qa_metrics) > 0 - has_kg_metrics = len(self.kg_metrics) > 0 - - # QA evaluation: process batch items - if has_qa_metrics: + # QA evaluation + if len(self.qa_evaluators) > 0: items = batch.to_dict(orient="records") results = self._evaluate_qa(items) return pd.DataFrame(results) - - # KG evaluation: evaluate from storage - if has_kg_metrics: + + # KG evaluation + if len(self.kg_evaluators) > 0: results = self._evaluate_kg() # Convert dict to DataFrame (single row) return pd.DataFrame([results]) - + # No metrics specified logger.warning("No metrics specified, returning empty DataFrame") return pd.DataFrame() diff --git a/requirements.txt b/requirements.txt index 119c95e0..b0eb3966 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,6 @@ aiohttp socksio pydantic ray==2.52.1 -kuzu pyarrow leidenalg @@ -32,9 +31,11 @@ python-louvain # storage rocksdict +kuzu # KG rdflib +scipy # Bioinformatics biopython From e10b3917ccc08f5e9ccfc8dfc78ac91dec517e4a Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 26 Dec 2025 16:24:38 +0800 Subject: [PATCH 29/29] feat: add kg_accuracy & kg_consistency metrics --- .../evaluate_kg/kg_evaluation_config.yaml | 4 ++-- .../models/evaluator/kg/accuracy_evaluator.py | 3 ++- graphgen/operators/evaluate/evaluate_kg.py | 0 .../operators/evaluate/evaluate_service.py | 18 ++++++++++++------ 4 files changed, 16 insertions(+), 9 deletions(-) delete mode 100644 graphgen/operators/evaluate/evaluate_kg.py diff --git a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml index 41d88c44..d86d01b1 100644 --- a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml +++ b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml @@ -41,5 +41,5 @@ nodes: params: metrics: - kg_structure -# - kg_accuracy -# - kg_consistency + - kg_accuracy + - kg_consistency diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py index a0d87f9b..9663b6f8 100644 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ b/graphgen/models/evaluator/kg/accuracy_evaluator.py @@ -279,8 +279,9 @@ def _evaluate_relation_extraction( "issues": [f"Evaluation error: {str(e)}"], } + @staticmethod def _aggregate_evaluation_results( - self, entity_evaluations: List[Dict], relation_evaluations: List[Dict] + entity_evaluations: List[Dict], relation_evaluations: List[Dict] ) -> Dict[str, Any]: """Aggregate evaluation results from all chunks.""" diff --git a/graphgen/operators/evaluate/evaluate_kg.py b/graphgen/operators/evaluate/evaluate_kg.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index 77a8d59a..b0875d7f 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -18,15 +18,19 @@ def __init__( working_dir: str = "cache", metrics: list[str] = None, graph_backend: str = "kuzu", + kv_backend: str = "rocksdb", **kwargs, ): super().__init__(working_dir=working_dir, op_name="evaluate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.metrics = metrics or [] self.kwargs = kwargs - self.graph_backend = init_storage( + self.graph_storage = init_storage( backend=graph_backend, working_dir=working_dir, namespace="graph" ) + self.chunk_storage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="chunk" + ) # Initialize evaluators self.qa_evaluators = {} @@ -62,21 +66,23 @@ def _init_evaluators(self): from graphgen.models import AccuracyEvaluator self.kg_evaluators[metric] = AccuracyEvaluator( - graph_storage=self.graph_backend, - **self.kwargs.get("accuracy_params", {}), + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, ) elif metric == "kg_consistency": from graphgen.models import ConsistencyEvaluator self.kg_evaluators[metric] = ConsistencyEvaluator( - graph_storage=self.graph_backend, - **self.kwargs.get("consistency_params", {}), + graph_storage=self.graph_storage, + chunk_storage=self.chunk_storage, + llm_client=self.llm_client, ) elif metric == "kg_structure": from graphgen.models import StructureEvaluator self.kg_evaluators[metric] = StructureEvaluator( - graph_storage=self.graph_backend, + graph_storage=self.graph_storage, **self.kwargs.get("structure_params", {}), ) else: