From 21b605a08ec641c2e7fe0e8fc81f4b77e26aa275 Mon Sep 17 00:00:00 2001 From: Joe Shamon Date: Sat, 7 Feb 2026 01:25:30 +0000 Subject: [PATCH] Add multi-model eval support and results for Sonnet 4.5, GPT-5.2, and Opus 4.6 --- .gitignore | 5 + docs/src/data/modelData.ts | 1 - rca/api_config.example.yaml | 4 + rca/api_config.yaml | 4 - rca/api_profiles.example.yaml | 32 +++ rca/api_router.py | 72 +++-- rca/baseline/rca_agent/controller.py | 34 ++- rca/run_agent_standard.py | 407 +++++++++++++++++++-------- 8 files changed, 413 insertions(+), 146 deletions(-) create mode 100644 rca/api_config.example.yaml delete mode 100644 rca/api_config.yaml create mode 100644 rca/api_profiles.example.yaml diff --git a/.gitignore b/.gitignore index 9cc8140..3731c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,10 @@ .temp/ dataset/ test/ +submission/ api_config.yaml +api_profiles.yaml +temp-share/ # Mac .DS_Store @@ -169,3 +172,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +.claude/ \ No newline at end of file diff --git a/docs/src/data/modelData.ts b/docs/src/data/modelData.ts index 3128f82..5d2d81a 100644 --- a/docs/src/data/modelData.ts +++ b/docs/src/data/modelData.ts @@ -38,7 +38,6 @@ export const modelData: Data[] = [ { name: 'RCA-Agent', model: 'Claude 3.5 Sonnet', org: 'Microsoft', correct: '11.34%', partial: '17.31%', date: '2025/1/23' }, { name: 'RCA-Agent', model: 'GPT-4o', org: 'Microsoft', correct: '8.96%', partial: '17.91%', date: '2025/1/23' }, { name: 'RCA-Agent', model: 'Gemini 1.5 Pro', org: 'Microsoft', correct: '2.69%', partial: '6.87%', date: '2025/1/23' }, - // Closed Models - Balanced { name: 'Prompting (Balanced)', model: 'Claude 3.5 Sonnet', org: 'Microsoft', correct: '3.88%', partial: '18.81%', date: '2025/1/23' }, { name: 'Prompting (Balanced)', model: 'GPT-4o', org: 'Microsoft', correct: '3.28%', partial: '14.33%', date: '2025/1/23' }, diff --git a/rca/api_config.example.yaml b/rca/api_config.example.yaml new file mode 100644 index 0000000..b9505ff --- /dev/null +++ b/rca/api_config.example.yaml @@ -0,0 +1,4 @@ +SOURCE: "Anthropic" +MODEL: "claude-sonnet-4-5-20250929" +API_KEY: "sk-ant-xxxxxxxxxxxxx" +API_BASE: "" diff --git a/rca/api_config.yaml b/rca/api_config.yaml deleted file mode 100644 index 439264a..0000000 --- a/rca/api_config.yaml +++ /dev/null @@ -1,4 +0,0 @@ -SOURCE: "OpenAI" -MODEL: "gpt-4o-2024-05-13" -API_KEY: "sk-xxxxxxxxxxxxxx" -API_BASE: "" \ No newline at end of file diff --git a/rca/api_profiles.example.yaml b/rca/api_profiles.example.yaml new file mode 100644 index 0000000..5c77f51 --- /dev/null +++ b/rca/api_profiles.example.yaml @@ -0,0 +1,32 @@ +# API Profiles — copy a profile's fields into api_config.yaml or use CLI overrides +# Usage: python -m rca.run_agent_standard --dataset Bank --source OpenAI --model gpt-4o --api_key sk-... + +anthropic-sonnet: + SOURCE: "Anthropic" + MODEL: "claude-sonnet-4-5-20250929" + API_KEY: "sk-ant-xxxxxxxxxxxxx" + +anthropic-opus-4.5: + SOURCE: "Anthropic" + MODEL: "claude-opus-4-5-20251101" + API_KEY: "sk-ant-xxxxxxxxxxxxx" + +anthropic-opus-4.6: + SOURCE: "Anthropic" + MODEL: "claude-opus-4-6-20260110" + API_KEY: "sk-ant-xxxxxxxxxxxxx" + +openai-gpt5.2: + SOURCE: "OpenAI" + MODEL: "gpt-5.2-2025-12-11" + API_KEY: "sk-proj-xxxxxxxxxxxxx" + +openai-o3: + SOURCE: "OpenAI" + MODEL: "o3" + API_KEY: "sk-proj-xxxxxxxxxxxxx" + +google-gemini: + SOURCE: "Google" + MODEL: "gemini-3-pro-preview" + API_KEY: "AIza-xxxxxxxxxxxxx" diff --git a/rca/api_router.py b/rca/api_router.py index ecf53b5..ee2191b 100644 --- a/rca/api_router.py +++ b/rca/api_router.py @@ -23,33 +23,50 @@ def OpenAI_chat_completion(messages, temperature): ).choices[0].message.content def Google_chat_completion(messages, temperature): - import google.generativeai as genai - genai.configure( - api_key=configs["API_KEY"] + from google import genai + from google.genai import types + client = genai.Client(api_key=configs["API_KEY"], http_options=types.HttpOptions(timeout=120_000)) + system_instruction = None + if messages and messages[0]["role"] == "system": + system_instruction = messages[0]["content"] + messages = messages[1:] + contents = [] + for item in messages: + role = "model" if item["role"] == "assistant" else "user" + contents.append(types.Content(role=role, parts=[types.Part.from_text(text=item["content"])])) + config = types.GenerateContentConfig( + temperature=temperature, + system_instruction=system_instruction, + ) + response = client.models.generate_content( + model=configs["MODEL"], + contents=contents, + config=config, ) - genai.GenerationConfig(temperature=temperature) - system_instruction = messages[0]["content"] if messages[0]["role"] == "system" else None - messages = [item for item in messages if item["role"] != "system"] - messages = [{"role": "model" if item["role"] == "assistant" else item["role"], "parts": item["content"]} for item in messages] - history = messages[:-1] - message = messages[-1] - return genai.GenerativeModel( - model_name=configs["MODEL"], - system_instruction=system_instruction - ).start_chat( - history=history if history != [] else None - ).send_message(message).text + return response.text def Anthropic_chat_completion(messages, temperature): import anthropic client = anthropic.Anthropic( api_key=configs["API_KEY"] ) - return client.messages.create( + system = None + if messages and messages[0]["role"] == "system": + system = messages[0]["content"] + messages = messages[1:] + kwargs = dict( model=configs["MODEL"], messages=messages, - temperature=temperature - ).content + temperature=temperature, + max_tokens=128000, + ) + if system: + kwargs["system"] = system + text = "" + with client.messages.stream(**kwargs) as stream: + for chunk in stream.text_stream: + text += chunk + return text # for 3-rd party API which is compatible with OpenAI API (with different 'API_BASE') def AI_chat_completion(messages, temperature): @@ -78,14 +95,25 @@ def send_request(): else: raise ValueError("Invalid SOURCE in api_config file.") - for i in range(3): + max_retries = 60 + for i in range(max_retries): try: return send_request() except Exception as e: print(e) if '429' in str(e): - print("Rate limit exceeded. Waiting for 1 second.") - time.sleep(1) + if 'insufficient_quota' in str(e): + wait = 60 + else: + wait = min(2 ** i, 30) + print(f"Rate limit exceeded. Waiting for {wait} seconds (attempt {i+1}/{max_retries}).") + time.sleep(wait) + continue + elif 'Connection' in type(e).__name__ or 'ConnectionError' in str(type(e)): + wait = min(2 ** i, 30) + print(f"Connection error. Waiting for {wait} seconds (attempt {i+1}/{max_retries}).") + time.sleep(wait) continue else: - raise e \ No newline at end of file + raise e + raise RuntimeError(f"API request failed after {max_retries} retries due to rate limiting.") \ No newline at end of file diff --git a/rca/baseline/rca_agent/controller.py b/rca/baseline/rca_agent/controller.py index 1683b80..722f3d2 100644 --- a/rca/baseline/rca_agent/controller.py +++ b/rca/baseline/rca_agent/controller.py @@ -81,12 +81,23 @@ def control_loop(objective:str, plan:str, ap, bp, logger, max_step = 15, max_tur note = [{'role': 'user', 'content': f"Continue your reasoning process for the target issue:\n\n{objective}\n\nFollow the rules during issue solving:\n\n{ap.rules}.\n\nResponse format:\n\n{format}"}] attempt_actor = [] + response_raw = "" try: response_raw = get_chat_completion( messages=prompt + note, ) + if response_raw is None: + logger.error("API returned None response") + prompt.append({'role': 'user', 'content': "The API request failed. Please provide your analysis in requested JSON format."}) + continue if "```json" in response_raw: - response_raw = re.search(r"```json\n(.*)\n```", response_raw, re.S).group(1).strip() + m = re.search(r"```json\s*\n(.*?)\n\s*```", response_raw, re.S) + if m: + response_raw = m.group(1).strip() + else: + m2 = re.search(r"```json\s*(.*?)```", response_raw, re.S) + if m2: + response_raw = m2.group(1).strip() logger.debug(f"Raw Response:\n{response_raw}") if '"analysis":' not in response_raw or '"instruction":' not in response_raw or '"completed":' not in response_raw: logger.warning("Invalid response format. Please provide a valid JSON response.") @@ -107,10 +118,18 @@ def control_loop(objective:str, plan:str, ap, bp, logger, max_step = 15, max_tur answer = get_chat_completion( messages=prompt, ) + if answer is None: + answer = "API request failed. No root cause found." logger.debug(f"Raw Final Answer:\n{answer}") prompt.append({'role': 'assistant', 'content': answer}) if "```json" in answer: - answer = re.search(r"```json\n(.*)\n```", answer, re.S).group(1).strip() + m = re.search(r"```json\s*\n(.*?)\n\s*```", answer, re.S) + if m: + answer = m.group(1).strip() + else: + m2 = re.search(r"```json\s*(.*?)```", answer, re.S) + if m2: + answer = m2.group(1).strip() return answer, trajectory, prompt code, result, status, new_history = execute_act(instruction, bp.schema, history, attempt_actor, kernel, logger) @@ -144,8 +163,17 @@ def control_loop(objective:str, plan:str, ap, bp, logger, max_step = 15, max_tur answer = get_chat_completion( messages=prompt, ) + if answer is None: + answer = "API request failed. No root cause found." logger.debug(f"Raw Final Answer:\n{answer}") prompt.append({'role': 'assistant', 'content': answer}) if "```json" in answer: - answer = re.search(r"```json\n(.*)\n```", answer, re.S).group(1).strip() + m = re.search(r"```json\s*\n(.*?)\n\s*```", answer, re.S) + if m: + answer = m.group(1).strip() + else: + # Fallback: try to extract JSON object directly after ```json + m2 = re.search(r"```json\s*(.*?)```", answer, re.S) + if m2: + answer = m2.group(1).strip() return answer, trajectory, prompt diff --git a/rca/run_agent_standard.py b/rca/run_agent_standard.py index 7aac11b..70a9178 100644 --- a/rca/run_agent_standard.py +++ b/rca/run_agent_standard.py @@ -1,7 +1,9 @@ import os import sys import json +import yaml import argparse +import multiprocessing project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.insert(0, project_root) from main.evaluate import evaluate @@ -16,8 +18,8 @@ def handler(signum, frame): raise TimeoutError("Loop execution exceeded the time limit") -def main(args, uid, dataset): +def _get_prompt_modules(dataset): from rca.baseline.rca_agent.rca_agent import RCA_Agent import rca.baseline.rca_agent.prompt.agent_prompt as ap if dataset == "Telecom": @@ -26,6 +28,100 @@ def main(args, uid, dataset): import rca.baseline.rca_agent.prompt.basic_prompt_Bank as bp elif dataset == "Market/cloudbed-1" or dataset == "Market/cloudbed-2": import rca.baseline.rca_agent.prompt.basic_prompt_Market as bp + return RCA_Agent, ap, bp + + +def process_query(query_args): + """Process a single query in its own process. Top-level function for pickling.""" + idx, instruction, task_index, scoring_points, gt_columns, dataset, uid, unique_obs_path, timeout, controller_max_step, controller_max_turn, sample_num = query_args + + signal.signal(signal.SIGALRM, handler) + RCA_Agent, ap, bp = _get_prompt_modules(dataset) + + task_id = int(task_index.split('_')[1]) + best_score = 0 + results = [] + + if task_id <= 3: + catalog = "easy" + elif task_id <= 6: + catalog = "middle" + else: + catalog = "hard" + + for i in range(sample_num): + uuid = uid + f"_#{idx}-{i}" + nb = nbf.new_notebook() + nbfile = f"{unique_obs_path}/trajectory/{uuid}.ipynb" + promptfile = f"{unique_obs_path}/prompt/{uuid}.json" + logfile = f"{unique_obs_path}/history/{uuid}.log" + + proc_logger = logger.bind() + proc_logger.remove() + proc_logger.add(sys.stdout, colorize=True, enqueue=True, level="INFO") + proc_logger.add(logfile, colorize=True, enqueue=True, level="INFO") + proc_logger.debug('\n' + "#"*80 + f"\n{uuid}: {task_index}\n" + "#"*80) + + try: + signal.alarm(timeout) + + agent = RCA_Agent(ap, bp) + prediction, trajectory, prompt = agent.run(instruction, + proc_logger, + max_step=controller_max_step, + max_turn=controller_max_turn) + + signal.alarm(0) + + for step in trajectory: + code_cell = nbf.new_code_cell(step['code']) + result_cell = nbf.new_markdown_cell(f"```\n{step['result']}\n```") + nb.cells.append(code_cell) + nb.cells.append(result_cell) + with open(nbfile, 'w', encoding='utf-8') as f: + json.dump(nb, f, ensure_ascii=False, indent=4) + proc_logger.info(f"Trajectory has been saved to {nbfile}") + + with open(promptfile, 'w', encoding='utf-8') as f: + json.dump({"messages": prompt}, f, ensure_ascii=False, indent=4) + proc_logger.info(f"Prompt has been saved to {promptfile}") + + passed_criteria, failed_criteria, score = evaluate(prediction, scoring_points) + + proc_logger.info(f"Prediction: {prediction}") + proc_logger.info(f"Scoring Points: {scoring_points}") + proc_logger.info(f"Passed Criteria: {passed_criteria}") + proc_logger.info(f"Failed Criteria: {failed_criteria}") + proc_logger.info(f"Score: {score}") + best_score = max(best_score, score) + + results.append({ + "row_id": idx, + "task_index": task_index, + "instruction": instruction, + "prediction": prediction, + "groundtruth": gt_columns, + "passed": '\n'.join(passed_criteria), + "failed": '\n'.join(failed_criteria), + "score": score, + }) + + except TimeoutError: + proc_logger.error(f"Loop {i} exceeded the time limit and was skipped") + continue + except Exception as e: + proc_logger.error(f"Loop {i} failed with error: {e}") + continue + + return { + "idx": idx, + "catalog": catalog, + "best_score": best_score, + "results": results, + } + + +def main(args, uid, dataset): inst_file = f"dataset/{dataset}/query.csv" gt_file = f"dataset/{dataset}/record.csv" @@ -38,135 +134,192 @@ def main(args, uid, dataset): if not os.path.exists(inst_file) or not os.path.exists(gt_file): raise FileNotFoundError(f"Please download the dataset first.") - if not os.path.exists(f"{unique_obs_path}/history"): - os.makedirs(f"{unique_obs_path}/history") - if not os.path.exists(f"{unique_obs_path}/trajectory"): - os.makedirs(f"{unique_obs_path}/trajectory") - if not os.path.exists(f"{unique_obs_path}/prompt"): - os.makedirs(f"{unique_obs_path}/prompt") + os.makedirs(f"{unique_obs_path}/history", exist_ok=True) + os.makedirs(f"{unique_obs_path}/trajectory", exist_ok=True) + os.makedirs(f"{unique_obs_path}/prompt", exist_ok=True) if not os.path.exists(eval_file): - if not os.path.exists(f"test/result/{dataset}"): - os.makedirs(f"test/result/{dataset}") + os.makedirs(f"test/result/{dataset}", exist_ok=True) eval_df = pd.DataFrame(columns=["instruction", "prediction", "groundtruth", "passed", "failed", "score"]) else: eval_df = pd.read_csv(eval_file) - scores = { - "total": 0, - "easy": 0, - "middle": 0, - "hard": 0, - } - nums = { - "total": 0, - "easy": 0, - "middle": 0, - "hard": 0, - } + scores = {"total": 0, "easy": 0, "middle": 0, "hard": 0} + nums = {"total": 0, "easy": 0, "middle": 0, "hard": 0} + + # Determine already-completed row_ids for resume + completed_ids = set() + if "row_id" in eval_df.columns: + completed_ids = set(eval_df["row_id"].dropna().astype(int).tolist()) + # Re-tally scores from existing results + for _, existing_row in eval_df.iterrows(): + if "task_index" in existing_row and "score" in existing_row: + try: + tid = int(str(existing_row["task_index"]).split('_')[1]) + s = float(existing_row["score"]) + except (ValueError, IndexError): + continue + cat = "easy" if tid <= 3 else "middle" if tid <= 6 else "hard" + scores[cat] += s + scores["total"] += s + nums[cat] += 1 + nums["total"] += 1 - signal.signal(signal.SIGALRM, handler) logger.info(f"Using dataset: {dataset}") logger.info(f"Using model: {configs['MODEL'].split('/')[-1]}") - - for idx, row in instruct_data.iterrows(): + if completed_ids: + logger.info(f"Resuming: {len(completed_ids)} queries already completed, skipping them") + # Build list of query args, skipping completed + query_args_list = [] + for idx, row in instruct_data.iterrows(): if idx < args.start_idx: - continue + continue if idx > args.end_idx: break - - instruction = row["instruction"] - task_index = row["task_index"] - scoring_points = row["scoring_points"] - task_id = int(task_index.split('_')[1]) - best_score = 0 - - if task_id <= 3: - catalog = "easy" - elif task_id <= 6: - catalog = "middle" - elif task_id <= 7: - catalog = "hard" - - for i in range(args.sample_num): - uuid = uid + f"_#{idx}-{i}" - nb = nbf.new_notebook() - nbfile = f"{unique_obs_path}/trajectory/{uuid}.ipynb" - promptfile = f"{unique_obs_path}/prompt/{uuid}.json" - logfile = f"{unique_obs_path}/history/{uuid}.log" - logger.remove() - logger.add(sys.stdout, colorize=True, enqueue=True, level="INFO") - logger.add(logfile, colorize=True, enqueue=True, level="INFO") - logger.debug('\n' + "#"*80 + f"\n{uuid}: {task_index}\n" + "#"*80) - try: - signal.alarm(args.timeout) - - agent = RCA_Agent(ap, bp) - prediction, trajectory, prompt = agent.run(instruction, - logger, - max_step=args.controller_max_step, - max_turn=args.controller_max_turn) - - signal.alarm(0) - - for step in trajectory: - code_cell = nbf.new_code_cell(step['code']) - result_cell = nbf.new_markdown_cell(f"```\n{step['result']}\n```") - nb.cells.append(code_cell) - nb.cells.append(result_cell) - with open(nbfile, 'w', encoding='utf-8') as f: - json.dump(nb, f, ensure_ascii=False, indent=4) - logger.info(f"Trajectory has been saved to {nbfile}") - - with open(promptfile, 'w', encoding='utf-8') as f: - json.dump({"messages": prompt}, f, ensure_ascii=False, indent=4) - logger.info(f"Prompt has been saved to {promptfile}") - - new_eval_df = pd.DataFrame([{"row_id": idx, - "task_index": task_index, - "instruction": instruction, - "prediction": prediction, - "groundtruth": '\n'.join([f'{col}: {gt_data.iloc[idx][col]}' for col in gt_data.columns if col != 'description']), - "passed": "N/A", - "failed": "N/A", - "score": "N/A"}]) - eval_df = pd.concat([eval_df, new_eval_df], - ignore_index=True) - eval_df.to_csv(eval_file, - index=False) - - passed_criteria, failed_criteria, score = evaluate(prediction, scoring_points) - - logger.info(f"Prediction: {prediction}") - logger.info(f"Scoring Points: {scoring_points}") - logger.info(f"Passed Criteria: {passed_criteria}") - logger.info(f"Failed Criteria: {failed_criteria}") - logger.info(f"Score: {score}") - best_score = max(best_score, score) - - eval_df.loc[eval_df.index[-1], "passed"] = '\n'.join(passed_criteria) - eval_df.loc[eval_df.index[-1], "failed"] = '\n'.join(failed_criteria) - eval_df.loc[eval_df.index[-1], "score"] = score - eval_df.to_csv(eval_file, - index=False) - - temp_scores = scores.copy() - temp_scores[catalog] += best_score - temp_scores["total"] += best_score - temp_nums = nums.copy() - temp_nums[catalog] += 1 - temp_nums["total"] += 1 - - except TimeoutError: - logger.error(f"Loop {i} exceeded the time limit and was skipped") - continue - - scores = temp_scores - nums = temp_nums + if idx in completed_ids: + continue + gt_columns = '\n'.join([f'{col}: {gt_data.iloc[idx][col]}' for col in gt_data.columns if col != 'description']) + query_args_list.append(( + idx, + row["instruction"], + row["task_index"], + row["scoring_points"], + gt_columns, + dataset, + uid, + unique_obs_path, + args.timeout, + args.controller_max_step, + args.controller_max_turn, + args.sample_num, + )) + + if args.workers <= 1: + # Sequential — same as original behavior + signal.signal(signal.SIGALRM, handler) + RCA_Agent, ap, bp = _get_prompt_modules(dataset) + + for qa in query_args_list: + idx = qa[0] + instruction = qa[1] + task_index = qa[2] + scoring_points = qa[3] + gt_columns = qa[4] + task_id = int(task_index.split('_')[1]) + best_score = 0 + + if task_id <= 3: + catalog = "easy" + elif task_id <= 6: + catalog = "middle" + else: + catalog = "hard" + + for i in range(args.sample_num): + uuid = uid + f"_#{idx}-{i}" + nb = nbf.new_notebook() + nbfile = f"{unique_obs_path}/trajectory/{uuid}.ipynb" + promptfile = f"{unique_obs_path}/prompt/{uuid}.json" + logfile = f"{unique_obs_path}/history/{uuid}.log" + logger.remove() + logger.add(sys.stdout, colorize=True, enqueue=True, level="INFO") + logger.add(logfile, colorize=True, enqueue=True, level="INFO") + logger.debug('\n' + "#"*80 + f"\n{uuid}: {task_index}\n" + "#"*80) + try: + signal.alarm(args.timeout) + + agent = RCA_Agent(ap, bp) + prediction, trajectory, prompt = agent.run(instruction, + logger, + max_step=args.controller_max_step, + max_turn=args.controller_max_turn) + + signal.alarm(0) + + for step in trajectory: + code_cell = nbf.new_code_cell(step['code']) + result_cell = nbf.new_markdown_cell(f"```\n{step['result']}\n```") + nb.cells.append(code_cell) + nb.cells.append(result_cell) + with open(nbfile, 'w', encoding='utf-8') as f: + json.dump(nb, f, ensure_ascii=False, indent=4) + logger.info(f"Trajectory has been saved to {nbfile}") + + with open(promptfile, 'w', encoding='utf-8') as f: + json.dump({"messages": prompt}, f, ensure_ascii=False, indent=4) + logger.info(f"Prompt has been saved to {promptfile}") + + new_eval_df = pd.DataFrame([{"row_id": idx, + "task_index": task_index, + "instruction": instruction, + "prediction": prediction, + "groundtruth": gt_columns, + "passed": "N/A", + "failed": "N/A", + "score": "N/A"}]) + eval_df = pd.concat([eval_df, new_eval_df], + ignore_index=True) + eval_df.to_csv(eval_file, + index=False) + + passed_criteria, failed_criteria, score = evaluate(prediction, scoring_points) + + logger.info(f"Prediction: {prediction}") + logger.info(f"Scoring Points: {scoring_points}") + logger.info(f"Passed Criteria: {passed_criteria}") + logger.info(f"Failed Criteria: {failed_criteria}") + logger.info(f"Score: {score}") + best_score = max(best_score, score) + + eval_df.loc[eval_df.index[-1], "passed"] = '\n'.join(passed_criteria) + eval_df.loc[eval_df.index[-1], "failed"] = '\n'.join(failed_criteria) + eval_df.loc[eval_df.index[-1], "score"] = score + eval_df.to_csv(eval_file, + index=False) + + temp_scores = scores.copy() + temp_scores[catalog] += best_score + temp_scores["total"] += best_score + temp_nums = nums.copy() + temp_nums[catalog] += 1 + temp_nums["total"] += 1 + + except TimeoutError: + logger.error(f"Loop {i} exceeded the time limit and was skipped") + continue + + scores = temp_scores + nums = temp_nums + else: + # Parallel execution + total = len(query_args_list) + logger.info(f"Running {total} queries with {args.workers} workers") + completed = 0 + + with multiprocessing.Pool(processes=args.workers) as pool: + for result in pool.imap_unordered(process_query, query_args_list): + completed += 1 + catalog = result["catalog"] + best_score = result["best_score"] + + scores[catalog] += best_score + scores["total"] += best_score + nums[catalog] += 1 + nums["total"] += 1 + + for r in result["results"]: + new_eval_df = pd.DataFrame([r]) + eval_df = pd.concat([eval_df, new_eval_df], ignore_index=True) + + eval_df.to_csv(eval_file, index=False) + logger.info(f"Progress: {completed}/{total} queries completed (idx={result['idx']}, score={best_score})") + + logger.info(f"Final scores: {scores}") + logger.info(f"Final nums: {nums}") if __name__ == "__main__": - + uid = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="Market/cloudbed-1") @@ -178,9 +331,31 @@ def main(args, uid, dataset): parser.add_argument("--timeout", type=int, default=600) parser.add_argument("--tag", type=str, default='rca') parser.add_argument("--auto", type=bool, default=False) + parser.add_argument("--workers", type=int, default=1) + parser.add_argument("--source", type=str, default=None, help="Override SOURCE (OpenAI, Anthropic, Google, AI)") + parser.add_argument("--model", type=str, default=None, help="Override MODEL") + parser.add_argument("--api_key", type=str, default=None, help="Override API_KEY") + parser.add_argument("--api_base", type=str, default=None, help="Override API_BASE") + parser.add_argument("--profile", type=str, default=None, help="Load a named profile from rca/api_profiles.yaml") args = parser.parse_args() + if args.profile: + with open("rca/api_profiles.yaml", "r") as f: + profiles = yaml.safe_load(f) + if args.profile not in profiles: + raise ValueError(f"Profile '{args.profile}' not found. Available: {list(profiles.keys())}") + for k, v in profiles[args.profile].items(): + configs[k] = v + if args.source: + configs["SOURCE"] = args.source + if args.model: + configs["MODEL"] = args.model + if args.api_key: + configs["API_KEY"] = args.api_key + if args.api_base: + configs["API_BASE"] = args.api_base + if args.auto: print(f"Auto mode is on. Model is fixed to {configs['MODEL']}") datasets = ["Market/cloudbed-1", "Market/cloudbed-2", "Bank", "Telecom"] @@ -188,4 +363,4 @@ def main(args, uid, dataset): main(args, uid, dataset) else: dataset = args.dataset - main(args, uid, dataset) \ No newline at end of file + main(args, uid, dataset)