-
Notifications
You must be signed in to change notification settings - Fork 42
[Bugfix] Add pred and choices parsing to fix the issue of score=0 for… #258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever | ||
| from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer | ||
| from ais_bench.benchmark.datasets import MMStarDataset, MMStarEvaluator | ||
| from ais_bench.benchmark.utils.postprocess.text_postprocessors import last_option_postprocess | ||
|
|
||
|
|
||
| mmstar_reader_cfg = dict( | ||
|
|
@@ -29,7 +30,8 @@ | |
| ) | ||
|
|
||
| mmstar_eval_cfg = dict( | ||
| evaluator=dict(type=MMStarEvaluator) | ||
| evaluator=dict(type=MMStarEvaluator), | ||
| pred_postprocessor=dict(type=last_option_postprocess, options="ABCD"), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using Since the prompt explicitly asks for the |
||
| ) | ||
|
|
||
| mmstar_datasets = [ | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,53 +1,70 @@ | ||||||||||
| import json | ||||||||||
| import os | ||||||||||
| import re | ||||||||||
| import string | ||||||||||
| import pandas as pd | ||||||||||
| import numpy as np | ||||||||||
|
|
||||||||||
| import numpy as np | ||||||||||
| import pandas as pd | ||||||||||
| from datasets import Dataset, DatasetDict | ||||||||||
|
|
||||||||||
| from ais_bench.benchmark.datasets import build_choices, can_infer, dump_image, split_MMMU | ||||||||||
| from ais_bench.benchmark.datasets.utils.datasets import get_data_path, toliststr | ||||||||||
| from ais_bench.benchmark.openicl import BaseEvaluator | ||||||||||
| from ais_bench.benchmark.registry import LOAD_DATASET | ||||||||||
| from ais_bench.benchmark.datasets.utils.datasets import get_data_path, toliststr | ||||||||||
| from ais_bench.benchmark.utils.logging import AISLogger | ||||||||||
| from ais_bench.benchmark.datasets import dump_image, split_MMMU, build_choices, can_infer | ||||||||||
| from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_TEXT_START, AIS_IMAGE_START | ||||||||||
| from ais_bench.benchmark.utils.prompt import AIS_CONTENT_TAG, AIS_IMAGE_START, AIS_TEXT_START | ||||||||||
|
|
||||||||||
| from .base import BaseDataset | ||||||||||
|
|
||||||||||
| IMAGE_MAP_LEN = 64 | ||||||||||
| logger = AISLogger() | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def extract_options_from_question(question_text:str): | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
| options = {} | ||||||||||
| if "Options:" in question_text: | ||||||||||
| options_part = question_text.split("Options:")[1].strip() | ||||||||||
| pattern = r"([A-Z]):\s*([^,]+(?:,\s*[^,]+)*?)(?=(?:,\s*[A-Z]:|$))" | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The regex pattern is fragile because it strictly relies on commas as separators between options. If the dataset uses newlines, periods, or just spaces (e.g., |
||||||||||
|
|
||||||||||
| matches = re.findall(pattern, options_part) | ||||||||||
| for letter, content in matches: | ||||||||||
| content = content.strip() | ||||||||||
| if content.endswith("."): | ||||||||||
| content = content[:-1] | ||||||||||
| options[letter] = content | ||||||||||
|
|
||||||||||
| return options | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @LOAD_DATASET.register_module() | ||||||||||
| class MMStarDataset(BaseDataset): | ||||||||||
|
|
||||||||||
| @staticmethod | ||||||||||
| def load(path): | ||||||||||
| path = get_data_path(path) | ||||||||||
| image_root_path = os.path.join(os.path.dirname(path), "MMStar_images") | ||||||||||
| logger.info(f"Convert base64 to image and save it in {image_root_path}") | ||||||||||
| skip_noimg = True | ||||||||||
| data = pd.read_csv(path, sep='\t') | ||||||||||
| if skip_noimg and 'image' in data: | ||||||||||
| data = data[~pd.isna(data['image'])] | ||||||||||
|
|
||||||||||
| data = pd.read_csv(path, sep="\t") | ||||||||||
| if skip_noimg and "image" in data: | ||||||||||
| data = data[~pd.isna(data["image"])] | ||||||||||
| # The image field can store the base64 encoded image or another question index (for saving space) | ||||||||||
| if 'image' in data: | ||||||||||
| data['image'] = [str(x) for x in data['image']] | ||||||||||
| image_map = {x: y for x, y in zip(data['index'], data['image'])} | ||||||||||
| if "image" in data: | ||||||||||
| data["image"] = [str(x) for x in data["image"]] | ||||||||||
| image_map = {x: y for x, y in zip(data["index"], data["image"])} | ||||||||||
| for k in image_map: | ||||||||||
| if len(image_map[k]) <= IMAGE_MAP_LEN: | ||||||||||
| idx = image_map[k] | ||||||||||
| image_map[k] = image_map[idx] | ||||||||||
|
|
||||||||||
| images = [toliststr(image_map[k]) for k in data['index']] | ||||||||||
| data['image'] = [x[0] if len(x) == 1 else x for x in images] | ||||||||||
| if 'image_path' in data: | ||||||||||
| paths = [toliststr(x) for x in data['image_path']] | ||||||||||
| data['image_path'] = [x[0] if len(x) == 1 else x for x in paths] | ||||||||||
| images = [toliststr(image_map[k]) for k in data["index"]] | ||||||||||
| data["image"] = [x[0] if len(x) == 1 else x for x in images] | ||||||||||
| if "image_path" in data: | ||||||||||
| paths = [toliststr(x) for x in data["image_path"]] | ||||||||||
| data["image_path"] = [x[0] if len(x) == 1 else x for x in paths] | ||||||||||
|
|
||||||||||
| if np.all([isinstance(x, int) for x in data['index']]): | ||||||||||
| data['index'] = [int(x) for x in data['index']] | ||||||||||
| if np.all([isinstance(x, int) for x in data["index"]]): | ||||||||||
| data["index"] = [int(x) for x in data["index"]] | ||||||||||
|
|
||||||||||
| sheet_indices = list(range(0, len(data), 1)) | ||||||||||
| data = data.iloc[sheet_indices] | ||||||||||
|
|
@@ -61,59 +78,67 @@ def load(path): | |||||||||
| for cand in string.ascii_uppercase | ||||||||||
| if cand in line and not pd.isna(line[cand]) | ||||||||||
| } | ||||||||||
| options_prompt = 'Options:\n' | ||||||||||
| options_prompt = "Options:\n" | ||||||||||
| for key, item in options.items(): | ||||||||||
| options_prompt += f'{key}. {item}\n' | ||||||||||
| hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None | ||||||||||
| # get text prompt | ||||||||||
| prompt = '' | ||||||||||
| options_prompt += f"{key}. {item}\n" | ||||||||||
|
|
||||||||||
| hint = line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None | ||||||||||
| # get text prompt | ||||||||||
| prompt = "" | ||||||||||
| if hint is not None: | ||||||||||
| prompt += f'Hint: {hint}\n' | ||||||||||
| prompt += f"Hint: {hint}\n" | ||||||||||
| prompt += line["question"] | ||||||||||
| if len(options): | ||||||||||
| prompt += options_prompt | ||||||||||
| prompt += 'Please select the correct answer from the options above. \n' | ||||||||||
| prompt += "Please select the correct answer from the options above. \n" | ||||||||||
| # add image info | ||||||||||
| if isinstance(tgt_path, list): | ||||||||||
| tgt_path = tgt_path[0] | ||||||||||
|
|
||||||||||
| content = AIS_IMAGE_START + tgt_path + AIS_CONTENT_TAG \ | ||||||||||
| + AIS_TEXT_START + prompt + AIS_CONTENT_TAG | ||||||||||
| choices = build_choices(line) | ||||||||||
| dataset.append({"content": content, | ||||||||||
| "answer": {'choices': json.dumps(choices), | ||||||||||
| 'answer': line['answer'], | ||||||||||
| 'split': line.get('split'), | ||||||||||
| 'l2-category': line.get('l2-category'), | ||||||||||
| 'category': line.get('category')}}) | ||||||||||
|
|
||||||||||
| content = ( | ||||||||||
| AIS_IMAGE_START | ||||||||||
| + tgt_path | ||||||||||
| + AIS_CONTENT_TAG | ||||||||||
| + AIS_TEXT_START | ||||||||||
| + prompt | ||||||||||
| + AIS_CONTENT_TAG | ||||||||||
| ) | ||||||||||
| choices = build_choices(extract_options_from_question(line['question'])) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change introduces a regression. By calling You should prioritize the options from the columns and only fallback to parsing the question text if they are missing.
Suggested change
|
||||||||||
| dataset.append( | ||||||||||
| { | ||||||||||
| "content": content, | ||||||||||
| "answer": { | ||||||||||
| "choices": json.dumps(choices), | ||||||||||
| "answer": line["answer"], | ||||||||||
| "split": line.get("split"), | ||||||||||
| "l2-category": line.get("l2-category"), | ||||||||||
| "category": line.get("category"), | ||||||||||
| }, | ||||||||||
| } | ||||||||||
| ) | ||||||||||
| return Dataset.from_list(dataset) | ||||||||||
|
|
||||||||||
| class MMStarEvaluator(BaseEvaluator): | ||||||||||
|
|
||||||||||
| class MMStarEvaluator(BaseEvaluator): | ||||||||||
| def score(self, predictions, references): | ||||||||||
| result = {} | ||||||||||
| if len(predictions) != len(references): | ||||||||||
| return { | ||||||||||
| 'error': 'predictions and references have different ' | ||||||||||
| 'length' | ||||||||||
| } | ||||||||||
| return {"error": "predictions and references have different length"} | ||||||||||
| details = [] | ||||||||||
| overall_key = 'Overall' | ||||||||||
| overall_key = "Overall" | ||||||||||
| for pred, refer in zip(predictions, references): | ||||||||||
| detail = {'pred': pred, 'answer': refer, 'correct': False} | ||||||||||
| choices = json.loads(refer['choices']) | ||||||||||
| detail = {"pred": pred, "answer": refer, "correct": False} | ||||||||||
| choices = json.loads(refer["choices"]) | ||||||||||
| infer_res = can_infer(pred, choices) | ||||||||||
| key_category = refer['category'] | ||||||||||
| score = 1 if infer_res == refer['answer'] else 0 | ||||||||||
|
|
||||||||||
| key_category = refer["category"] | ||||||||||
| score = 1 if infer_res == refer["answer"] else 0 | ||||||||||
| if score == 1: | ||||||||||
| detail['correct'] = True | ||||||||||
| detail["correct"] = True | ||||||||||
| details.append(detail) | ||||||||||
| result.setdefault(overall_key, []).append(score) | ||||||||||
| result.setdefault(key_category, []).append(score) | ||||||||||
| for key in result: | ||||||||||
| result[key] = 100 * sum(result[key]) / len(result[key]) | ||||||||||
| result['details'] = details | ||||||||||
| result["details"] = details | ||||||||||
| return result | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
last_option_postprocessis risky for multiple-choice evaluation, especially if the model provides reasoning (Chain-of-Thought). This function extracts the last occurrence of any character in theoptionsstring ("ABCD"). If the model's explanation mentions other options (e.g., "Option B is incorrect, so the answer is A"), this will incorrectly return 'B' as the prediction.Consider using
first_option_postprocessor a more specific regex-based postprocessor that targets a final answer pattern (e.g.,ANSWER: [A-D]).