diff --git a/app.py b/app.py
index 76994ec..ed59c26 100644
--- a/app.py
+++ b/app.py
@@ -1,8 +1,8 @@
+
import os
import tempfile
-
import streamlit as st
-
+from src.utils.pdf_annotator import highlight_contract_risks
from src.data_pipeline.pipeline import DataPipeline
from src.inference.predictor import (
build_xlsx_export,
@@ -21,6 +21,20 @@
render_risk_card,
render_risk_heatmap,
)
+def cleanup_files():
+ try:
+ if os.path.exists("marked_contract.pdf"):
+ os.remove("marked_contract.pdf")
+ if "original_pdf_path" in st.session_state and st.session_state.original_pdf_path:
+ old_path = st.session_state.original_pdf_path
+ if os.path.exists(old_path):
+ os.remove(old_path)
+ st.session_state.original_pdf_path = None
+
+ st.session_state.pdf_ready = False
+
+ except Exception as e:
+ print(f"Silent Cleanup: {e}")
st.set_page_config(
page_title="ContraLegal - Risk Dashboard",
@@ -305,9 +319,6 @@
unsafe_allow_html=True,
)
-# ---------------------------------------------------------------------------
-# Session state init
-# ---------------------------------------------------------------------------
if "analyzed_df" not in st.session_state:
st.session_state.analyzed_df = None
if "contract_risk" not in st.session_state:
@@ -337,9 +348,7 @@
if "rewrite_results" not in st.session_state:
st.session_state.rewrite_results = {}
-# ---------------------------------------------------------------------------
-# Sidebar — Navigation + Settings
-# ---------------------------------------------------------------------------
+
with st.sidebar:
st.markdown("### ContraLegal")
st.caption("Intelligent Contract Risk Analysis")
@@ -419,9 +428,7 @@
except Exception as e:
st.sidebar.warning(f"Could not connect to {llm_provider}: {e}")
-# ---------------------------------------------------------------------------
-# Model load
-# ---------------------------------------------------------------------------
+
vectorizer, model = load_model()
if vectorizer is None or model is None:
@@ -432,9 +439,7 @@
)
st.stop()
-# ---------------------------------------------------------------------------
-# Brand header + Input (always visible)
-# ---------------------------------------------------------------------------
+
st.markdown('
ContraLegal
', unsafe_allow_html=True)
st.markdown(
'Intelligent Contract Risk Analysis
',
@@ -488,9 +493,7 @@
st.markdown('
', unsafe_allow_html=True)
-# ---------------------------------------------------------------------------
-# Input parsing
-# ---------------------------------------------------------------------------
+
if analyze_btn_upload or analyze_btn_paste or analyze_btn_demo:
clauses = []
@@ -508,14 +511,15 @@
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(uploaded_file.getvalue())
tmp_path = tmp.name
+ st.session_state.original_pdf_path = tmp_path
+
try:
from src.data_pipeline.pdf_extractor import PDFExtractor
st.session_state.raw_contract_text = PDFExtractor().extract_text(tmp_path)
pipeline = DataPipeline()
clauses = pipeline.process_document(tmp_path)
- finally:
- if os.path.exists(tmp_path):
- os.remove(tmp_path)
+ except Exception as e:
+ st.error(f"Error processing PDF: {e}")
elif analyze_btn_paste and pasted_text.strip():
st.session_state.raw_contract_text = pasted_text.strip()
@@ -538,7 +542,8 @@
st.session_state.analyzed_df = r_df
st.session_state.contract_risk = c_risk
- st.session_state.summary_clauses = s_clauses
+ st.session_state.summary_clauses = s_clauses
+ st.session_state.pdf_ready = False
# Build RAG index for the new document
if ai_enabled and st.session_state.raw_contract_text:
@@ -573,9 +578,7 @@
st.session_state.current_view = "dashboard"
st.rerun()
-# ===========================================================================
-# VIEW: Risk Dashboard
-# ===========================================================================
+
if st.session_state.analyzed_df is not None and st.session_state.current_view == "dashboard":
results_df = st.session_state.analyzed_df
contract_risk = st.session_state.contract_risk
@@ -645,7 +648,7 @@
st.markdown(f"- {row['clause_text']}")
st.markdown('
', unsafe_allow_html=True)
- dl_col1, dl_col2 = st.columns(2)
+ dl_col1, dl_col2, dl_col3= st.columns(3)
with dl_col1:
csv = results_df.drop(columns=["keyword_matches"], errors="ignore").to_csv(index=False).encode("utf-8")
@@ -666,10 +669,49 @@
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
use_container_width=True,
)
+ with dl_col3:
+ if "pdf_ready" not in st.session_state:
+ st.session_state.pdf_ready = False
+
+ if st.session_state.get("original_pdf_path") and os.path.exists(st.session_state.original_pdf_path):
+
+ output_pdf_path = "marked_contract.pdf"
+
+ risky_data = [
+ {"text": row["clause_text"], "risk": row["risk_label"]}
+ for _, row in results_df[
+ results_df["risk_label"].isin(["High Risk", "Medium Risk"])
+ ].iterrows()
+ ]
+
+ if not st.session_state.pdf_ready:
+ if st.button("Generate Marked-up PDF", use_container_width=True):
+ with st.spinner("Applying spatial highlights..."):
+ try:
+ highlight_contract_risks(
+ st.session_state.original_pdf_path,
+ output_pdf_path,
+ risky_data
+ )
+ st.session_state.pdf_ready = True
+ st.success("Highlights applied successfully!")
+ except Exception as e:
+ st.error("Could not apply highlights. PDF format may not be supported.")
+ print(e)
+
+ if st.session_state.pdf_ready:
+ with open(output_pdf_path, "rb") as f:
+ st.download_button(
+ label="Download Highlighted PDF",
+ data=f,
+ file_name="ContraLegal_Spatial_Analysis.pdf",
+ mime="application/pdf",
+ use_container_width=True,
+ on_click=cleanup_files
+ )
+ else:
+ st.caption("Upload a PDF to enable highlighted export.")
-# ===========================================================================
-# VIEW: AI Assistant
-# ===========================================================================
elif st.session_state.analyzed_df is not None and st.session_state.current_view == "assistant":
results_df = st.session_state.analyzed_df
@@ -681,7 +723,7 @@
unsafe_allow_html=True,
)
- # -- Clause Analyzer --
+ #Clause analyzer
st.markdown('
', unsafe_allow_html=True)
st.markdown("#### Clause Analyzer")
st.caption("Select a risky clause to get an AI explanation or a fairer rewrite.")
@@ -750,7 +792,7 @@
with st.expander("AI Suggested Rewrite", expanded=True):
st.markdown(st.session_state.rewrite_results[selected_idx])
- # -- Chat with Contract --
+ #Chat with Contract
if st.session_state.chat_chain is not None:
st.markdown('
', unsafe_allow_html=True)
st.markdown("#### Chat with Your Contract")
@@ -809,9 +851,7 @@
except Exception as e:
st.error(f"AI error: {e}")
-# ===========================================================================
-# No analysis yet
-# ===========================================================================
+#no ananlysis yet
elif st.session_state.analyzed_df is None:
st.markdown(
''
@@ -819,3 +859,5 @@
"
",
unsafe_allow_html=True,
)
+
+
diff --git a/requirements.txt b/requirements.txt
index 6e5c742..e95a297 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,17 +6,17 @@ pdfplumber>=0.10.0
spacy
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
-# Machine Learning
+# Machine Learning & Data
scikit-learn>=1.3.0
numpy>=1.24.0
pandas>=2.0.0
# Deep Learning (Legal-BERT)
-torch>=2.0.0
+torch>=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu
transformers>=4.36.0
accelerate>=0.25.0
-# UI
+# UI & Viz
streamlit>=1.30.0
plotly>=5.18.0
matplotlib>=3.7.0
@@ -25,17 +25,17 @@ matplotlib>=3.7.0
regex
tqdm
openpyxl>=3.1.0
+python-dotenv
+huggingface_hub
# Generative AI & RAG
langchain>=0.3.0
+langchain-core>=0.3.0
+langchain-community>=0.3.0
+langchain-text-splitters
langchain-google-genai>=2.0.0
langchain-openai>=0.2.0
-langchain-community>=0.3.0
langchain-groq>=0.2.0
-langchain-huggingface>=1.0.0
+langchain-huggingface>=0.1.0
sentence-transformers>=3.0.0
-faiss-cpu>=1.8.0
-
-
-langchain-text-splitters
-python-dotenv
\ No newline at end of file
+faiss-cpu>=1.8.0
\ No newline at end of file
diff --git a/src/data_pipeline/clause_segment.py b/src/data_pipeline/clause_segment.py
index a37f4b6..1c38d0e 100644
--- a/src/data_pipeline/clause_segment.py
+++ b/src/data_pipeline/clause_segment.py
@@ -1,5 +1,5 @@
from typing import List
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_text_splitters import RecursiveCharacterTextSplitter
class ClauseSegmenter:
def __init__(self):
diff --git a/src/inference/llm_engine.py b/src/inference/llm_engine.py
index 9974f06..1535bd3 100644
--- a/src/inference/llm_engine.py
+++ b/src/inference/llm_engine.py
@@ -8,17 +8,15 @@
- Clause explainer: explains WHY a clause is risky
- Clause rewriter: suggests safer alternative language
"""
-
from __future__ import annotations
-
from typing import List, Tuple
-
import pandas as pd
-
import streamlit as st
-from langchain_classic.chains import ConversationalRetrievalChain
-from langchain_text_splitters import RecursiveCharacterTextSplitter
-from langchain_huggingface import HuggingFaceEmbeddings
+
+# Fixed Imports
+from langchain.chains import ConversationalRetrievalChain # Fixed from langchain_classic
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_community.embeddings import HuggingFaceEmbeddings # Adjusted for stability
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import (
ChatPromptTemplate,