This repository was archived by the owner on Aug 27, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
274 lines (220 loc) · 11.6 KB
/
main.py
File metadata and controls
274 lines (220 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import os
import time
import logging
import argparse
import traceback
from tqdm import tqdm
from dotenv import load_dotenv
from utils.data_extractor import DataExtractor
from opwebui.api_client import OpenWebUIClient
from utils.query_enhancer import QueryEnchancer
from metrics.metrics_evaluator import ScoreCalculator, SolutionMatcher
from utils.evaluation_utils import (
generate_report,
assess_response_quality,
export_report_to_excel
)
load_dotenv()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
DATA_PATH = os.path.abspath(os.getenv("DATA_DIR_PATH"))
QUESTION_PATH = os.path.join(DATA_PATH, os.getenv("QUESTION_EXCEL"))
SOLUTION_PATH = os.path.join(DATA_PATH, os.getenv("SOLUTION_EXCEL"))
QUESTION_SHEET_NAME = os.getenv("QUESTION_SHEET_NAME")
if not DATA_PATH or not QUESTION_PATH or not SOLUTION_PATH:
raise ValueError("DATA_DIR_PATH, QUESTION_EXCEL, SOLUTION_EXCEL must be set in the environment variables.")
def display_results(question, model_response, best_solution, metrics):
"""
Display formatted results of model response evaluation.
Args:
question: The Question object
model_response: The model's generated response text
best_solution: The best matching Solution object
metrics: Dictionary of evaluation metrics
"""
# Log the best match
logging.info(f"Best match for question {question.id} with F1={metrics['bert_f1']:.4f}")
# Print formatted output
print(f"\n{'='*50}")
print(f"=== Question {question.id} ===")
print(f"Issue: {question.issue[:100]}..." if len(question.issue) > 100 else f"Issue: {question.issue}")
print(f"\n=== Model Response ===")
response_preview = model_response[:300] + "..." if len(model_response) > 300 else model_response
print(response_preview)
print(f"\n=== Best Matching Solution ({best_solution.id}) ===")
print(f"Title: {best_solution.title}")
# Print up to 5 steps
for j, step in enumerate(best_solution.steps[:5]):
print(f" {j+1}. {step}")
# Show ellipsis if there are more steps
if len(best_solution.steps) > 5:
print(f" ...(+{len(best_solution.steps) - 5} more steps)")
# Print metrics
print(f"\nEvaluation Metrics:")
print(f" BERTScore: {metrics['bert_f1']:.4f} (P={metrics['bert_precision']:.4f}, R={metrics['bert_recall']:.4f})")
print(f" F1 Score: {metrics['trad_f1']:.4f}")
print(f" BLEU: {metrics['bleu']:.4f}")
print(f"\n{'='*50}\n")
def main():
try:
# Parse command line arguments
args = parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
logging.debug("Verbose mode enabled")
# Ensure data directory exists, create if not
os.makedirs(DATA_PATH, exist_ok=True)
if not os.path.exists(QUESTION_PATH) or not os.path.exists(SOLUTION_PATH):
logging.error(f"One or more required files are missing. Please check that both files exist in the data directory.")
return
# Initialize the data extractor
extractor = DataExtractor(
questions_path=QUESTION_PATH,
answers_path=SOLUTION_PATH,
questions_config={
"sheet_name": QUESTION_SHEET_NAME,
"header_row": 2,
"issue_col": 'B',
'solutions_col': 'C',
'ai_solutions_col': 'D'
},
answers_config={
"header_row": 0,
"title_col": 'A',
'steps_col': 'B'
}
)
# Load and parse the data
extractor.load_and_parse_data()
# Get the parsed data
questions = extractor.get_questions()
solutions = extractor.get_solutions()
# Initialize the score calculator and matcher
score_calculator = ScoreCalculator()
matcher = SolutionMatcher(score_calculator)
# Prepare metrics storage for report generation
metrics_by_question = {}
# Use the limit argument if provided
if args.question_id:
question_to_process = [q for q in questions if q.id == args.question_id]
if not question_to_process:
logging.error(f"Question ID {args.question_id} not found.")
return
else:
question_to_process = questions if args.limit <= 0 else questions[:args.limit]
total_questions = len(question_to_process)
query_enchancer = QueryEnchancer()
for i, question in enumerate(tqdm(question_to_process, desc="\nProcessing questions", unit="question")):
logging.info(f"Processing question {i+1}/{total_questions} (ID: {question.id})")
# Get model response for this question
client = OpenWebUIClient()
prompt = question.issue
# Enhance the query if specified
if args.pre_process:
enchanced_prompt = query_enchancer.pre_process(prompt)
logging.info(f"Pre-request enhanced prompt: {enchanced_prompt[:100]}...")
prompt = enchanced_prompt
logging.info(f"Sending prompt to model: {prompt[:50]}...")
response = client.chat_with_model(prompt)
if not response:
logging.error(f"No response received for question {question.id}")
continue
model_response = response.choices[0].message.content
logging.info(f"Received model response: {len(model_response)} chars")
# Find the associated solution(s) based on solutions_used or ai_solutions_used field
if len(question.ai_solutions_used) > 0:
# Prefer AI solutions if specified
solution_indices = question.ai_solutions_used
logging.info(f"Using AI solutions {solution_indices} for question {question.id}")
elif len(question.solutions_used) > 0:
# Fall back to regular solutions
solution_indices = question.solutions_used
logging.info(f"Using regular solutions {solution_indices} for question {question.id}")
else:
# If no solutions are marked, compare with all solutions
solution_indices = list(range(1, len(solutions) + 1)) # Use 1-based indices to match Excel
logging.info(f"No specific solution marked for question {question.id}, comparing with all solutions")
# Filter out invalid indices (ensure they're 0-based for array indexing)
valid_indices = [i-1 for i in solution_indices if 1 <= i <= len(solutions)]
solutions_to_compare = [solutions[i] for i in valid_indices]
# Skip if no valid solutions to compare against
if not solutions_to_compare:
logging.warning(f"No valid solutions found for question {question.id}")
continue
best_solution, metrics = matcher.find_best_solution(
model_response,
solutions_to_compare
)
question.bert_score = metrics['bert_f1']
question.f1_score = metrics['trad_f1']
question.bleu_score = metrics['bleu']
# Store metrics for report generation
metrics_by_question[question.id] = {
'metrics': metrics,
'best_solution_id': best_solution.id,
'model_response': model_response
}
# Check quality and potentially improve prompt
is_acceptable, feedback = assess_response_quality(
metrics,
bert_threshold=args.bert_threshold,
f1_threshold=args.f1_threshold,
bleu_threshold=args.bleu_threshold,
combined_threshold=args.combined_threshold
)
# INCOMPLETE - for future use
if not is_acceptable and args.post_process:
# Log the feedback and warning
logging.warning(f"Question {question.id} response quality below threshold")
logging.info(f"{feedback}\n")
# Generate improved prompt for future use
improved_prompt = query_enchancer.post_process(
prompt,
metrics,
bert_threshold=args.bert_threshold,
f1_threshold=args.f1_threshold,
bleu_threshold=args.bleu_threshold,
combined_threshold=args.combined_threshold
)
logging.info(f"Post-request improved prompt: {improved_prompt[:100]}...")
if args.verbose:
display_results(question, model_response, best_solution, metrics)
time.sleep(args.wait_time)
# Generate comprehensive report after all questions processed
if metrics_by_question and not args.skip_report:
report_path = generate_report(questions, solutions, metrics_by_question,
output_dir=args.report_dir,
bert_threshold=args.bert_threshold,
f1_threshold=args.f1_threshold,
bleu_threshold=args.bleu_threshold,
combined_threshold=args.combined_threshold)
if args.export_excel and report_path:
excel_path = export_report_to_excel(report_path)
logging.info(f"Evaluation metrics exported to Excel: {excel_path}")
logging.info(f"Evaluation report generated: {report_path}")
except Exception as e:
logging.error(f"An error occurred: {e}")
traceback.print_exc()
return
def parse_args():
"""
Parse command line arguments.
Returns:
Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Knowledge Base Answer Scorer")
parser.add_argument("--bert-threshold", "--bt", type=float, default=0.5, help="BERT score threshold for quality assessment")
parser.add_argument("--f1-threshold", "--f1", type=float, default=0.3, help="F1 score threshold for quality assessment")
parser.add_argument("--bleu-threshold", "--bl", type=float, default=0.1, help="BLEU score threshold for quality assessment")
parser.add_argument("--combined-threshold", "--ct", type=float, default=0.4, help="Combined score threshold for quality assessment")
parser.add_argument("--limit", "--l", type=int, default=0, help="Limit the number of questions to process")
parser.add_argument("--question-id", "--id", type=str, help="Process only a specific question ID")
parser.add_argument("--pre-process", "--pre", action="store_true", help="Enable query enhancement before sending to model")
parser.add_argument("--post-process", "--post", action="store_true", help="Enable query enhancement after receiving model response")
parser.add_argument("--verbose", "--v", action="store_true", help="Display detailed logs")
parser.add_argument("--report-dir", "--rd", type=str, default="reports", help="Directory to save reports")
parser.add_argument("--wait-time", "--wt", type=float, default=1.0, help="Wait time between API calls in seconds")
parser.add_argument("--skip-report", "--sr", action="store_true", help="Skip report generation")
parser.add_argument("--export-excel", "--ee", action="store_true", help="Export evaluation report to Excel")
return parser.parse_args()
if __name__ == "__main__":
main()