diff --git a/app.py b/app.py index 76994ec..ed59c26 100644 --- a/app.py +++ b/app.py @@ -1,8 +1,8 @@ + import os import tempfile - import streamlit as st - +from src.utils.pdf_annotator import highlight_contract_risks from src.data_pipeline.pipeline import DataPipeline from src.inference.predictor import ( build_xlsx_export, @@ -21,6 +21,20 @@ render_risk_card, render_risk_heatmap, ) +def cleanup_files(): + try: + if os.path.exists("marked_contract.pdf"): + os.remove("marked_contract.pdf") + if "original_pdf_path" in st.session_state and st.session_state.original_pdf_path: + old_path = st.session_state.original_pdf_path + if os.path.exists(old_path): + os.remove(old_path) + st.session_state.original_pdf_path = None + + st.session_state.pdf_ready = False + + except Exception as e: + print(f"Silent Cleanup: {e}") st.set_page_config( page_title="ContraLegal - Risk Dashboard", @@ -305,9 +319,6 @@ unsafe_allow_html=True, ) -# --------------------------------------------------------------------------- -# Session state init -# --------------------------------------------------------------------------- if "analyzed_df" not in st.session_state: st.session_state.analyzed_df = None if "contract_risk" not in st.session_state: @@ -337,9 +348,7 @@ if "rewrite_results" not in st.session_state: st.session_state.rewrite_results = {} -# --------------------------------------------------------------------------- -# Sidebar — Navigation + Settings -# --------------------------------------------------------------------------- + with st.sidebar: st.markdown("### ContraLegal") st.caption("Intelligent Contract Risk Analysis") @@ -419,9 +428,7 @@ except Exception as e: st.sidebar.warning(f"Could not connect to {llm_provider}: {e}") -# --------------------------------------------------------------------------- -# Model load -# --------------------------------------------------------------------------- + vectorizer, model = load_model() if vectorizer is None or model is None: @@ -432,9 +439,7 @@ ) st.stop() -# --------------------------------------------------------------------------- -# Brand header + Input (always visible) -# --------------------------------------------------------------------------- + st.markdown('
ContraLegal
', unsafe_allow_html=True) st.markdown( '
Intelligent Contract Risk Analysis
', @@ -488,9 +493,7 @@ st.markdown('
', unsafe_allow_html=True) -# --------------------------------------------------------------------------- -# Input parsing -# --------------------------------------------------------------------------- + if analyze_btn_upload or analyze_btn_paste or analyze_btn_demo: clauses = [] @@ -508,14 +511,15 @@ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: tmp.write(uploaded_file.getvalue()) tmp_path = tmp.name + st.session_state.original_pdf_path = tmp_path + try: from src.data_pipeline.pdf_extractor import PDFExtractor st.session_state.raw_contract_text = PDFExtractor().extract_text(tmp_path) pipeline = DataPipeline() clauses = pipeline.process_document(tmp_path) - finally: - if os.path.exists(tmp_path): - os.remove(tmp_path) + except Exception as e: + st.error(f"Error processing PDF: {e}") elif analyze_btn_paste and pasted_text.strip(): st.session_state.raw_contract_text = pasted_text.strip() @@ -538,7 +542,8 @@ st.session_state.analyzed_df = r_df st.session_state.contract_risk = c_risk - st.session_state.summary_clauses = s_clauses + st.session_state.summary_clauses = s_clauses + st.session_state.pdf_ready = False # Build RAG index for the new document if ai_enabled and st.session_state.raw_contract_text: @@ -573,9 +578,7 @@ st.session_state.current_view = "dashboard" st.rerun() -# =========================================================================== -# VIEW: Risk Dashboard -# =========================================================================== + if st.session_state.analyzed_df is not None and st.session_state.current_view == "dashboard": results_df = st.session_state.analyzed_df contract_risk = st.session_state.contract_risk @@ -645,7 +648,7 @@ st.markdown(f"- {row['clause_text']}") st.markdown('
', unsafe_allow_html=True) - dl_col1, dl_col2 = st.columns(2) + dl_col1, dl_col2, dl_col3= st.columns(3) with dl_col1: csv = results_df.drop(columns=["keyword_matches"], errors="ignore").to_csv(index=False).encode("utf-8") @@ -666,10 +669,49 @@ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", use_container_width=True, ) + with dl_col3: + if "pdf_ready" not in st.session_state: + st.session_state.pdf_ready = False + + if st.session_state.get("original_pdf_path") and os.path.exists(st.session_state.original_pdf_path): + + output_pdf_path = "marked_contract.pdf" + + risky_data = [ + {"text": row["clause_text"], "risk": row["risk_label"]} + for _, row in results_df[ + results_df["risk_label"].isin(["High Risk", "Medium Risk"]) + ].iterrows() + ] + + if not st.session_state.pdf_ready: + if st.button("Generate Marked-up PDF", use_container_width=True): + with st.spinner("Applying spatial highlights..."): + try: + highlight_contract_risks( + st.session_state.original_pdf_path, + output_pdf_path, + risky_data + ) + st.session_state.pdf_ready = True + st.success("Highlights applied successfully!") + except Exception as e: + st.error("Could not apply highlights. PDF format may not be supported.") + print(e) + + if st.session_state.pdf_ready: + with open(output_pdf_path, "rb") as f: + st.download_button( + label="Download Highlighted PDF", + data=f, + file_name="ContraLegal_Spatial_Analysis.pdf", + mime="application/pdf", + use_container_width=True, + on_click=cleanup_files + ) + else: + st.caption("Upload a PDF to enable highlighted export.") -# =========================================================================== -# VIEW: AI Assistant -# =========================================================================== elif st.session_state.analyzed_df is not None and st.session_state.current_view == "assistant": results_df = st.session_state.analyzed_df @@ -681,7 +723,7 @@ unsafe_allow_html=True, ) - # -- Clause Analyzer -- + #Clause analyzer st.markdown('
', unsafe_allow_html=True) st.markdown("#### Clause Analyzer") st.caption("Select a risky clause to get an AI explanation or a fairer rewrite.") @@ -750,7 +792,7 @@ with st.expander("AI Suggested Rewrite", expanded=True): st.markdown(st.session_state.rewrite_results[selected_idx]) - # -- Chat with Contract -- + #Chat with Contract if st.session_state.chat_chain is not None: st.markdown('
', unsafe_allow_html=True) st.markdown("#### Chat with Your Contract") @@ -809,9 +851,7 @@ except Exception as e: st.error(f"AI error: {e}") -# =========================================================================== -# No analysis yet -# =========================================================================== +#no ananlysis yet elif st.session_state.analyzed_df is None: st.markdown( '
' @@ -819,3 +859,5 @@ "
", unsafe_allow_html=True, ) + + diff --git a/requirements.txt b/requirements.txt index 6e5c742..e95a297 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,17 +6,17 @@ pdfplumber>=0.10.0 spacy en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl -# Machine Learning +# Machine Learning & Data scikit-learn>=1.3.0 numpy>=1.24.0 pandas>=2.0.0 # Deep Learning (Legal-BERT) -torch>=2.0.0 +torch>=2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu transformers>=4.36.0 accelerate>=0.25.0 -# UI +# UI & Viz streamlit>=1.30.0 plotly>=5.18.0 matplotlib>=3.7.0 @@ -25,17 +25,17 @@ matplotlib>=3.7.0 regex tqdm openpyxl>=3.1.0 +python-dotenv +huggingface_hub # Generative AI & RAG langchain>=0.3.0 +langchain-core>=0.3.0 +langchain-community>=0.3.0 +langchain-text-splitters langchain-google-genai>=2.0.0 langchain-openai>=0.2.0 -langchain-community>=0.3.0 langchain-groq>=0.2.0 -langchain-huggingface>=1.0.0 +langchain-huggingface>=0.1.0 sentence-transformers>=3.0.0 -faiss-cpu>=1.8.0 - - -langchain-text-splitters -python-dotenv \ No newline at end of file +faiss-cpu>=1.8.0 \ No newline at end of file diff --git a/src/data_pipeline/clause_segment.py b/src/data_pipeline/clause_segment.py index a37f4b6..1c38d0e 100644 --- a/src/data_pipeline/clause_segment.py +++ b/src/data_pipeline/clause_segment.py @@ -1,5 +1,5 @@ from typing import List -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_text_splitters import RecursiveCharacterTextSplitter class ClauseSegmenter: def __init__(self): diff --git a/src/inference/llm_engine.py b/src/inference/llm_engine.py index 9974f06..1535bd3 100644 --- a/src/inference/llm_engine.py +++ b/src/inference/llm_engine.py @@ -8,17 +8,15 @@ - Clause explainer: explains WHY a clause is risky - Clause rewriter: suggests safer alternative language """ - from __future__ import annotations - from typing import List, Tuple - import pandas as pd - import streamlit as st -from langchain_classic.chains import ConversationalRetrievalChain -from langchain_text_splitters import RecursiveCharacterTextSplitter -from langchain_huggingface import HuggingFaceEmbeddings + +# Fixed Imports +from langchain.chains import ConversationalRetrievalChain # Fixed from langchain_classic +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_community.embeddings import HuggingFaceEmbeddings # Adjusted for stability from langchain_community.vectorstores import FAISS from langchain_core.prompts import ( ChatPromptTemplate,