-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathprovenance.py
More file actions
170 lines (135 loc) · 6.08 KB
/
Copy pathprovenance.py
File metadata and controls
170 lines (135 loc) · 6.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""Provenance: stamp every node/edge with where, who, when, and why.
This module is the answer to "I see edge X → relation → Y in the graph,
how do I know it's real?". Every extractor — rule or LLM — calls
`attach_provenance()` after producing its raw `{nodes, edges}`. The result
satisfies the canonical `{nodes, edges}` graph contract **plus**
the strict graphanything provenance contract.
Fields added:
Node:
extractor_id str ← which extractor (e.g. 'markdown', 'llm-entity')
extractor_version str ← version of the extractor
extraction_time str ← ISO8601 UTC, second precision
source_hash str ← sha256(source_file content), first 16 chars
Edge:
everything above, PLUS (LLM-only):
rationale str ← ≤120 char explanation
evidence_span str ← verbatim source span the model cited
prompt_hash str ← sha256(prompt) first 16 chars (for audit)
The contract requires `confidence ∈ {EXTRACTED, INFERRED, AMBIGUOUS}` on
every edge. LLM-extracted edges should mark `confidence='INFERRED'`
unless they have both `rationale` and `evidence_span`, in which case
`EXTRACTED` is fine.
`assert_strict()` here enforces provenance presence; downstream consumers
accept both stamped and bare `{nodes, edges}` graphs.
"""
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
def _hash_file(path: Path | str, *, length: int = 16) -> str:
"""sha256 of file content, first `length` hex chars. Returns '' on miss."""
p = Path(path)
if not p.exists() or not p.is_file():
return ""
h = hashlib.sha256()
with p.open("rb") as f:
for chunk in iter(lambda: f.read(64 * 1024), b""):
h.update(chunk)
return h.hexdigest()[:length]
def hash_text(text: str, *, length: int = 16) -> str:
"""sha256 of an arbitrary string (used for prompt_hash)."""
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:length]
@dataclass
class ProvenanceStamp:
"""One stamp applied to all nodes/edges from one extractor invocation."""
extractor_id: str
extractor_version: str
source_file: str
extraction_time: str = ""
source_hash: str = ""
prompt_hash: str = ""
def __post_init__(self) -> None:
if not self.extraction_time:
self.extraction_time = _utc_now_iso()
if not self.source_hash and self.source_file:
self.source_hash = _hash_file(self.source_file)
def as_dict(self, *, include_prompt: bool = False) -> dict[str, Any]:
out = {
"extractor_id": self.extractor_id,
"extractor_version": self.extractor_version,
"extraction_time": self.extraction_time,
"source_hash": self.source_hash,
}
if include_prompt and self.prompt_hash:
out["prompt_hash"] = self.prompt_hash
return out
def attach_provenance(
extraction: dict[str, Any],
stamp: ProvenanceStamp,
*,
edge_extra: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Stamp every node and edge with `stamp`. Existing fields win.
Returns a NEW dict; input is not mutated. `edge_extra` is merged into
each edge after the stamp (used by LLM extractors to attach
rationale/evidence_span without colliding with stamp keys).
"""
base = stamp.as_dict()
nodes_in = extraction.get("nodes") or []
edges_in = extraction.get("edges") or extraction.get("links") or []
nodes_out = []
for n in nodes_in:
merged = {**base, **dict(n)} # existing node fields win over stamp
merged.setdefault("source_file", stamp.source_file)
nodes_out.append(merged)
edges_out = []
edge_layer = {**base}
if stamp.prompt_hash:
edge_layer["prompt_hash"] = stamp.prompt_hash
if edge_extra:
edge_layer.update(edge_extra)
for e in edges_in:
merged = {**edge_layer, **dict(e)}
merged.setdefault("source_file", stamp.source_file)
edges_out.append(merged)
return {"nodes": nodes_out, "edges": edges_out}
# ---------------------------------------------------------------------------
# Strict validation (graphanything contract — superset of legacy validate.py)
# ---------------------------------------------------------------------------
REQUIRED_PROVENANCE = {"extractor_id", "extractor_version", "extraction_time"}
def validate_provenance(data: dict[str, Any], *, llm_only: bool = False) -> list[str]:
"""Return a list of missing-provenance errors. Empty list = pass.
`llm_only=True` additionally requires `rationale` + `evidence_span` on
each edge — appropriate for the LLM extractor's self-check, NOT for
rule-based extractors (where evidence is the source line itself).
"""
errors: list[str] = []
nodes = data.get("nodes", [])
edges = data.get("edges") or data.get("links") or []
for i, n in enumerate(nodes):
for k in REQUIRED_PROVENANCE:
if not n.get(k):
errors.append(f"node[{i}] (id={n.get('id', '?')!r}) missing provenance '{k}'")
for i, e in enumerate(edges):
for k in REQUIRED_PROVENANCE:
if not e.get(k):
errors.append(f"edge[{i}] missing provenance '{k}'")
if llm_only:
if not e.get("rationale"):
errors.append(f"edge[{i}] (LLM) missing 'rationale'")
if not e.get("evidence_span"):
errors.append(f"edge[{i}] (LLM) missing 'evidence_span'")
return errors
def assert_strict(data: dict[str, Any], *, llm_only: bool = False) -> None:
"""Raise ValueError if any provenance field is missing."""
errs = validate_provenance(data, llm_only=llm_only)
if errs:
raise ValueError(
f"Provenance check failed ({len(errs)} issues):\n"
+ "\n".join(f" • {e}" for e in errs[:15])
+ (f"\n • ... ({len(errs) - 15} more)" if len(errs) > 15 else "")
)