From 9cf5ef81f6a31c80cda9f03cb71670336f8c82ea Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Sat, 28 Mar 2026 16:00:49 +0100 Subject: [PATCH 1/8] Add multilingual task suite aligned with unified :cf/:mcf/:gen conventions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor existing multilingual tasks and add new ones to match the English task conventions (BPB merged into CF, explicit few-shot counts, English explicitly excluded from all multilingual task registrations). Refactored: - mlmm_arc_challenge: cf+mcf variants, 26 langs, hf_revision pinned - global_mmlu: cf+mcf variants, 33 langs (English removed), both formulations - mlmm_hellaswag: cf only, 33 langs, few_shots_split=train added - mgsm: :gen suffix, generation_size=512, both expr_gold + multilingual_quasi_em New tasks: - global_mmlu_lite: CohereLabs/Global-MMLU-Lite, 17 langs, cf+mcf - mmlu_prox (multilingual): li-lab/MMLU-ProX, 28 langs, 10-option cf+mcf - mmlu_prox (English): li-lab/MMLU-ProX English subset in tasks/tasks/ - wmt24pp: google/wmt24pp, 24 lang pairs × 2 directions, 0-shot gen Tooling: - scripts/multilingual_aggregate.py: cross-language average post-processor - TASK_NAMING_Multilingual.md: naming conventions and language inventory doc add comet22 metrics --- TASK_NAMING_Multilingual.md | 487 ++++++++++++++++++ pyproject.toml | 1 + scripts/multilingual_aggregate.py | 301 +++++++++++ src/lighteval/logging/info_loggers.py | 5 +- src/lighteval/metrics/dynamic_metrics.py | 121 ++++- src/lighteval/metrics/metrics.py | 9 + src/lighteval/metrics/metrics_corpus.py | 53 +- src/lighteval/metrics/metrics_sample.py | 10 +- src/lighteval/metrics/sample_preparator.py | 20 + src/lighteval/metrics/utils/nltk_resources.py | 28 + .../models/transformers/transformers_model.py | 7 +- src/lighteval/tasks/lighteval_task.py | 311 ++++++++++- .../tasks/multilingual/tasks/flores200.py | 80 +-- .../tasks/multilingual/tasks/global_mmlu.py | 275 +++++----- .../multilingual/tasks/global_mmlu_lite.py | 131 +++++ .../tasks/multilingual/tasks/mgsm.py | 57 +- .../multilingual/tasks/mlmm_arc_challenge.py | 169 +++--- .../multilingual/tasks/mlmm_hellaswag.py | 180 ++++--- .../tasks/multilingual/tasks/mmlu_prox.py | 170 ++++++ .../tasks/multilingual/tasks/wmt24pp.py | 114 ++++ .../tasks/multilingual/utils/translation.py | 103 ++++ src/lighteval/tasks/tasks/mmlu_prox.py | 133 +++++ src/lighteval/tasks/templates/translation.py | 6 +- .../templates/utils/translation_literals.py | 24 + 24 files changed, 2448 insertions(+), 347 deletions(-) create mode 100644 TASK_NAMING_Multilingual.md create mode 100644 scripts/multilingual_aggregate.py create mode 100644 src/lighteval/metrics/utils/nltk_resources.py create mode 100644 src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py create mode 100644 src/lighteval/tasks/multilingual/tasks/mmlu_prox.py create mode 100644 src/lighteval/tasks/multilingual/tasks/wmt24pp.py create mode 100644 src/lighteval/tasks/multilingual/utils/translation.py create mode 100644 src/lighteval/tasks/tasks/mmlu_prox.py diff --git a/TASK_NAMING_Multilingual.md b/TASK_NAMING_Multilingual.md new file mode 100644 index 000000000..26b833aaa --- /dev/null +++ b/TASK_NAMING_Multilingual.md @@ -0,0 +1,487 @@ +# Multilingual Task Naming Conventions + +Most multilingual tasks follow the 3-part naming rule: `{base}:{lang}:{suffix}`. Translation tasks +are now English-centric exceptions: the public selector is `{base}:{lang}`, and it expands to the +two internal directional leaf tasks `{base}:en_to_x:{lang}` and `{base}:x_to_en:{lang}`. + +**Key rule**: English is **explicitly excluded** from every multilingual task. Use the +corresponding English task file instead (e.g., `arc.py`, `mmlu.py`, `hellaswag.py`, `gsm8k.py`). + +--- + +## Naming pattern + +``` +{base}:{lang}:{suffix}|{n_shot} +``` + +- `{base}` — task name (e.g., `global_mmlu`, `mlmm_arc`, `mmlu_prox`) +- `{lang}` — ISO 639-3 language value from `Language` enum (e.g., `deu`, `fra`, `zho`) +- `{suffix}` — `:cf`, `:mcf`, or `:gen` +- `{n_shot}` — appended at runtime (e.g., `|5` for 5-shot) + +**Group selectors** (auto-generated by `_task_superset_dict`): + +| Selector | Matches | +|---|---| +| `global_mmlu:cf\|5` | all 41 languages, CF variant | +| `global_mmlu:fra:cf\|5` | French only, CF variant | +| `global_mmlu\|5` | all 41 languages × CF + MCF | +| `mlmm_arc:cf\|5` | all 26 languages, CF variant | +| `mmlu_prox:cf\|5` | all 28 multilingual languages, CF variant | +| `mlmm_hellaswag:cf\|5` | all 32 languages | +| `mgsm:gen\|8` | all 10 languages | +| `wmt24pp\|0` | all ~24 English-centric language slices (both directions each) | +| `wmt24pp:de_DE\|0` | German slice only: en→de_DE + de_DE→en | +| `flores200\|0` | all English-centric FLORES language slices | +| `flores200:deu_Latn\|0` | German slice only: eng_Latn→deu_Latn + deu_Latn→eng_Latn | + +--- + +## Suffix reference (same as English tasks) + +| Suffix | Metrics reported | Description | +|--------|-----------------|-------------| +| `:cf` | `acc`, `acc_norm` (char), `target_bpb` | Score full answer text via log p(choice\|context); BPB merged in | +| `:mcf` | `acc`, `acc_norm` (char) | Score label tokens only (A, B, …); no BPB | +| `:mcf_em` | `em` (exact match accuracy) | MCF prompt + greedy generation (1–5 tokens) + exact match on predicted label | +| `:gen` | task-specific (F1, EM, expr_gold, quasi_em; or `chrf++`, `bleu`, `comet22` for translation) | Greedy generation | + +--- + +## Task inventory + +### Multilingual ARC Challenge — `mlmm_arc` + +| Task pattern | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `mlmm_arc:{lang}:cf\|5` | `jon-tow/okapi_arc_challenge` | test | train | 5 | acc, acc_norm, target_bpb | +| `mlmm_arc:{lang}:mcf\|5` | same | test | train | 5 | acc, acc_norm | +| `mlmm_arc:{lang}:mcf_em\|5` | same | test | train | 5 | em | + +**26 languages** (English excluded): arabic, bengali, catalan, chinese, croatian, danish, dutch, +french, german, hindi, hungarian, indonesian, italian, kannada, malayalam, marathi, nepali, +romanian, russian, serbian, slovak, spanish, tamil, telugu, ukrainian, vietnamese. + +**Language codes**: +`ara`, `ben`, `cat`, `zho`, `hrv`, `dan`, `nld`, `fra`, `deu`, `hin`, `hun`, `ind`, `ita`, +`kan`, `mal`, `mar`, `nep`, `ron`, `rus`, `srp`, `slk`, `spa`, `tam`, `tel`, `ukr`, `vie`. + +File: `multilingual/tasks/mlmm_arc_challenge.py` + +**Copy-pasteable examples:** +``` +# All 26 languages, CF +mlmm_arc:cf|5 + +# All 26 languages, MCF +mlmm_arc:mcf|5 + +# All 26 languages, MCF greedy +mlmm_arc:mcf_em|5 + +# Single language +mlmm_arc:deu:cf|5 +mlmm_arc:fra:mcf|5 +mlmm_arc:fra:mcf_em|5 +``` + +--- + +### Global MMLU — `global_mmlu` + +| Task pattern | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `global_mmlu:{lang}:cf\|5` | `CohereForAI/Global-MMLU` | test | dev | 5 | see below | +| `global_mmlu:{lang}:mcf\|5` | same | test | dev | 5 | see below | +| `global_mmlu:{lang}:mcf_em\|5` | same | test | dev | 5 | em | + +**41 languages** (English excluded): amharic, arabic, bengali, chinese, czech, dutch, french, +german, greek, hausa, hebrew, hindi, igbo, indonesian, italian, japanese, kirghiz, korean, +lithuanian, malagasy, malay, nepali, nyanja, persian, polish, portuguese, romanian, russian, +serbian, shona, sinhala, somali, spanish, swahili, swedish, tagalog, telugu, turkish, +ukrainian, vietnamese, yoruba. + +**Language codes**: +`amh`, `ara`, `ben`, `zho`, `ces`, `nld`, `fra`, `deu`, `ell`, `hau`, `heb`, `hin`, `ibo`, +`ind`, `ita`, `jpn`, `kir`, `kor`, `lit`, `mlg`, `msa`, `nep`, `nya`, `fas`, `pol`, `por`, +`ron`, `rus`, `srp`, `sna`, `sin`, `som`, `spa`, `swa`, `swe`, `tgl`, `tel`, `tur`, `ukr`, +`vie`, `yor`. + +**All 57 MMLU subjects are evaluated** — no per-subject task splitting. Results are +automatically routed to STEM / Humanities / Social / Other categories inside the metric. + +#### CF metrics (`global_mmlu:{lang}:cf`) + +| Metric key | Description | +|---|---| +| `acc` | overall accuracy across all subjects | +| `acc_norm` | overall char-normalised accuracy | +| `bpb` | overall bits-per-byte (lower = better) | +| `acc_stem`, `acc_norm_stem`, `bpb_stem` | STEM category (20 subjects) | +| `acc_humanities`, `acc_norm_humanities`, `bpb_humanities` | Humanities category (13 subjects) | +| `acc_social`, `acc_norm_social`, `bpb_social` | Social Sciences category (12 subjects) | +| `acc_other`, `acc_norm_other`, `bpb_other` | Other category (remaining subjects) | + +Per-category keys return `None` for samples outside the category; the corpus aggregator +uses `mean_notnone` so they average only over samples that belong to the category. + +#### MCF metrics (`global_mmlu:{lang}:mcf`) + +| Metric key | Description | +|---|---| +| `acc` | overall accuracy | +| `acc_norm` | overall char-normalised accuracy | +| `acc_stem`, `acc_norm_stem` | STEM category | +| `acc_humanities`, `acc_norm_humanities` | Humanities category | +| `acc_social`, `acc_norm_social` | Social Sciences category | +| `acc_other`, `acc_norm_other` | Other category | + +File: `multilingual/tasks/global_mmlu.py` +Metrics class: `MMLUCategoryGroupingCF` / `MMLUCategoryGroupingMCF` in `metrics/dynamic_metrics.py` + +**Copy-pasteable examples:** +``` +# All 41 languages, CF (with category-level metrics) +global_mmlu:cf|5 + +# All 41 languages, MCF +global_mmlu:mcf|5 + +# All 41 languages, MCF greedy +global_mmlu:mcf_em|5 + +# All 41 languages, both CF and MCF +global_mmlu|5 + +# Single language +global_mmlu:fra:cf|5 +global_mmlu:zho:mcf|5 +global_mmlu:zho:mcf_em|5 +``` + +--- + +### Global MMLU Lite — `global_mmlu_lite` + +| Task pattern | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `global_mmlu_lite:{lang}:cf\|5` | `CohereLabs/Global-MMLU-Lite` | test | dev | 5 | acc, acc_norm, target_bpb | +| `global_mmlu_lite:{lang}:mcf\|5` | same | test | dev | 5 | acc, acc_norm | +| `global_mmlu_lite:{lang}:mcf_em\|5` | same | test | dev | 5 | em | + +**17 languages** (English excluded): arabic, bengali, welsh, german, spanish, french, hindi, +indonesian, italian, japanese, korean, burmese, portuguese, albanian, swahili, yoruba, chinese. + +**Language codes**: +`ara`, `ben`, `cym`, `deu`, `spa`, `fra`, `hin`, `ind`, `ita`, `jpn`, `kor`, `mya`, `por`, +`sqi`, `swa`, `yor`, `zho`. + +~400 test samples and ~215 dev samples per language across 43 subjects. + +File: `multilingual/tasks/global_mmlu_lite.py` + +**Copy-pasteable examples:** +``` +# All 17 languages, CF +global_mmlu_lite:cf|5 + +# All 17 languages, MCF +global_mmlu_lite:mcf|5 + +# All 17 languages, MCF greedy +global_mmlu_lite:mcf_em|5 + +# Single language +global_mmlu_lite:ara:cf|5 +global_mmlu_lite:ara:mcf_em|5 +``` + +--- + +### MMLU-ProX English — `mmlu_prox_eng` + +| Task | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `mmlu_prox_eng:cf\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | acc, acc_norm, target_bpb | +| `mmlu_prox_eng:mcf\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | acc, acc_norm | +| `mmlu_prox_eng:mcf_em\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | em | + +English subset (`hf_subset="en"`). Up to 10 answer options (labels A–J). + +**Language code**: `eng` + +File: `tasks/tasks/mmlu_prox.py` (2-part name — not captured by `mmlu_prox:cf` superset) + +**Copy-pasteable examples:** +``` +mmlu_prox_eng:cf|5 +mmlu_prox_eng:mcf|5 +mmlu_prox_eng:mcf_em|5 +``` + +--- + +### MMLU-ProX Multilingual — `mmlu_prox` + +| Task pattern | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `mmlu_prox:{lang}:cf\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | acc, acc_norm, target_bpb | +| `mmlu_prox:{lang}:mcf\|5` | same | test | validation | 5 | acc, acc_norm | +| `mmlu_prox:{lang}:mcf_em\|5` | same | test | validation | 5 | em | + +**28 languages** (English excluded): afrikaans, arabic, bengali, czech, german, spanish, french, +hindi, hungarian, indonesian, italian, japanese, korean, marathi, nepali, portuguese, russian, +serbian, swahili, telugu, thai, ukrainian, urdu, vietnamese, wolof, yoruba, chinese, zulu. + +**Language codes**: +`afr`, `ara`, `ben`, `ces`, `deu`, `spa`, `fra`, `hin`, `hun`, `ind`, `ita`, `jpn`, `kor`, +`mar`, `nep`, `por`, `rus`, `srp`, `swa`, `tel`, `tha`, `ukr`, `urd`, `vie`, `wol`, `yor`, +`zho`, `zul`. + +Up to 10 options per question (labels A–J). + +File: `multilingual/tasks/mmlu_prox.py` + +**Copy-pasteable examples:** +``` +# All 28 languages, CF +mmlu_prox:cf|5 + +# All 28 languages, MCF +mmlu_prox:mcf|5 + +# All 28 languages, MCF greedy +mmlu_prox:mcf_em|5 + +# Single language +mmlu_prox:deu:cf|5 +mmlu_prox:deu:mcf_em|5 +``` + +--- + +### Multilingual HellaSwag — `mlmm_hellaswag` + +| Task pattern | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `mlmm_hellaswag:{lang}:cf\|5` | `alexandrainst/m_hellaswag` (most) / `jon-tow/okapi_hellaswag` (zh) | val/validation | None | 5 | acc, acc_norm, target_bpb | +| `mlmm_hellaswag:{lang}:mcf\|5` | same | val/validation | None | 5 | acc, acc_norm | +| `mlmm_hellaswag:{lang}:mcf_em\|5` | same | val/validation | None | 5 | em | + +HellaSwag has 4 candidate continuations; the MCF formulation labels them A–D and scores (or generates) the label token. + +**32 languages** (English excluded): arabic, armenian, basque, bengali, catalan, chinese, +croatian, danish, dutch, french, german, gujarati, hindi, hungarian, icelandic, indonesian, +italian, kannada, malayalam, marathi, nepali, portuguese, romanian, russian, +serbian, slovak, spanish, swedish, tamil, telugu, ukrainian, vietnamese. + +**Language codes**: +`ara`, `hye`, `eus`, `ben`, `cat`, `zho`, `hrv`, `dan`, `nld`, `fra`, `deu`, `guj`, `hin`, +`hun`, `isl`, `ind`, `ita`, `kan`, `mal`, `mar`, `nep`, `por`, `ron`, `rus`, `srp`, `slk`, +`spa`, `swe`, `tam`, `tel`, `ukr`, `vie`. + +File: `multilingual/tasks/mlmm_hellaswag.py` + +**Copy-pasteable examples:** +``` +# All 32 languages, CF +mlmm_hellaswag:cf|5 + +# All 32 languages, MCF +mlmm_hellaswag:mcf|5 + +# All 32 languages, MCF greedy +mlmm_hellaswag:mcf_em|5 + +# Single language +mlmm_hellaswag:deu:cf|5 +mlmm_hellaswag:zho:mcf|5 +mlmm_hellaswag:zho:mcf_em|5 +``` + +--- + +### Multilingual GSM — `mgsm` + +| Task pattern | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `mgsm:{lang}:gen\|8` | `juletxara/mgsm` | test | train | 8 | expr_gold_metric, multilingual_quasi_em | + +Both `expr_gold_metric` (math expression parser, for Arabic-numeral answers) and +`MultilingualQuasiExactMatchMetric` (language-aware fuzzy match, for non-ASCII digit systems) +are reported in the same pass. + +**10 languages** (English excluded): bengali, french, german, japanese, russian, spanish, +swahili, telugu, thai, chinese. + +**Language codes**: +`ben`, `fra`, `deu`, `jpn`, `rus`, `spa`, `swa`, `tel`, `tha`, `zho`. + +File: `multilingual/tasks/mgsm.py` + +**Copy-pasteable examples:** +``` +# All 10 languages +mgsm:gen|8 + +# Single language +mgsm:deu:gen|8 +mgsm:zho:gen|8 +``` + +--- + +### WMT24++ Translation — `wmt24pp` + +| Task pattern | Dataset | Eval | FS | ICL | Metrics | +|---|---|---|---|---|---| +| `wmt24pp:{lp}\|0` | `google/wmt24pp` | train | none | 0 | chrf++, bleu, comet22 | + +Running `wmt24pp:{lp}|0` expands to both internal directional tasks: +`wmt24pp:en_to_x:{lp}|0` and `wmt24pp:x_to_en:{lp}|0`. + +`{lp}` is the language-pair code from the dataset's `lp` column (e.g., `de_DE`, `fr_FR`, +`zh_CN`). Rows with `is_bad_source=True` are filtered out. + +**Available language codes** in `wmt24pp.py`: +`de_DE`, `fr_FR`, `cs_CZ`, `es_MX`, `ru_RU`, `zh_CN`, `ja_JP`, `uk_UA`, `hi_IN`, `ar_EG`, +`ko_KR`, `pt_BR`, `tr_TR`, `pl_PL`, `he_IL`, `nl_NL`, `it_IT`, `sv_SE`, `fi_FI`, `vi_VN`, +`bn_IN`, `th_TH`, `id_ID`, `hu_HU`. + +File: `multilingual/tasks/wmt24pp.py` + +**Copy-pasteable examples:** +``` +# All ~24 language slices, both directions each +wmt24pp|0 + +# Single language slice +wmt24pp:de_DE|0 +wmt24pp:zh_CN|0 +``` + +--- + +### FLORES-200 Translation — `flores200` + +| Task pattern | Dataset | Eval | FS split | ICL | Metrics | +|---|---|---|---|---|---| +| `flores200:{lang}\|0` | `facebook/flores` | devtest | dev | 0 | chrf++, bleu, comet22 | + +Task names use FLORES-200 language codes: `{iso639-3}_{script}` (e.g., `deu_Latn`, `zho_Hans`, +`arb_Arab`). English is implicit: running `flores200:{lang}|0` expands to both internal +directional tasks `flores200:en_to_x:{lang}|0` and `flores200:x_to_en:{lang}|0`. + +**~200 languages** are listed in `flores_200_languages`, but only the English-centric bilingual +slices are registered here. + +File: `multilingual/tasks/flores200.py` + +**Copy-pasteable examples:** +``` +# German slice (eng_Latn ↔ deu_Latn) +flores200:deu_Latn|0 + +# Chinese (Simplified) slice +flores200:zho_Hans|0 + +# Arabic slice +flores200:arb_Arab|0 +``` + +Note: FLORES-200 uses 4-part language codes (e.g., `eng_Latn`) not the 3-letter ISO codes +used by other tasks. The public selector has 2 components (`flores200:{lang}`), while the +directional leaf tasks use 3 components so they still participate in the `a:b:c` superset and +averaging behavior. + +--- + +## Two-level averaging + +### Level 1 — per-language average (automatic) + +For any 3-component task name `a:b:c`, lighteval auto-generates an `a:b:c:_average` aggregate +key by grouping over the middle component. For tasks with a single subset per language (all +tasks here except the old subject-split global_mmlu), this average equals the task score. + +``` +# Auto-generated by lighteval when running group selectors: +global_mmlu:cf:_average|5 # mean of global_mmlu:{lang}:cf across all langs +mlmm_arc:cf:_average|5 +mmlu_prox:cf:_average|5 +mlmm_hellaswag:cf:_average|5 +mgsm:gen:_average|8 +wmt24pp:de_DE:_average|0 # mean of en_to_x + x_to_en for de_DE +flores200:deu_Latn:_average|0 # mean of eng_Latn→deu_Latn + deu_Latn→eng_Latn +``` + +For `global_mmlu`, the per-language aggregate gives you `acc`, `acc_norm`, `bpb` overall +plus the 4×3 (CF) or 4×2 (MCF) per-category variants, all averaged across all 57 subjects +for that language. + +### Level 2 — cross-language average (post-hoc script) + +After evaluation, compute a mean across all languages: + +```bash +python scripts/multilingual_aggregate.py results.json +python scripts/multilingual_aggregate.py results.json --prefix global_mmlu +``` + +Output keys: +``` +global_mmlu:cf:cross_lang_average|5 # mean of per-lang CF averages +global_mmlu:mcf:cross_lang_average|5 +mlmm_arc:cf:cross_lang_average|5 +mlmm_hellaswag:cf:cross_lang_average|5 +mgsm:gen:cross_lang_average|8 +mmlu_prox:cf:cross_lang_average|5 +wmt24pp:de_DE:en_to_x|0 # alias for the raw directional task +wmt24pp:de_DE:x_to_en|0 # alias for the raw directional task +wmt24pp:de_DE:bidirectional_average|0 # alias for wmt24pp:de_DE:_average|0 +wmt24pp:en_to_x:cross_lang_average|0 # mean across all English→X directions +wmt24pp:x_to_en:cross_lang_average|0 # mean across all X→English directions +wmt24pp:bidirectional_cross_lang_average|0 +flores200:deu_Latn:bidirectional_average|0 +flores200:bidirectional_cross_lang_average|0 +``` + +--- + +## Quick reference: all copy-pasteable group commands + +```bash +# --- Multiple-choice tasks (CF + BPB) --- +global_mmlu:cf|5 # Global MMLU, all 41 langs, CF (acc+norm+bpb + 4 categories each) +global_mmlu:mcf|5 # Global MMLU, all 41 langs, MCF (acc+norm + 4 categories each) +global_mmlu:mcf_em|5 # Global MMLU, all 41 langs, MCF greedy (em) +global_mmlu|5 # Global MMLU, all 41 langs, CF + MCF + +global_mmlu_lite:cf|5 # Global MMLU Lite, all 17 langs, CF +global_mmlu_lite:mcf|5 # Global MMLU Lite, all 17 langs, MCF +global_mmlu_lite:mcf_em|5 # Global MMLU Lite, all 17 langs, MCF greedy (em) + +mlmm_arc:cf|5 # ARC Challenge, all 26 langs, CF +mlmm_arc:mcf|5 # ARC Challenge, all 26 langs, MCF +mlmm_arc:mcf_em|5 # ARC Challenge, all 26 langs, MCF greedy (em) + +mmlu_prox:cf|5 # MMLU-ProX multilingual, all 28 langs, CF +mmlu_prox:mcf|5 # MMLU-ProX multilingual, all 28 langs, MCF +mmlu_prox:mcf_em|5 # MMLU-ProX multilingual, all 28 langs, MCF greedy (em) +mmlu_prox_eng:cf|5 # MMLU-ProX English only +mmlu_prox_eng:mcf|5 +mmlu_prox_eng:mcf_em|5 + +# --- Reasoning / generation --- +mlmm_hellaswag:cf|5 # HellaSwag, all 32 langs, CF +mlmm_hellaswag:mcf|5 # HellaSwag, all 32 langs, MCF +mlmm_hellaswag:mcf_em|5 # HellaSwag, all 32 langs, MCF greedy (em) +mgsm:gen|8 # MGSM math, all 10 langs + +# --- Translation --- +wmt24pp|0 # WMT24++, all English-centric language slices +wmt24pp:de_DE|0 # WMT24++, German slice (both directions) +flores200|0 # FLORES-200, all English-centric language slices +flores200:deu_Latn|0 # FLORES-200, German slice (both directions) +``` diff --git a/pyproject.toml b/pyproject.toml index 576a93122..253f0ea35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ multilingual = [ "spacy[ja,ko,th]>=3.8.0", "jieba", # for chinese tokenizer "pyvi", # for vietnamese tokenizer + "unbabel-comet>=2.0.0", # COMET-22 metric for translation tasks ] math = ["latex2sympy2_extended==1.0.6"] wandb = ["wandb"] diff --git a/scripts/multilingual_aggregate.py b/scripts/multilingual_aggregate.py new file mode 100644 index 000000000..8347cd7a9 --- /dev/null +++ b/scripts/multilingual_aggregate.py @@ -0,0 +1,301 @@ +""" +multilingual_aggregate.py — Aggregation helpers for multilingual task results. + +Usage: + python scripts/multilingual_aggregate.py results.json [--prefix PREFIX] [--output out.json] + + results.json — lighteval output JSON file (the "results" field) + --prefix — optional task prefix filter, e.g. "global_mmlu" (default: all) + --output — write aggregate outputs to this JSON file (default: stdout) + +How it works +------------ +For standard multilingual tasks, lighteval auto-generates per-language subtask +averages such as: + + global_mmlu_fra:cf:_average|5 ← mean over all 57 MMLU subjects for French + global_mmlu_deu:cf:_average|5 ← mean over all 57 MMLU subjects for German + +This script groups those into cross-language averages such as: + + global_mmlu:cf:cross_lang_average|5 + +For the English-centric translation tasks, the raw task names are direction-first: + + wmt24pp:en_to_x:de_DE|0 + wmt24pp:x_to_en:de_DE|0 + wmt24pp:de_DE:_average|0 + +This script additionally emits language-centered aliases and translation-specific +cross-language averages such as: + + wmt24pp:de_DE:en_to_x|0 + wmt24pp:de_DE:x_to_en|0 + wmt24pp:de_DE:bidirectional_average|0 + wmt24pp:en_to_x:cross_lang_average|0 + wmt24pp:x_to_en:cross_lang_average|0 + wmt24pp:bidirectional_cross_lang_average|0 +""" + +import argparse +import json +import math +import re +import sys +from collections import defaultdict +from pathlib import Path + +# All 3-letter ISO 639-3 language codes used in the Language enum. +# Derived from lighteval.utils.language.Language; kept as a static set here +# to avoid importing the full lighteval package. +_LANG_CODES = { + "afr", "aka", "amh", "ara", "aze", "bam", "bel", "ben", "bos", "bul", + "cat", "ces", "ckb", "cmn", "cym", "dan", "deu", "ell", "eng", "epo", + "est", "eus", "ewe", "fas", "fin", "fra", "ful", "gla", "gle", "glg", + "grn", "guj", "hau", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind", + "isl", "ita", "jpn", "kan", "kat", "kaz", "khm", "kin", "kir", "kon", + "kor", "lao", "lat", "lin", "lit", "ltz", "lug", "luo", "lvs", "mal", + "mar", "mkd", "mlt", "mri", "msa", "mya", "nep", "nld", "nor", "nob", + "nno", "nya", "ory", "pan", "pol", "por", "pus", "ron", "run", "rus", + "sag", "sin", "slk", "slv", "smo", "sna", "som", "sot", "spa", "sqi", + "srp", "ssw", "swa", "swe", "tam", "tel", "tgk", "tgl", "tha", "tir", + "ton", "tsn", "tso", "tuk", "tur", "twi", "uig", "ukr", "urd", "uzn", + "vie", "wol", "xho", "yor", "zho", "zsm", "zul", +} + +# Pattern: task names whose first component ends with _{lang_code} +# e.g. "global_mmlu_fra" → base="global_mmlu", lang="fra" +_LANG_RE = re.compile(r"^(.+?)_(" + "|".join(_LANG_CODES) + r")(:.+)?$") +_TRANSLATION_DIRECTION_RE = re.compile(r"^(?P[^:]+):(?Pen_to_x|x_to_en):(?P[^:]+)$") +_TRANSLATION_AVERAGE_RE = re.compile(r"^(?P[^:]+):(?P[^:]+):_average$") + + +def _strip_lang(task_no_fewshot: str): + """ + Given a task name without the |n suffix, return (base_name, lang_code). + base_name is the task name with the language component replaced by nothing. + Returns (None, None) if no language component found. + + Example: + "global_mmlu_fra:cf:_average" → ("global_mmlu:cf:_average", "fra") + "mlmm_arc_deu:challenge:cf:_average" → ("mlmm_arc:challenge:cf:_average", "deu") + "gsm8k" → (None, None) + """ + m = _LANG_RE.match(task_no_fewshot) + if m: + base_prefix = m.group(1) + lang = m.group(2) + rest = m.group(3) or "" + return f"{base_prefix}{rest}", lang + return None, None + + +def _avg(values): + return sum(values) / len(values) if values else float("nan") + + +def _propagate_se(se_values): + """SE of mean of independent estimates: sqrt(sum(SE_i^2)) / n.""" + n = len(se_values) + return math.sqrt(sum(v ** 2 for v in se_values)) / n if n else float("nan") + + +def _split_fewshot(task_name: str): + if "|" in task_name: + task_no_fs, fs_suffix = task_name.rsplit("|", 1) + return task_no_fs, "|" + fs_suffix + return task_name, "" + + +def _is_valid_number(value): + return isinstance(value, (int, float)) and not ( + isinstance(value, float) and math.isnan(value) + ) + + +def _append_numeric_metrics(metric_lists: dict[str, list[float]], metrics: dict): + for metric_name, metric_val in metrics.items(): + if _is_valid_number(metric_val): + metric_lists[metric_name].append(metric_val) + + +def _average_metric_lists(metric_lists: dict[str, list[float]]) -> dict: + averaged = {} + for metric_name, vals in metric_lists.items(): + if metric_name.endswith("_stderr"): + averaged[metric_name] = _propagate_se(vals) + else: + averaged[metric_name] = _avg(vals) + return averaged + + +def compute_cross_lang_averages(results: dict, prefix: str | None = None) -> dict: + """ + Given the 'results' dict from a lighteval JSON output, compute cross-language + averages for all multilingual per-language-average tasks. + + Returns a dict mapping cross-language-average task name → metric dict. + """ + # Only consider tasks that are per-language averages (contain :_average) + average_tasks = {k: v for k, v in results.items() if ":_average" in k} + + if prefix: + average_tasks = {k: v for k, v in average_tasks.items() if k.startswith(prefix)} + + # Group by (base_name_with_fewshot) where lang is stripped out + groups: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) + + for task_name, metrics in average_tasks.items(): + # Split off |n fewshot suffix + if "|" in task_name: + task_no_fs, fs_suffix = task_name.rsplit("|", 1) + fs_part = "|" + fs_suffix + else: + task_no_fs = task_name + fs_part = "" + + base_name, lang = _strip_lang(task_no_fs) + if base_name is None: + continue # not a multilingual task we can aggregate + + group_key = base_name + fs_part + for metric_name, metric_val in metrics.items(): + if isinstance(metric_val, (int, float)) and not ( + isinstance(metric_val, float) and math.isnan(metric_val) + ): + groups[group_key][metric_name].append(metric_val) + + # Compute cross-language averages + cross_lang: dict[str, dict] = {} + for group_key, metric_lists in groups.items(): + if not metric_lists: + continue + n_langs = max(len(v) for v in metric_lists.values()) + if n_langs < 2: + continue # only one language — no meaningful cross-lang average + + # Rename key: replace ":_average" with ":cross_lang_average" + out_key = group_key.replace(":_average", ":cross_lang_average") + cross_lang[out_key] = {} + for metric_name, vals in metric_lists.items(): + if metric_name.endswith("_stderr"): + cross_lang[out_key][metric_name] = _propagate_se(vals) + else: + cross_lang[out_key][metric_name] = _avg(vals) + cross_lang[out_key]["_n_languages"] = n_langs + + return cross_lang + + +def compute_translation_aggregates(results: dict, prefix: str | None = None) -> dict: + directional_metrics = {} + bidirectional_averages = {} + + for task_name, metrics in results.items(): + if prefix and not task_name.startswith(prefix): + continue + + task_no_fs, fs_part = _split_fewshot(task_name) + + directional_match = _TRANSLATION_DIRECTION_RE.match(task_no_fs) + if directional_match: + key = ( + directional_match.group("base"), + directional_match.group("lang"), + fs_part, + ) + directional_metrics.setdefault(key, {})[ + directional_match.group("direction") + ] = metrics + continue + + average_match = _TRANSLATION_AVERAGE_RE.match(task_no_fs) + if average_match: + key = ( + average_match.group("base"), + average_match.group("lang"), + fs_part, + ) + bidirectional_averages[key] = metrics + + if not directional_metrics: + return {} + + aliases = {} + directional_groups: dict[str, dict[str, list[float]]] = defaultdict( + lambda: defaultdict(list) + ) + bidirectional_groups: dict[str, dict[str, list[float]]] = defaultdict( + lambda: defaultdict(list) + ) + + for key, per_direction in directional_metrics.items(): + base, lang, fs_part = key + + for direction, metrics in per_direction.items(): + alias_key = f"{base}:{lang}:{direction}{fs_part}" + aliases[alias_key] = metrics + _append_numeric_metrics( + directional_groups[f"{base}:{direction}:cross_lang_average{fs_part}"], + metrics, + ) + + average_metrics = bidirectional_averages.get(key) + if average_metrics is not None: + aliases[f"{base}:{lang}:bidirectional_average{fs_part}"] = average_metrics + _append_numeric_metrics( + bidirectional_groups[f"{base}:bidirectional_cross_lang_average{fs_part}"], + average_metrics, + ) + + for out_key, metric_lists in directional_groups.items(): + n_langs = max(len(v) for v in metric_lists.values()) + if n_langs < 2: + continue + aliases[out_key] = _average_metric_lists(metric_lists) + aliases[out_key]["_n_languages"] = n_langs + + for out_key, metric_lists in bidirectional_groups.items(): + n_langs = max(len(v) for v in metric_lists.values()) + if n_langs < 2: + continue + aliases[out_key] = _average_metric_lists(metric_lists) + aliases[out_key]["_n_languages"] = n_langs + + return aliases + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("results_file", help="Path to lighteval results JSON file") + parser.add_argument("--prefix", default=None, help="Filter tasks by prefix (e.g. 'global_mmlu')") + parser.add_argument("--output", default=None, help="Write output to this JSON file (default: stdout)") + args = parser.parse_args() + + with open(args.results_file) as f: + data = json.load(f) + + # Support both the raw results dict and the full lighteval JSON format + if "results" in data: + results = data["results"] + else: + results = data + + aggregated = {} + aggregated.update(compute_cross_lang_averages(results, prefix=args.prefix)) + aggregated.update(compute_translation_aggregates(results, prefix=args.prefix)) + + if not aggregated: + print("No aggregate outputs found.", file=sys.stderr) + sys.exit(0) + + out = json.dumps(aggregated, indent=2, ensure_ascii=False) + if args.output: + Path(args.output).write_text(out) + print(f"Written {len(aggregated)} aggregate outputs to {args.output}", file=sys.stderr) + else: + print(out) + + +if __name__ == "__main__": + main() diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py index 7309ea552..3cbf5c693 100644 --- a/src/lighteval/logging/info_loggers.py +++ b/src/lighteval/logging/info_loggers.py @@ -364,9 +364,10 @@ def aggregate(self, task_dict: dict[str, LightevalTask], bootstrap_iters: int = ) else: stderr = get_stderr_function(aggregation=aggregation, number_experiments=bootstrap_iters) - if stderr is not None and len(metric_values) > 1: + stderr_values = [v for v in metric_values if v is not None] + if stderr is not None and len(stderr_values) > 1: try: - self.metric_aggregated[task_name][f"{metric_name}_stderr"] = stderr(metric_values) + self.metric_aggregated[task_name][f"{metric_name}_stderr"] = stderr(stderr_values) except OverflowError: # Is this need or should we just pass? self.metric_aggregated[task_name][f"{metric_name}_stderr"] = float("nan") diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py index 66ed91c3a..4de1c33ad 100644 --- a/src/lighteval/metrics/dynamic_metrics.py +++ b/src/lighteval/metrics/dynamic_metrics.py @@ -33,6 +33,7 @@ Probability, ) from lighteval.metrics.normalizations import ( + LogProbCharNorm, LogProbNormalization, LogProbTokenNorm, get_multilingual_normalizer, @@ -45,10 +46,11 @@ get_extraction_regexes, ) from lighteval.metrics.utils.math_comparison import compare_gold_target -from lighteval.metrics.utils.metric_utils import SampleLevelComputation, SampleLevelMetric +from lighteval.metrics.utils.metric_utils import SampleLevelComputation, SampleLevelMetric, SampleLevelMetricGrouping from lighteval.models.model_output import ModelResponse from lighteval.tasks.requests import Doc, SamplingMethod from lighteval.utils.language import Language +from lighteval.utils.utils import as_list from lighteval.utils.timeout import timeout @@ -266,3 +268,120 @@ def compute(self, doc: Doc, model_response: ModelResponse) -> float: for pred in extracted_predictions ] ) + + +# --------------------------------------------------------------------------- +# MMLU per-category accuracy metric +# --------------------------------------------------------------------------- +import math as _math + +_MMLU_STEM = frozenset([ + "abstract_algebra", "anatomy", "astronomy", "college_biology", "college_chemistry", + "college_computer_science", "college_mathematics", "college_medicine", "college_physics", + "computer_security", "conceptual_physics", "electrical_engineering", "elementary_mathematics", + "high_school_biology", "high_school_chemistry", "high_school_computer_science", + "high_school_mathematics", "high_school_physics", "high_school_statistics", "machine_learning", +]) +_MMLU_HUMANITIES = frozenset([ + "formal_logic", "high_school_european_history", "high_school_us_history", + "high_school_world_history", "international_law", "jurisprudence", "logical_fallacies", + "moral_disputes", "moral_scenarios", "philosophy", "prehistory", "professional_law", + "world_religions", +]) +_MMLU_SOCIAL = frozenset([ + "econometrics", "high_school_geography", "high_school_government_and_politics", + "high_school_macroeconomics", "high_school_microeconomics", "high_school_psychology", + "human_sexuality", "professional_psychology", "public_relations", "security_studies", + "sociology", "us_foreign_policy", +]) +# "Other" = any subject not in the above three sets + +_MMLU_CATS = ("stem", "humanities", "social", "other") + + +def _subject_to_mmlu_cat(subject: str) -> str: + s = subject.lower() + if s in _MMLU_STEM: return "stem" + if s in _MMLU_HUMANITIES: return "humanities" + if s in _MMLU_SOCIAL: return "social" + return "other" + + +def _mean_notnone(vals): + """Mean of non-None values; returns nan if all are None.""" + filtered = [v for v in vals if v is not None] + return float(np.mean(filtered)) if filtered else float("nan") + + +class _MMLUCategoryFn(SampleLevelComputation): + """Computes acc, acc_norm (char), and optionally per-sample BPB, routed to + the correct MMLU category (stem/humanities/social/other) via + ``doc.specific["subject"]``.""" + + def __init__(self, include_bpb: bool = False): + self._acc_fn = LoglikelihoodAcc(logprob_normalization=None) + self._acc_norm_fn = LoglikelihoodAcc(logprob_normalization=LogProbCharNorm()) + self.include_bpb = include_bpb + + def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict: + acc = self._acc_fn.compute(doc=doc, model_response=model_response) + acc_norm = self._acc_norm_fn.compute(doc=doc, model_response=model_response) + + subject = (doc.specific or {}).get("subject", "") + cat = _subject_to_mmlu_cat(subject) + + result: dict = {"acc": acc, "acc_norm": acc_norm} + for c in _MMLU_CATS: + result[f"acc_{c}"] = acc if c == cat else None + result[f"acc_norm_{c}"] = acc_norm if c == cat else None + + if self.include_bpb: + gold_ix = as_list(doc.gold_index)[0] + gold_logprob = model_response.logprobs[gold_ix] + gold_bytes = max(len(doc.choices[gold_ix].encode("utf-8")), 1) + bpb = -gold_logprob / _math.log(2) / gold_bytes + result["bpb"] = bpb + for c in _MMLU_CATS: + result[f"bpb_{c}"] = bpb if c == cat else None + + return result + + +def _make_mmlu_category_grouping(include_bpb: bool) -> SampleLevelMetricGrouping: + keys_acc = ( + ["acc", "acc_norm"] + + [f"acc_{c}" for c in _MMLU_CATS] + + [f"acc_norm_{c}" for c in _MMLU_CATS] + ) + keys_bpb = ["bpb"] + [f"bpb_{c}" for c in _MMLU_CATS] if include_bpb else [] + all_keys = keys_acc + keys_bpb + + corpus_fns: dict = { + "acc": np.mean, + "acc_norm": np.mean, + **{f"acc_{c}": _mean_notnone for c in _MMLU_CATS}, + **{f"acc_norm_{c}": _mean_notnone for c in _MMLU_CATS}, + } + if include_bpb: + corpus_fns["bpb"] = np.mean + corpus_fns.update({f"bpb_{c}": _mean_notnone for c in _MMLU_CATS}) + + higher_is_better = { + **{k: True for k in keys_acc}, + **{k: False for k in keys_bpb}, + } + + return SampleLevelMetricGrouping( + metric_name=all_keys, + sample_level_fn=_MMLUCategoryFn(include_bpb=include_bpb), + corpus_level_fn=corpus_fns, + category=SamplingMethod.LOGPROBS, + higher_is_better=higher_is_better, + ) + + +MMLUCategoryGroupingCF = _make_mmlu_category_grouping(include_bpb=True) +"""CF variant: reports acc, acc_norm, bpb — overall and per MMLU category (stem/humanities/social/other).""" + +MMLUCategoryGroupingMCF = _make_mmlu_category_grouping(include_bpb=False) +"""MCF variant: reports acc, acc_norm — overall and per MMLU category (no BPB for label-token scoring).""" diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index c1164b9a3..41053de70 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -32,6 +32,7 @@ from lighteval.metrics.harness_compatibility.drop import DropMetrics from lighteval.metrics.harness_compatibility.truthful_qa import TruthfulqaMCMetrics from lighteval.metrics.metrics_corpus import ( + CorpusLevelCOMETMetric, CorpusLevelF1Score, CorpusLevelPerplexityMetric, CorpusLevelTranslationMetric, @@ -66,6 +67,7 @@ remove_braces_and_strip, ) from lighteval.metrics.sample_preparator import ( + COMETPreparator, GenerativePreparator, LoglikelihoodPreparator, PerplexityPreparator, @@ -290,6 +292,13 @@ class Metrics(Enum): corpus_level_fn=CorpusLevelTranslationMetric("chrf++"), higher_is_better=True, ) + comet22 = CorpusLevelMetric( + metric_name="comet22", + sample_level_fn=COMETPreparator(), + category=SamplingMethod.GENERATIVE, + corpus_level_fn=CorpusLevelCOMETMetric("Unbabel/wmt22-comet-da"), + higher_is_better=True, + ) copyright = SampleLevelMetricGrouping( metric_name=[ "longest_common_prefix_length", diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 92c2c574a..7b4b9b4cf 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -35,6 +35,7 @@ import sklearn.metrics from lighteval.metrics.sample_preparator import ( + COMETCorpusMetricInput, GenerativeCorpusMetricInput, LogprobCorpusMetricInput, PerplexityCorpusMetricInput, @@ -126,9 +127,6 @@ def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""): def get_metric(self): if self.metric_type == "bleu": - import nltk - - nltk.download("punkt_tab") return sacrebleu.BLEU(trg_lang=self.lang) elif self.metric_type == "chrf": return sacrebleu.CHRF() @@ -190,3 +188,52 @@ def compute_corpus(self, items: list[PerplexityCorpusMetricInput]): return math.exp(-sum(logprobs) / sum(weights)) if self.metric_type == "bits_per_byte": return -sum(logprobs) / sum(weights) * 1 / math.log(2) + + +class CorpusLevelCOMETMetric(CorpusLevelComputation): + """Corpus-level COMET-22 metric using the Unbabel/wmt22-comet-da model. + + Requires `unbabel-comet`: pip install lighteval[multilingual] (or pip install unbabel-comet). + + The COMET model is loaded once per process (class-level cache) and reused across all + translation language pairs in the same evaluation run. + + On multi-GPU nodes, all available GPUs are used automatically via the `gpus` parameter + of unbabel-comet's predict() API. batch_size_per_gpu=256 is tuned for H100 96GB; + reduce to 64 for A100 40GB or 32 for 16GB GPUs. + """ + + _model = None # class-level cache: loaded once per process + + def __init__(self, model_name: str = "Unbabel/wmt22-comet-da", batch_size_per_gpu: int = 256): + self.model_name = model_name + self.batch_size_per_gpu = batch_size_per_gpu + + def _load_model(self): + if CorpusLevelCOMETMetric._model is None: + try: + from comet import download_model, load_from_checkpoint + except ImportError: + raise ImportError( + "COMET metric requires `unbabel-comet`. " + "Install with: pip install lighteval[multilingual]" + ) + CorpusLevelCOMETMetric._model = load_from_checkpoint(download_model(self.model_name)) + return CorpusLevelCOMETMetric._model + + def compute_corpus(self, items: list[COMETCorpusMetricInput]) -> float: + import torch + + model = self._load_model() + data = [ + {"src": i.source, "mt": i.hyp, "ref": i.ref[0] if i.ref else ""} + for i in items + ] + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + output = model.predict( + data, + batch_size=self.batch_size_per_gpu, + gpus=num_gpus, + progress_bar=False, + ) + return float(np.mean(output.scores)) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 5a1176b40..4863d7239 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -35,7 +35,6 @@ from huggingface_hub import HfApi from nltk.metrics.distance import edit_distance from nltk.tokenize import word_tokenize -from nltk.tokenize.treebank import TreebankWordTokenizer from nltk.translate.bleu_score import sentence_bleu from pydantic import BaseModel from scipy.stats import hypergeom @@ -51,6 +50,7 @@ remove_braces, remove_braces_and_strip, ) +from lighteval.metrics.utils.nltk_resources import ensure_nltk_resource from lighteval.metrics.utils.judge_utils import ( get_judge_prompt_simpleqa, process_judge_response_simpleqa, @@ -870,9 +870,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs): Returns: float: Score over the current sample's items. """ - import nltk - - nltk.download("punkt_tab") + ensure_nltk_resource("tokenizers/punkt_tab", "punkt_tab") golds = doc.get_golds() predictions = model_response.final_text return np.mean([self._bleu_score(golds, p) for p in predictions]) @@ -889,7 +887,9 @@ def _bleu_score(self, gold: list[str], pred: str): """ weights = [1 if ix == self.n_gram else 0 for ix in range(1, 5)] return sentence_bleu( - [word_tokenize(g) for g in gold], word_tokenize(pred), weights=weights + [word_tokenize(g) for g in gold], + word_tokenize(pred), + weights=weights, ) diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index 277dde993..a80a1e573 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -60,6 +60,13 @@ class PerplexityCorpusMetricInput(CorpusMetricInput): weights: list[int] +@dataclass +class COMETCorpusMetricInput(CorpusMetricInput): + source: str + hyp: str + ref: list[str] + + class Preparator: pass @@ -93,6 +100,19 @@ def __str__(self): return f"{self.__class__.__name__}({', '.join(attr_strs)})" +class COMETPreparator(Preparator): + @staticmethod + def prepare(doc: Doc, model_response: ModelResponse, **kwargs) -> "COMETCorpusMetricInput": + """Prepares a translation example for COMET scoring. + + Reads the source sentence from doc.specific["source_text"] (set by the translation + template), combines it with the model's hypothesis and the gold reference(s). + """ + source = (doc.specific or {}).get("source_text", "") + golds = as_list(doc.get_golds()) + return COMETCorpusMetricInput(source=source, hyp=model_response.final_text, ref=golds) + + class LoglikelihoodPreparator(Preparator): def __init__(self, is_single_token: bool = False): """Init. diff --git a/src/lighteval/metrics/utils/nltk_resources.py b/src/lighteval/metrics/utils/nltk_resources.py new file mode 100644 index 000000000..bf0a9a474 --- /dev/null +++ b/src/lighteval/metrics/utils/nltk_resources.py @@ -0,0 +1,28 @@ +import nltk + + +def ensure_nltk_resource(resource_path: str, package_name: str) -> None: + try: + nltk.data.find(resource_path) + return + except LookupError: + pass + + error_message = ( + f"Required NLTK resource '{package_name}' is not available locally. " + f"Tried to resolve '{resource_path}' and attempted an automatic download, but it failed. " + "Download it in advance on a machine with network access and set NLTK_DATA to the cached directory." + ) + + try: + downloaded = nltk.download(package_name, quiet=True) + except Exception as exc: + raise RuntimeError(error_message) from exc + + if not downloaded: + raise RuntimeError(error_message) + + try: + nltk.data.find(resource_path) + except LookupError as exc: + raise RuntimeError(error_message) from exc diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 2a59bc6d0..736e06a38 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -707,7 +707,12 @@ def _padded_greedy_until( # NOTE: we are assuming all items in a batch behave similarly (same # stop_tokens and max_tokens genrated) which is not necessarily # the case! Because of that we only use batch size of 1 - stop_tokens = [self.tokenizer.eos_token] + batch[0].stop_sequences if len(batch[0].stop_sequences) > 0 else [self.tokenizer.eos_token] + batch_stop_sequences = list(batch[0].stop_sequences) + stop_tokens = ( + [self.tokenizer.eos_token] + batch_stop_sequences + if len(batch_stop_sequences) > 0 + else [self.tokenizer.eos_token] + ) max_new_tokens = batch[0].generation_size num_samples = batch[0].num_samples diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 04629379c..3f45315ba 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -21,13 +21,18 @@ # SOFTWARE. import functools +import json import logging +import os import random +import tarfile +import urllib.request from dataclasses import asdict, dataclass, field +from pathlib import Path, PurePosixPath from typing import Callable -from datasets import DatasetDict, load_dataset -from huggingface_hub import TextGenerationInputGrammarType +from datasets import Dataset, DatasetDict, load_dataset +from huggingface_hub import TextGenerationInputGrammarType, hf_hub_download, list_repo_files from inspect_ai.dataset import Sample from multiprocess import Pool from pytablewriter import MarkdownTableWriter @@ -45,6 +50,249 @@ logger = logging.getLogger(__name__) +def _get_manual_dataset_cache_dir() -> Path: + hf_home = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")) + cache_dir = hf_home / "lighteval_manual_datasets" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def _matches_hub_split_file(path: str, split: str, ext: str, config_name: str | None) -> bool: + posix_path = PurePosixPath(path) + filename = posix_path.name + if not filename.endswith(ext): + return False + + if filename.startswith(f"{split}-") or filename == f"{split}{ext}": + if config_name: + return posix_path.parent.as_posix() == config_name + return posix_path.parent.as_posix() in (".", "") + + if not config_name: + return False + + if filename == f"{config_name}_{split}{ext}" and posix_path.parent.as_posix() in ( + "data", + ".", + "", + ): + return True + + return False + + +def _load_hub_raw_dataset_files( + dataset_path: str, + config_name: str | None, + splits: list[str], + *, + revision: str | None = None, + local_files_only: bool = False, +) -> DatasetDict | None: + repo_files = list_repo_files( + repo_id=dataset_path, + repo_type="dataset", + revision=revision, + ) + + data_files: dict[str, list[str]] = {} + data_format: str | None = None + for split in splits: + matched_files = [] + matched_format = None + for ext, fmt in ((".parquet", "parquet"), (".jsonl", "json"), (".json", "json")): + matched_files = [ + hf_hub_download( + repo_id=dataset_path, + repo_type="dataset", + filename=repo_file, + revision=revision, + local_files_only=local_files_only, + ) + for repo_file in repo_files + if _matches_hub_split_file(repo_file, split, ext, config_name) + ] + if matched_files: + matched_format = fmt + break + + if not matched_files or matched_format is None: + return None + + if data_format is None: + data_format = matched_format + elif data_format != matched_format: + raise ValueError( + f"Inconsistent raw file formats for dataset {dataset_path}: {data_format} vs {matched_format}" + ) + + data_files[split] = sorted(matched_files) + + if not data_files or data_format is None: + return None + + return load_dataset(data_format, data_files=data_files) + + +def _load_mgsm_dataset(config_name: str) -> DatasetDict: + import csv + import importlib.util + + from huggingface_hub import hf_hub_download + + exemplars_path = hf_hub_download( + repo_id="juletxara/mgsm", + repo_type="dataset", + filename="exemplars.py", + ) + tsv_path = hf_hub_download( + repo_id="juletxara/mgsm", + repo_type="dataset", + filename=f"mgsm_{config_name}.tsv", + ) + + spec = importlib.util.spec_from_file_location("mgsm_exemplars", exemplars_path) + exemplars_module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(exemplars_module) + + train_rows = [] + examples = exemplars_module.MGSM_EXEMPLARS[config_name] + number_answers = exemplars_module.EXEMPLAR_NUMBER_ANSWERS + equation_solutions = exemplars_module.EXEMPLAR_EQUATION_SOLUTIONS + for key, data in examples.items(): + idx = int(key) - 1 + train_rows.append( + { + "question": data["q"], + "answer": data["a"], + "answer_number": number_answers[idx], + "equation_solution": equation_solutions[idx], + } + ) + + test_rows = [] + with open(tsv_path, encoding="utf-8") as csv_file: + csv_reader = csv.reader(csv_file, quotechar='"', delimiter="\t") + for row in csv_reader: + test_rows.append( + { + "question": row[0], + "answer": None, + "answer_number": int(row[1].replace(",", "")), + "equation_solution": None, + } + ) + + return DatasetDict( + { + "train": Dataset.from_list(train_rows), + "test": Dataset.from_list(test_rows), + } + ) + + +def _load_wmt24pp_dataset(config_name: str, local_files_only: bool = False) -> DatasetDict: + from huggingface_hub import hf_hub_download + + jsonl_path = hf_hub_download( + repo_id="google/wmt24pp", + repo_type="dataset", + filename=f"{config_name}.jsonl", + local_files_only=local_files_only, + ) + + rows = [] + with open(jsonl_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + + return DatasetDict({"train": Dataset.from_list(rows)}) + + +def _load_flores_dataset(config_name: str, local_files_only: bool = False) -> DatasetDict: + cache_dir = _get_manual_dataset_cache_dir() / "flores" + cache_dir.mkdir(parents=True, exist_ok=True) + + archive_path = cache_dir / "flores200_dataset.tar.gz" + extract_dir = cache_dir / "flores200_dataset" + + if not extract_dir.exists(): + if not archive_path.exists(): + if local_files_only: + raise FileNotFoundError( + f"FLORES archive not found in local cache: {archive_path}" + ) + urllib.request.urlretrieve( + "https://dl.fbaipublicfiles.com/nllb/flores200_dataset.tar.gz", + archive_path, + ) + + with tarfile.open(archive_path, "r:gz") as tar: + tar.extractall(cache_dir) + + if "-" in config_name: + langs = config_name.split("-", 1) + else: + langs = [config_name] + + def _sentence_path(split: str, lang: str) -> Path: + return extract_dir / split / f"{lang}.{split}" + + def _metadata_path(split: str) -> Path: + return extract_dir / f"metadata_{split}.tsv" + + def _build_split(split: str) -> Dataset: + with open(_metadata_path(split), encoding="utf-8") as meta_file: + metadata_lines = [line.rstrip("\n") for line in meta_file] + header, rows = metadata_lines[0], metadata_lines[1:] + if header.split("\t")[:5] != [ + "URL", + "domain", + "topic", + "has_image", + "has_hyperlink", + ]: + raise ValueError(f"Unexpected FLORES metadata header: {header}") + + sentences_by_lang: dict[str, list[str]] = {} + for lang in langs: + with open(_sentence_path(split, lang), encoding="utf-8") as sentence_file: + sentences_by_lang[lang] = [ + line.rstrip("\n") for line in sentence_file + ] + + data = [] + for idx, metadata in enumerate(rows, start=1): + url, domain, topic, has_image, has_hyperlink = metadata.split("\t") + row = { + "id": idx, + "URL": url, + "domain": domain, + "topic": topic, + "has_image": 1 if has_image == "yes" else 0, + "has_hyperlink": 1 if has_hyperlink == "yes" else 0, + } + if len(langs) == 1: + row["sentence"] = sentences_by_lang[langs[0]][idx - 1] + else: + for lang in langs: + row[f"sentence_{lang}"] = sentences_by_lang[lang][idx - 1] + data.append(row) + + return Dataset.from_list(data) + + return DatasetDict( + { + "dev": _build_split("dev"), + "devtest": _build_split("devtest"), + } + ) + + @dataclass class LightevalTaskConfig: """Configuration dataclass for a LightevalTask. @@ -521,6 +769,42 @@ def download_dataset_worker( if not (_is_script_err or _is_offline): raise + if _is_script_err and task.dataset_path == "juletxara/mgsm": + dataset = _load_mgsm_dataset(task.dataset_config_name) + if task.dataset_filter is not None: + dataset = dataset.filter(task.dataset_filter) + return dataset # type: ignore + + if (_is_script_err or _is_offline) and task.dataset_path == "google/wmt24pp": + dataset = _load_wmt24pp_dataset( + task.dataset_config_name, + local_files_only=_is_offline, + ) + if task.dataset_filter is not None: + dataset = dataset.filter(task.dataset_filter) + return dataset # type: ignore + + if (_is_script_err or _is_offline) and task.dataset_path == "facebook/flores": + dataset = _load_flores_dataset( + task.dataset_config_name, + local_files_only=_is_offline, + ) + if task.dataset_filter is not None: + dataset = dataset.filter(task.dataset_filter) + return dataset # type: ignore + + if _is_script_err: + dataset = _load_hub_raw_dataset_files( + dataset_path=task.dataset_path, + config_name=task.dataset_config_name, + splits=list(task.config.hf_avail_splits or []), + revision=task.dataset_revision, + ) + if dataset is not None: + if task.dataset_filter is not None: + dataset = dataset.filter(task.dataset_filter) + return dataset # type: ignore + _splits = list(task.config.hf_avail_splits or []) if not _splits: _splits = ( @@ -554,9 +838,24 @@ def download_dataset_worker( if task.dataset_config_name else _rev_dir ) + def _matches_cached_file(_filename: str, _split: str, _ext: str) -> bool: + if not _filename.endswith(_ext): + return False + if _filename.startswith(f"{_split}-") or _filename == f"{_split}{_ext}": + return True + return bool( + task.dataset_config_name + and _filename == f"{task.dataset_config_name}_{_split}{_ext}" + ) + # Use os.listdir (not glob) so symlinks in the hub # cache are resolved correctly. - for _search_dir in [_config_dir, _os.path.join(_config_dir, "data")]: + for _search_dir in [ + _config_dir, + _os.path.join(_config_dir, "data"), + _rev_dir, + _os.path.join(_rev_dir, "data"), + ]: if not _os.path.isdir(_search_dir): continue try: @@ -568,11 +867,7 @@ def download_dataset_worker( _pq = sorted( _os.path.join(_search_dir, f) for f in _entries - if f.endswith(_ext) - and ( - f.startswith(f"{_split}-") - or f == f"{_split}{_ext}" - ) + if _matches_cached_file(f, _split, _ext) ) if _pq: _data_files[_split] = _pq diff --git a/src/lighteval/tasks/multilingual/tasks/flores200.py b/src/lighteval/tasks/multilingual/tasks/flores200.py index f6adbef78..3d536bb3b 100644 --- a/src/lighteval/tasks/multilingual/tasks/flores200.py +++ b/src/lighteval/tasks/multilingual/tasks/flores200.py @@ -6,7 +6,8 @@ facebook/flores abstract: -Flores200 multilingual benchmark. +Flores200 English-centric translation benchmark. +Few-shot evaluation uses `dev` for exemplars and `devtest` for scoring. languages: arabic, armenian, bengali, cyrillic, devanagari, ethiopic, georgian, greek, @@ -20,19 +21,16 @@ paper: """ -from itertools import permutations - -from lighteval.metrics.metrics import Metrics -from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.templates.translation import get_translation_prompt_function -from lighteval.tasks.templates.utils.formulation import ( - CFFormulation, +from lighteval.tasks.multilingual.utils.translation import ( + TRANSLATION_METRICS, + EnglishCentricTranslationConfig, + build_english_centric_translation_tasks, ) from lighteval.utils.language import Language, manage_duplicate_language_codes flores_200_languages = [ - # "ace_Arab", + "ace_Arab", "ace_Latn", "acm_Arab", "acq_Arab", @@ -43,7 +41,7 @@ "amh_Ethi", "apc_Arab", "arb_Arab", - # "arb_Latn", + "arb_Latn", "ars_Arab", "ary_Arab", "arz_Arab", @@ -60,7 +58,7 @@ "bem_Latn", "ben_Beng", "bho_Deva", - # "bjn_Arab", + "bjn_Arab", "bjn_Latn", "bod_Tibt", "bos_Latn", @@ -115,10 +113,10 @@ "kac_Latn", "kam_Latn", "kan_Knda", - # "kas_Arab", + "kas_Arab", "kas_Deva", "kat_Geor", - # "knc_Arab", + "knc_Arab", "knc_Latn", "kaz_Cyrl", "kbp_Latn", @@ -148,7 +146,7 @@ "mai_Deva", "mal_Mlym", "mar_Deva", - # "min_Arab", + "min_Arab", "min_Latn", "mkd_Cyrl", "plt_Latn", @@ -233,11 +231,14 @@ "yor_Latn", "yue_Hant", "zho_Hans", - # "zho_Hant", + "zho_Hant", "zsm_Latn", "zul_Latn", ] +_ENGLISH_FLORES_CODE = "eng_Latn" +_FLORES_METRICS = TRANSLATION_METRICS + def flores_adapter(lang1, lang2): return lambda line: { @@ -246,25 +247,30 @@ def flores_adapter(lang1, lang2): } -TASKS_TABLE = [ - LightevalTaskConfig( - name=f"flores200:{lang1}-{lang2}", - prompt_function=get_translation_prompt_function( - source_language=Language(manage_duplicate_language_codes(lang1.split("_")[0])), - target_language=Language(manage_duplicate_language_codes(lang2.split("_")[0])), - adapter=flores_adapter(lang1, lang2), - formulation=CFFormulation(), - ), - hf_repo="facebook/flores", - hf_subset=f"{lang1}-{lang2}", - hf_avail_splits=["dev", "devtest"], - evaluation_splits=["devtest"], - few_shots_split="dev", - few_shots_select=None, - generation_size=300, - metrics=[Metrics.chrf_plus, Metrics.bleu, Metrics.bleu_1, Metrics.bleu_4], - stop_sequence=["\n"], - version=0, - ) - for (lang1, lang2) in permutations(flores_200_languages, 2) -] +def _flores_pair_subset(lang_code: str) -> str: + return "-".join(sorted((_ENGLISH_FLORES_CODE, lang_code))) + + +TASKS_TABLE = build_english_centric_translation_tasks( + base_name="flores200", + hf_repo="facebook/flores", + language_configs=[ + EnglishCentricTranslationConfig( + task_id=lang_code, + target_language=Language( + manage_duplicate_language_codes(lang_code.split("_")[0]) + ), + forward_hf_subset=_flores_pair_subset(lang_code), + reverse_hf_subset=_flores_pair_subset(lang_code), + forward_adapter=flores_adapter(_ENGLISH_FLORES_CODE, lang_code), + reverse_adapter=flores_adapter(lang_code, _ENGLISH_FLORES_CODE), + hf_avail_splits=("dev", "devtest"), + evaluation_splits=("devtest",), + few_shots_split="dev", + few_shots_select="random_sampling_from_train", + ) + for lang_code in flores_200_languages + if lang_code != _ENGLISH_FLORES_CODE + ], + metrics=_FLORES_METRICS, +) diff --git a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py index 894f15a3c..6167a5018 100644 --- a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py @@ -9,11 +9,19 @@ Translated MMLU using both professional and non-professional translators. Contains tags for cultural sensitivity. +Refactored to use unified :cf / :mcf suffixes with 3-part names (base:lang:suffix) +so group selectors like `global_mmlu:cf|5` expand automatically. +English is excluded — use mmlu.py for English evaluation. + +Per-category metrics (STEM / Humanities / Social / Other) are reported via +MMLUCategoryGroupingCF / MCF, routed via doc.specific["subject"]. + languages: -amharic, arabic, bengali, chinese, czech, dutch, english, french, german, -hebrew, hindi, indonesian, italian, japanese, korean, malay, norwegian, polish, -portuguese, romanian, russian, serbian, spanish, swahili, swedish, tamil, -telugu, thai, turkish, ukrainian, urdu, vietnamese, yoruba, zulu +amharic, arabic, bengali, chinese, czech, dutch, french, german, greek, +hausa, hebrew, hindi, igbo, indonesian, italian, japanese, kirghiz, korean, +lithuanian, malagasy, malay, nepali, nyanja, persian, polish, portuguese, +romanian, russian, serbian, shona, sinhala, somali, spanish, swahili, +swedish, tagalog, telugu, turkish, ukrainian, vietnamese, yoruba tags: knowledge, multilingual, multiple-choice @@ -22,158 +30,149 @@ https://huggingface.co/papers/2412.03304 """ -from functools import partial from string import ascii_uppercase from langcodes import standardize_tag -from lighteval.metrics.dynamic_metrics import ( - LogLikelihoodAccMetric, -) -from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm +from lighteval.metrics.dynamic_metrics import MMLUCategoryGroupingCF, MMLUCategoryGroupingMCF +from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation +from lighteval.tasks.requests import Doc from lighteval.tasks.templates.multichoice import get_mcq_prompt_function -from lighteval.tasks.templates.utils.formulation import ( - MCFFormulation, -) +from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation from lighteval.utils.language import Language -MMLU_SUBSETS = [ - "abstract_algebra", - "anatomy", - "astronomy", - "business_ethics", - "clinical_knowledge", - "college_biology", - "college_chemistry", - "college_computer_science", - "college_mathematics", - "college_medicine", - "college_physics", - "computer_security", - "conceptual_physics", - "econometrics", - "electrical_engineering", - "elementary_mathematics", - "formal_logic", - "global_facts", - "high_school_biology", - "high_school_chemistry", - "high_school_computer_science", - "high_school_european_history", - "high_school_geography", - "high_school_government_and_politics", - "high_school_macroeconomics", - "high_school_mathematics", - "high_school_microeconomics", - "high_school_physics", - "high_school_psychology", - "high_school_statistics", - "high_school_us_history", - "high_school_world_history", - "human_aging", - "human_sexuality", - "international_law", - "jurisprudence", - "logical_fallacies", - "machine_learning", - "management", - "marketing", - "medical_genetics", - "miscellaneous", - "moral_disputes", - "moral_scenarios", - "nutrition", - "philosophy", - "prehistory", - "professional_accounting", - "professional_law", - "professional_medicine", - "professional_psychology", - "public_relations", - "security_studies", - "sociology", - "us_foreign_policy", - "virology", - "world_religions", +# All available non-English configs in Global-MMLU. The dataset exposes +# language configs as BCP-47-style short tags (for example, `fa`, `fil`, `sw`) +# so the task names stay on Language enum values while hf_subset uses +# standardize_tag(language.value). +_LANGUAGES = [ + Language.AMHARIC, + Language.ARABIC, + Language.BENGALI, + Language.CHINESE, + Language.CZECH, + Language.DUTCH, + Language.FRENCH, + Language.GERMAN, + Language.GREEK, + Language.HAUSA, + Language.HEBREW, + Language.HINDI, + Language.IGBO, + Language.INDONESIAN, + Language.ITALIAN, + Language.JAPANESE, + Language.KIRGHIZ, + Language.KOREAN, + Language.LITHUANIAN, + Language.MALAGASY, + Language.MALAY, + Language.NEPALI, + Language.NYANJA, + Language.PERSIAN, + Language.SPANISH, + Language.POLISH, + Language.PORTUGUESE, + Language.ROMANIAN, + Language.RUSSIAN, + Language.SERBIAN, + Language.SHONA, + Language.SINHALA, + Language.SOMALI, + Language.SWEDISH, + Language.SWAHILI, + Language.TAGALOG, + Language.TELUGU, + Language.TURKISH, + Language.UKRAINIAN, + Language.VIETNAMESE, + Language.YORUBA, ] +def _make_global_mmlu_cf_prompt(language: Language): + """CF prompt: score each full answer text. Sets doc.specific with subject.""" + inner = get_mcq_prompt_function(language, _cf_adapter, formulation=CFFormulation()) + + def prompt_fn(line, task_name=None): + doc = inner(line, task_name) + if doc is not None: + doc.specific = {"subject": (line.get("subject") or "").lower()} + return doc + + return prompt_fn + + +def _make_global_mmlu_mcf_prompt(language: Language): + """MCF prompt: score label tokens A/B/C/D. Sets doc.specific with subject.""" + inner = get_mcq_prompt_function(language, _mcf_adapter, formulation=MCFFormulation()) + + def prompt_fn(line, task_name=None): + doc = inner(line, task_name) + if doc is not None: + doc.specific = {"subject": (line.get("subject") or "").lower()} + return doc + + return prompt_fn + + +def _cf_adapter(line): + choices = [line["option_a"], line["option_b"], line["option_c"], line["option_d"]] + if any(c is None or not str(c).strip() for c in choices): + return None + return { + "question": line["question"], + "choices": choices, + "gold_idx": ascii_uppercase.index(line["answer"]), + } + + +def _mcf_adapter(line): + choices = [line["option_a"], line["option_b"], line["option_c"], line["option_d"]] + if any(c is None or not str(c).strip() for c in choices): + return None + return { + "question": line["question"], + "choices": choices, + "gold_idx": ascii_uppercase.index(line["answer"]), + } + + +_MMLU_CF_METRICS = [MMLUCategoryGroupingCF] +_MMLU_MCF_METRICS = [MMLUCategoryGroupingMCF] + + TASKS_TABLE = [ LightevalTaskConfig( - name=f"global_mmlu_{sensitivity_label.lower()}_{language.value}_{formulation.name.lower()}:{subset}", - prompt_function=get_mcq_prompt_function( - language, - lambda line: { - "question": line["question"], - "choices": [line["option_a"], line["option_b"], line["option_c"], line["option_d"]], - "gold_idx": ascii_uppercase.index(line["answer"]), - }, - formulation=formulation, - ), + name=f"global_mmlu:{language.value}:{suffix}", + prompt_function=prompt_fn(language), hf_repo="CohereForAI/Global-MMLU", hf_subset=standardize_tag(language.value), evaluation_splits=("test",), few_shots_split="dev", - hf_filter=partial( - lambda subset, sensitivity_label, x: x["subject"].lower() == subset - and ( - sensitivity_label == "ALL" or sensitivity_label in x["cultural_sensitivity_label"].replace("-", "UNK") - ) - and all(x[f"option_{opt}"] is not None and x[f"option_{opt}"].strip() for opt in "abcd"), - subset, - sensitivity_label, - ), - metrics=get_metrics_for_formulation( - formulation, - [ - LogLikelihoodAccMetric(normalization=LogProbTokenNorm()), - LogLikelihoodAccMetric(normalization=LogProbCharNorm()), - LogLikelihoodAccMetric(normalization=LogProbPMINorm()), - ], - ), + metrics=metrics, ) - for subset in MMLU_SUBSETS - for language in [ - Language.AMHARIC, - Language.ARABIC, - Language.BENGALI, - Language.CHINESE, - Language.CZECH, - Language.GERMAN, - Language.ENGLISH, - Language.SPANISH, - Language.FRENCH, - Language.HEBREW, - Language.HINDI, - Language.INDONESIAN, - Language.ITALIAN, - Language.JAPANESE, - Language.KOREAN, - Language.MALAY, - Language.DUTCH, - Language.NORWEGIAN, - Language.POLISH, - Language.PORTUGUESE, - Language.ROMANIAN, - Language.RUSSIAN, - Language.SERBIAN, - Language.SWEDISH, - Language.SWAHILI, - Language.TAMIL, - Language.TELUGU, - Language.THAI, - Language.TURKISH, - Language.UKRAINIAN, - Language.URDU, - Language.VIETNAMESE, - Language.YORUBA, - Language.ZULU, + for language in _LANGUAGES + for suffix, prompt_fn, metrics in [ + ("cf", _make_global_mmlu_cf_prompt, _MMLU_CF_METRICS), + ("mcf", _make_global_mmlu_mcf_prompt, _MMLU_MCF_METRICS), ] - for formulation in [ - MCFFormulation(), - ] - for sensitivity_label in ["ALL"] +] + +# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match +TASKS_TABLE += [ + LightevalTaskConfig( + name=f"global_mmlu:{language.value}:mcf_em", + prompt_function=_make_global_mmlu_mcf_prompt(language), + hf_repo="CohereForAI/Global-MMLU", + hf_subset=standardize_tag(language.value), + evaluation_splits=("test",), + few_shots_split="dev", + generation_size=5, + stop_sequence=["\n"], + metrics=[Metrics.exact_match], + ) + for language in _LANGUAGES ] diff --git a/src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py b/src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py new file mode 100644 index 000000000..1de77591d --- /dev/null +++ b/src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py @@ -0,0 +1,131 @@ +""" +name: +Global Mmlu Lite + +dataset: +CohereLabs/Global-MMLU-Lite + +abstract: +A lighter, culturally-annotated subset of MMLU covering 18 languages total; +17 non-English languages evaluated here (400 test samples and 215 dev samples +per language across 43 subjects). Designed for quick multilingual MMLU-style +evaluation. + +English is excluded — use mmlu.py for English evaluation. + +Metrics: +- :cf — completion formulation; reports acc, acc_norm (char), target_bpb +- :mcf — multiple-choice formulation; reports acc, acc_norm (char) + +languages: +arabic, bengali, welsh, german, spanish, french, hindi, indonesian, italian, +japanese, korean, burmese, portuguese, albanian, swahili, yoruba, chinese + +tags: +knowledge, multilingual, multiple-choice + +paper: +https://huggingface.co/datasets/CohereLabs/Global-MMLU-Lite +""" + +from string import ascii_uppercase + +from langcodes import standardize_tag + +from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.normalizations import LogProbCharNorm +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.templates.multichoice import get_mcq_prompt_function +from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation +from lighteval.utils.language import Language + + +_CF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), + Metrics.target_bits_per_byte, +] + +_MCF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), +] + +# 18 languages in Global-MMLU-Lite, English excluded +_LANGUAGES = [ + Language.ARABIC, # ar + Language.BENGALI, # bn + Language.WELSH, # cy + Language.GERMAN, # de + Language.SPANISH, # es + Language.FRENCH, # fr + Language.HINDI, # hi + Language.INDONESIAN, # id + Language.ITALIAN, # it + Language.JAPANESE, # ja + Language.KOREAN, # ko + Language.BURMESE, # my + Language.PORTUGUESE, # pt + Language.ALBANIAN, # sq + Language.SWAHILI, # sw + Language.YORUBA, # yo + Language.CHINESE, # zh +] + + +def _global_mmlu_lite_adapter(line): + return { + "question": line["question"], + "choices": [line["option_a"], line["option_b"], line["option_c"], line["option_d"]], + "gold_idx": ascii_uppercase.index(line["answer"]), + } + + +def _valid_options_filter(x): + """Skip rows where any option is missing.""" + return all(x[f"option_{opt}"] is not None and x[f"option_{opt}"].strip() for opt in "abcd") + + +TASKS_TABLE = [ + LightevalTaskConfig( + name=f"global_mmlu_lite:{language.value}:{suffix}", + prompt_function=get_mcq_prompt_function( + language, + _global_mmlu_lite_adapter, + formulation=formulation, + ), + hf_repo="CohereLabs/Global-MMLU-Lite", + hf_subset=standardize_tag(language.value), + evaluation_splits=("test",), + few_shots_split="dev", + hf_filter=_valid_options_filter, + metrics=metrics, + ) + for language in _LANGUAGES + for suffix, formulation, metrics in [ + ("cf", CFFormulation(), _CF_METRICS), + ("mcf", MCFFormulation(), _MCF_METRICS), + ] +] + +# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match +TASKS_TABLE += [ + LightevalTaskConfig( + name=f"global_mmlu_lite:{language.value}:mcf_em", + prompt_function=get_mcq_prompt_function( + language, + _global_mmlu_lite_adapter, + formulation=MCFFormulation(), + ), + hf_repo="CohereLabs/Global-MMLU-Lite", + hf_subset=standardize_tag(language.value), + evaluation_splits=("test",), + few_shots_split="dev", + hf_filter=_valid_options_filter, + generation_size=5, + stop_sequence=["\n"], + metrics=[Metrics.exact_match], + ) + for language in _LANGUAGES +] diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py index ba5d9a323..1f03a0b26 100644 --- a/src/lighteval/tasks/multilingual/tasks/mgsm.py +++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py @@ -6,37 +6,57 @@ juletxara/mgsm abstract: -Mgsm multilingual benchmark. +MGSM (Multilingual Grade School Math) is a multilingual benchmark testing +mathematical reasoning across languages, derived from GSM8K. + +Refactored to use unified :gen suffix consistent with English gsm8k.py. +English is excluded — use gsm8k.py for English evaluation. +Reports both expr_gold_metric (math expression parser) and +MultilingualQuasiExactMatchMetric (language-aware fuzzy match, handles +non-ASCII digit systems like Japanese/Thai). languages: -bengali, chinese, english, french, german, japanese, russian, spanish, swahili, -telugu, thai +bengali, french, german, japanese, russian, spanish, swahili, telugu, thai, +chinese tags: math, multilingual, reasoning paper: +https://arxiv.org/abs/2210.03057 """ from langcodes import standardize_tag -from lighteval.metrics.dynamic_metrics import ( - MultilingualQuasiExactMatchMetric, -) +from lighteval.metrics.dynamic_metrics import MultilingualQuasiExactMatchMetric +from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.templates.qa import get_qa_prompt_function from lighteval.utils.language import Language +# Languages covered by juletxara/mgsm (English excluded — use gsm8k.py) +_LANGUAGES = [ + Language.SPANISH, + Language.FRENCH, + Language.GERMAN, + Language.RUSSIAN, + Language.CHINESE, + Language.JAPANESE, + Language.THAI, + Language.SWAHILI, + Language.BENGALI, + Language.TELUGU, +] + + TASKS_TABLE = [ LightevalTaskConfig( - name=f"mgsm_{language.value}", + name=f"mgsm:{language.value}:gen", prompt_function=get_qa_prompt_function( language, lambda line: { "question": line["question"], - # The cot is available but we have no use: - # line["answer"] "choices": [str(line["answer_number"])], }, ), @@ -44,23 +64,12 @@ hf_subset=standardize_tag(language.value), evaluation_splits=("test",), few_shots_split="train", - generation_size=25, + generation_size=512, metrics=[ + Metrics.expr_gold_metric, MultilingualQuasiExactMatchMetric(language, "full"), ], - stop_sequence=("\n",), + stop_sequence=["\n"], ) - for language in [ - Language.ENGLISH, - Language.SPANISH, - Language.FRENCH, - Language.GERMAN, - Language.RUSSIAN, - Language.CHINESE, - Language.JAPANESE, - Language.THAI, - Language.SWAHILI, - Language.BENGALI, - Language.TELUGU, - ] + for language in _LANGUAGES ] diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py b/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py index f7485124f..200f14f11 100644 --- a/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py +++ b/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py @@ -9,10 +9,11 @@ ARC (AI2 Reasoning Challenge) is a dataset for question answering that requires reasoning. It consists of multiple-choice science questions from 3rd to 9th grade exams. The dataset is split into two parts: ARC-Easy and ARC-Challenge. -ARC-Easy contains questions that can be answered correctly by both humans and -simple baseline models. ARC-Challenge contains questions that are difficult for -both humans and current AI systems. Similar to MMLU, ARC tasks uses PMI -normalization by default but only for the challenge set. +ARC-Challenge contains questions that are difficult for both humans and current +AI systems. + +Refactored to use unified :cf / :mcf suffixes consistent with English arc.py. +English is excluded — use arc.py for English evaluation. languages: arabic, bengali, catalan, chinese, croatian, danish, dutch, french, german, @@ -31,33 +32,91 @@ from langcodes import standardize_tag -from lighteval.metrics.dynamic_metrics import ( - LogLikelihoodAccMetric, -) -from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm +from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.normalizations import LogProbCharNorm from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.multichoice import get_mcq_prompt_function -from lighteval.tasks.templates.utils.formulation import ( - CFFormulation, - HybridFormulation, - MCFFormulation, -) +from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation from lighteval.utils.language import Language +_CF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), + Metrics.target_bits_per_byte, +] + +_MCF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), +] + +# Languages covered by jon-tow/okapi_arc_challenge (English excluded) +_LANGUAGES = [ + Language.RUSSIAN, + Language.GERMAN, + Language.CHINESE, + Language.FRENCH, + Language.SPANISH, + Language.ITALIAN, + Language.DUTCH, + Language.VIETNAMESE, + Language.INDONESIAN, + Language.ARABIC, + Language.HUNGARIAN, + Language.ROMANIAN, + Language.DANISH, + Language.SLOVAK, + Language.UKRAINIAN, + Language.CATALAN, + Language.SERBIAN, + Language.CROATIAN, + Language.HINDI, + Language.BENGALI, + Language.TAMIL, + Language.NEPALI, + Language.MALAYALAM, + Language.MARATHI, + Language.TELUGU, + Language.KANNADA, +] + + +def _arc_adapter(line): + if "question" in line and "choices" in line: + choices = line["choices"]["text"] + answer_key = line["answerKey"] + else: + choices = [ + line[key] + for key in ("option_a", "option_b", "option_c", "option_d", "option_e") + if line.get(key) + ] + answer_key = line["answer"] + return { + "question": line["instruction"], + "choices": choices, + "gold_idx": int(answer_key) - 1 + if answer_key.isdigit() + else ascii_uppercase.index(answer_key), + } + + return { + "question": line["question"], + "choices": choices, + "gold_idx": int(answer_key) - 1 + if answer_key.isdigit() + else ascii_uppercase.index(answer_key), + } + + TASKS_TABLE = [ LightevalTaskConfig( - name=f"mlmm_arc_{language.value}_{formulation.name.lower()}:challenge", + name=f"mlmm_arc:{language.value}:{suffix}", prompt_function=get_mcq_prompt_function( language, - lambda line: { - "question": line["question"], - "choices": line["choices"]["text"], - "gold_idx": int(line["answerKey"]) - 1 - if line["answerKey"].isdigit() - else ascii_uppercase.index(line["answerKey"]), - }, + _arc_adapter, formulation=formulation, ), hf_repo="jon-tow/okapi_arc_challenge", @@ -65,46 +124,32 @@ hf_revision="823d5d7bfaf8974a3ab52a825b6cf4903b35dbc4", evaluation_splits=("test",), few_shots_split="train", - metrics=get_metrics_for_formulation( - formulation, - [ - LogLikelihoodAccMetric(normalization=LogProbTokenNorm()), - LogLikelihoodAccMetric(normalization=LogProbCharNorm()), - LogLikelihoodAccMetric(normalization=LogProbPMINorm()), - ], - ), + metrics=metrics, ) - for language in [ - Language.RUSSIAN, - Language.GERMAN, - Language.CHINESE, - Language.FRENCH, - Language.SPANISH, - Language.ITALIAN, - Language.DUTCH, - Language.VIETNAMESE, - Language.INDONESIAN, - Language.ARABIC, - Language.HUNGARIAN, - Language.ROMANIAN, - Language.DANISH, - Language.SLOVAK, - Language.UKRAINIAN, - Language.CATALAN, - Language.SERBIAN, - Language.CROATIAN, - Language.HINDI, - Language.BENGALI, - Language.TAMIL, - Language.NEPALI, - Language.MALAYALAM, - Language.MARATHI, - Language.TELUGU, - Language.KANNADA, - ] - for formulation in [ - MCFFormulation(), - CFFormulation(), - HybridFormulation(), + for language in _LANGUAGES + for suffix, formulation, metrics in [ + ("cf", CFFormulation(), _CF_METRICS), + ("mcf", MCFFormulation(), _MCF_METRICS), ] ] + +# Greedy variant: MCF-style prompt, generate 1 token, exact match +TASKS_TABLE += [ + LightevalTaskConfig( + name=f"mlmm_arc:{language.value}:mcf_em", + prompt_function=get_mcq_prompt_function( + language, + _arc_adapter, + formulation=MCFFormulation(), + ), + hf_repo="jon-tow/okapi_arc_challenge", + hf_subset=standardize_tag(language.value), + hf_revision="823d5d7bfaf8974a3ab52a825b6cf4903b35dbc4", + evaluation_splits=("test",), + few_shots_split="train", + generation_size=1, + stop_sequence=["\n"], + metrics=[Metrics.exact_match], + ) + for language in _LANGUAGES +] diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py index 3945b8947..bbd631278 100644 --- a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py +++ b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py @@ -3,18 +3,23 @@ Mlmm Hellaswag dataset: -jon-tow/okapi_hellaswag +alexandrainst/m_hellaswag (most languages) +jon-tow/okapi_hellaswag (Chinese) abstract: Hellaswag is a commonsense reasoning task that requires models to complete a given scenario with the most plausible ending. It tests the model's ability to understand and reason about everyday situations and human behavior. -MLMM-Hellaswag: Multilingual adaptation of Hellaswag +MLMM-Hellaswag: Multilingual adaptation of Hellaswag. + +Refactored to use unified :cf / :mcf / :mcf_em suffixes consistent with English hellaswag.py. +HellaSwag has 4 candidate continuations; the MCF formulation labels them A–D and scores (or +generates) the label token. English is excluded — use hellaswag.py for English evaluation. languages: arabic, armenian, basque, bengali, catalan, chinese, croatian, danish, dutch, french, german, gujarati, hindi, hungarian, icelandic, indonesian, italian, -kannada, malayalam, marathi, nepali, norwegian, portuguese, romanian, russian, +kannada, malayalam, marathi, nepali, portuguese, romanian, russian, serbian, slovak, spanish, swedish, tamil, telugu, ukrainian, vietnamese tags: @@ -26,82 +31,127 @@ from langcodes import standardize_tag -from lighteval.metrics.dynamic_metrics import ( - LogLikelihoodAccMetric, -) -from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm +from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.normalizations import LogProbCharNorm from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function -from lighteval.tasks.templates.utils.formulation import ( - CFFormulation, - HybridFormulation, - MCFFormulation, -) +from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation from lighteval.utils.language import Language +_CF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), + Metrics.target_bits_per_byte, +] + +_MCF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), +] + +# Languages covered by alexandrainst/m_hellaswag and jon-tow/okapi_hellaswag. +# English excluded — use hellaswag.py for English evaluation. +_LANGUAGES = [ + Language.ARABIC, + Language.BENGALI, + Language.CATALAN, + Language.DANISH, + Language.GERMAN, + Language.SPANISH, + Language.BASQUE, + Language.FRENCH, + Language.GUJARATI, + Language.HINDI, + Language.CROATIAN, + Language.HUNGARIAN, + Language.ARMENIAN, + Language.INDONESIAN, + Language.ICELANDIC, + Language.ITALIAN, + Language.KANNADA, + Language.MALAYALAM, + Language.MARATHI, + Language.NEPALI, + Language.DUTCH, + Language.PORTUGUESE, + Language.ROMANIAN, + Language.RUSSIAN, + Language.SLOVAK, + Language.SERBIAN, + Language.SWEDISH, + Language.TAMIL, + Language.TELUGU, + Language.UKRAINIAN, + Language.VIETNAMESE, + Language.CHINESE, +] + + +def _hellaswag_adapter(line): + return { + "ctx_a": line["ctx_a"], + "ctx_b": line["ctx_b"], + "continuations": line["endings"], + "gold_idx": int(line["label"]), + } + + TASKS_TABLE = [ LightevalTaskConfig( - name=f"mlmm_hellaswag_{lang.value}_{formulation.name.lower()}", + name=f"mlmm_hellaswag:{lang.value}:cf", prompt_function=get_hellaswag_prompt_function( language=lang, - adapter=lambda line: { - # We don't use activity_label as they are not available - "ctx_a": line["ctx_a"], - "ctx_b": line["ctx_b"], - "continuations": line["endings"], - "gold_idx": int(line["label"]), - }, - formulation=formulation, + adapter=_hellaswag_adapter, + formulation=CFFormulation(), ), hf_repo="alexandrainst/m_hellaswag" if lang != Language.CHINESE else "jon-tow/okapi_hellaswag", hf_subset=standardize_tag(lang.value), - # hf_revision="96ed8e0dfc6172dad1d3df338d7b8ba6c1ff9d83", evaluation_splits=["val" if lang != Language.CHINESE else "validation"], hf_avail_splits=["val" if lang != Language.CHINESE else "validation"], - metrics=get_metrics_for_formulation( - formulation, - [ - LogLikelihoodAccMetric(normalization=LogProbTokenNorm()), - LogLikelihoodAccMetric(normalization=LogProbCharNorm()), - ], + few_shots_split=None, + metrics=_CF_METRICS, + ) + for lang in _LANGUAGES +] + +# MCF variant: labeled options in prompt, score label tokens via logprobs +TASKS_TABLE += [ + LightevalTaskConfig( + name=f"mlmm_hellaswag:{lang.value}:mcf", + prompt_function=get_hellaswag_prompt_function( + language=lang, + adapter=_hellaswag_adapter, + formulation=MCFFormulation(), ), + hf_repo="alexandrainst/m_hellaswag" if lang != Language.CHINESE else "jon-tow/okapi_hellaswag", + hf_subset=standardize_tag(lang.value), + evaluation_splits=["val" if lang != Language.CHINESE else "validation"], + hf_avail_splits=["val" if lang != Language.CHINESE else "validation"], + few_shots_split=None, + metrics=_MCF_METRICS, + ) + for lang in _LANGUAGES +] + +# Greedy variant: MCF-style prompt, generate 1 token, exact match +TASKS_TABLE += [ + LightevalTaskConfig( + name=f"mlmm_hellaswag:{lang.value}:mcf_em", + prompt_function=get_hellaswag_prompt_function( + language=lang, + adapter=_hellaswag_adapter, + formulation=MCFFormulation(), + ), + hf_repo="alexandrainst/m_hellaswag" if lang != Language.CHINESE else "jon-tow/okapi_hellaswag", + hf_subset=standardize_tag(lang.value), + evaluation_splits=["val" if lang != Language.CHINESE else "validation"], + hf_avail_splits=["val" if lang != Language.CHINESE else "validation"], + few_shots_split=None, + generation_size=1, + stop_sequence=["\n"], + metrics=[Metrics.exact_match], ) - for lang in [ - Language.ARABIC, - Language.BENGALI, - Language.CATALAN, - Language.DANISH, - Language.GERMAN, - Language.SPANISH, - Language.BASQUE, - Language.FRENCH, - Language.GUJARATI, - Language.HINDI, - Language.CROATIAN, - Language.HUNGARIAN, - Language.ARMENIAN, - Language.INDONESIAN, - Language.ICELANDIC, - Language.ITALIAN, - Language.KANNADA, - Language.MALAYALAM, - Language.MARATHI, - Language.NORWEGIAN, - Language.NEPALI, - Language.DUTCH, - Language.PORTUGUESE, - Language.ROMANIAN, - Language.RUSSIAN, - Language.SLOVAK, - Language.SERBIAN, - Language.SWEDISH, - Language.TAMIL, - Language.TELUGU, - Language.UKRAINIAN, - Language.VIETNAMESE, - Language.CHINESE, - ] - for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] + for lang in _LANGUAGES ] diff --git a/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py new file mode 100644 index 000000000..a21c3b5f2 --- /dev/null +++ b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py @@ -0,0 +1,170 @@ +""" +name: +MMLU-ProX (Multilingual MMLU-Pro) + +dataset: +li-lab/MMLU-ProX + +abstract: +MMLU-ProX is a multilingual extension of MMLU-Pro covering 29 languages. It +uses up to 10 answer options per question (labels A–J), making it more +challenging than standard 4-option MMLU. + +English is excluded — use mmlu_pro.py for English evaluation. + +Metrics: +- :cf — completion formulation; scores each non-null option text; reports + acc, acc_norm (char), target_bpb +- :mcf — multiple-choice formulation; scores label tokens A–J; reports + acc, acc_norm (char) + +languages: +afrikaans, arabic, bengali, czech, german, spanish, french, hindi, hungarian, +indonesian, italian, japanese, korean, marathi, nepali, portuguese, russian, +serbian, swahili, telugu, thai, ukrainian, urdu, vietnamese, wolof, yoruba, +chinese, zulu + +tags: +knowledge, multilingual, multiple-choice + +paper: +https://huggingface.co/datasets/li-lab/MMLU-ProX +""" + +from string import ascii_uppercase + +from langcodes import standardize_tag + +from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.normalizations import LogProbCharNorm +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.utils.language import Language + + +_CF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), + Metrics.target_bits_per_byte, +] + +_MCF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), +] + +# MMLU-ProX languages (English excluded — use mmlu_pro.py for English) +_LANGUAGES = [ + Language.AFRIKAANS, # af + Language.ARABIC, # ar + Language.BENGALI, # bn + Language.CZECH, # cs + Language.GERMAN, # de + Language.SPANISH, # es + Language.FRENCH, # fr + Language.HINDI, # hi + Language.HUNGARIAN, # hu + Language.INDONESIAN, # id + Language.ITALIAN, # it + Language.JAPANESE, # ja + Language.KOREAN, # ko + Language.MARATHI, # mr + Language.NEPALI, # ne + Language.PORTUGUESE, # pt + Language.RUSSIAN, # ru + Language.SERBIAN, # sr + Language.SWAHILI, # sw + Language.TELUGU, # te + Language.THAI, # th + Language.UKRAINIAN, # uk + Language.URDU, # ur + Language.VIETNAMESE, # vi + Language.WOLOF, # wo + Language.YORUBA, # yo + Language.CHINESE, # zh + Language.ZULU, # zu +] + +# option_0..option_9 column names +_OPTION_COLS = [f"option_{i}" for i in range(10)] + + +def _get_options(line): + """Return list of non-null options from option_0..option_9.""" + return [line[col] for col in _OPTION_COLS if line.get(col) is not None and str(line[col]).strip()] + + +def mmlu_prox_cf_prompt(line, task_name: str = None): + """CF: score each non-null option text directly.""" + options = _get_options(line) + if not options: + return None + choices = [c if c and c[0].isspace() else " " + c for c in options] + return Doc( + task_name=task_name, + query=f"Question: {line['question'].strip()}\nAnswer:", + choices=choices, + gold_index=line["answer_index"], + ) + + +def mmlu_prox_mcf_prompt(line, task_name: str = None): + """MCF: show labeled options A–J, score label tokens only.""" + options = _get_options(line) + if not options: + return None + labels = list(ascii_uppercase[: len(options)]) + query = f"Question: {line['question'].strip()}\n" + query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)]) + query += "Answer:" + return Doc( + task_name=task_name, + query=query, + choices=[" " + lbl for lbl in labels], + gold_index=line["answer_index"], + ) + + +def _valid_filter(x): + """Skip rows where no options are present.""" + return any( + x.get(col) is not None and str(x[col]).strip() + for col in _OPTION_COLS + ) + + +TASKS_TABLE = [ + LightevalTaskConfig( + name=f"mmlu_prox:{language.value}:{suffix}", + prompt_function=prompt_fn, + hf_repo="li-lab/MMLU-ProX", + hf_subset=standardize_tag(language.value), + evaluation_splits=("test",), + few_shots_split="validation", + hf_filter=_valid_filter, + metrics=metrics, + ) + for language in _LANGUAGES + for suffix, prompt_fn, metrics in [ + ("cf", mmlu_prox_cf_prompt, _CF_METRICS), + ("mcf", mmlu_prox_mcf_prompt, _MCF_METRICS), + ] +] + +# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match +TASKS_TABLE += [ + LightevalTaskConfig( + name=f"mmlu_prox:{language.value}:mcf_em", + prompt_function=mmlu_prox_mcf_prompt, + hf_repo="li-lab/MMLU-ProX", + hf_subset=standardize_tag(language.value), + evaluation_splits=("test",), + few_shots_split="validation", + hf_filter=_valid_filter, + generation_size=5, + stop_sequence=["\n"], + metrics=[Metrics.exact_match], + ) + for language in _LANGUAGES +] diff --git a/src/lighteval/tasks/multilingual/tasks/wmt24pp.py b/src/lighteval/tasks/multilingual/tasks/wmt24pp.py new file mode 100644 index 000000000..27c5244bd --- /dev/null +++ b/src/lighteval/tasks/multilingual/tasks/wmt24pp.py @@ -0,0 +1,114 @@ +""" +name: +WMT24++ + +dataset: +google/wmt24pp + +abstract: +WMT24++ is an extension of the WMT2024 general translation shared task test +sets, covering 55+ language pairs with English as one side. Data includes +post-edited professional translations and quality flags. + +Tasks are 0-shot generation with chrF++ and BLEU-4 metrics. +The public selector is `wmt24pp:{lp}|0`, which expands to the internal +directional tasks `wmt24pp:en_to_x:{lp}` and `wmt24pp:x_to_en:{lp}`. + +NOTE: The dataset's "train" split is the evaluation data (WMT test sets are +published in HuggingFace with split="train"). +Few-shot evaluation reuses this same split, since no separate few-shot split is +published for WMT24++ in the dataset. + +NOTE: `lp` column values use regional codes (e.g., "de_DE", "fr_FR"). Tasks +are named using these codes directly. Check the dataset for all available lp +values: https://huggingface.co/datasets/google/wmt24pp + +tags: +multilingual, translation + +paper: +https://arxiv.org/abs/2412.06378 +""" + +from functools import partial + +from lighteval.tasks.multilingual.utils.translation import ( + TRANSLATION_METRICS, + EnglishCentricTranslationConfig, + build_english_centric_translation_tasks, +) +from lighteval.utils.language import Language + + +_TRANSLATION_METRICS = TRANSLATION_METRICS + +_LANGUAGE_CONFIGS = [ + ("de_DE", Language.GERMAN), + ("fr_FR", Language.FRENCH), + ("cs_CZ", Language.CZECH), + ("es_MX", Language.SPANISH), + ("ru_RU", Language.RUSSIAN), + ("zh_CN", Language.CHINESE), + ("ja_JP", Language.JAPANESE), + ("uk_UA", Language.UKRAINIAN), + ("hi_IN", Language.HINDI), + ("ar_EG", Language.ARABIC), + ("ko_KR", Language.KOREAN), + ("pt_BR", Language.PORTUGUESE), + ("tr_TR", Language.TURKISH), + ("pl_PL", Language.POLISH), + ("he_IL", Language.HEBREW), + ("nl_NL", Language.DUTCH), + ("it_IT", Language.ITALIAN), + ("sv_SE", Language.SWEDISH), + ("fi_FI", Language.FINNISH), + ("vi_VN", Language.VIETNAMESE), + ("bn_IN", Language.BENGALI), + ("th_TH", Language.THAI), + ("id_ID", Language.INDONESIAN), + ("hu_HU", Language.HUNGARIAN), +] + + +def _make_forward_adapter(_lp_code): + """Adapter for en→X direction: source=English, target=other language.""" + return lambda line: { + "source_text": line["source"], + "target_text": line["target"], + } + + +def _make_reverse_adapter(_lp_code): + """Adapter for X→en direction: source=other language, target=English.""" + return lambda line: { + "source_text": line["target"], + "target_text": line["source"], + } + + +def _make_filter(lp_code): + """Skip bad source sentences inside a specific WMT24++ pair config.""" + return partial(lambda _lp, x: not x.get("is_bad_source", False), lp_code) + + +TASKS_TABLE = build_english_centric_translation_tasks( + base_name="wmt24pp", + hf_repo="google/wmt24pp", + language_configs=[ + EnglishCentricTranslationConfig( + task_id=lp_code, + target_language=target_lang, + forward_hf_subset=f"en-{lp_code}", + reverse_hf_subset=f"en-{lp_code}", + forward_adapter=_make_forward_adapter(lp_code), + reverse_adapter=_make_reverse_adapter(lp_code), + hf_filter=_make_filter(lp_code), + hf_avail_splits=("train",), + evaluation_splits=("train",), + few_shots_split="train", + few_shots_select="random_sampling", + ) + for lp_code, target_lang in _LANGUAGE_CONFIGS + ], + metrics=_TRANSLATION_METRICS, +) diff --git a/src/lighteval/tasks/multilingual/utils/translation.py b/src/lighteval/tasks/multilingual/utils/translation.py new file mode 100644 index 000000000..be7eaf227 --- /dev/null +++ b/src/lighteval/tasks/multilingual/utils/translation.py @@ -0,0 +1,103 @@ +from collections.abc import Callable, Sequence +from dataclasses import dataclass + +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.utils.metric_utils import Metric +from lighteval.tasks.lighteval_task import LightevalTaskConfig + +# Canonical metric set for English-centric translation tasks. +# chrf++ and BLEU are instant (sacrebleu); COMET-22 requires unbabel-comet (pip install lighteval[multilingual]). +try: + import comet # noqa: F401 + + _comet_available = True +except ImportError: + _comet_available = False + +TRANSLATION_METRICS: list[Metrics] = [ + Metrics.chrf_plus, + Metrics.bleu, + *([ Metrics.comet22] if _comet_available else []), +] +from lighteval.tasks.templates.translation import ( + TranslationInput, + get_translation_prompt_function, +) +from lighteval.tasks.templates.utils.formulation import CFFormulation +from lighteval.utils.language import Language + + +@dataclass(frozen=True) +class EnglishCentricTranslationConfig: + task_id: str + target_language: Language + forward_hf_subset: str + reverse_hf_subset: str + forward_adapter: Callable[[dict], TranslationInput | None] + reverse_adapter: Callable[[dict], TranslationInput | None] + hf_filter: Callable[[dict], bool] | None = None + hf_avail_splits: Sequence[str] = ("train", "validation", "test") + evaluation_splits: Sequence[str] = ("validation",) + few_shots_split: str | None = None + few_shots_select: str | None = None + + +def build_english_centric_translation_tasks( + *, + base_name: str, + hf_repo: str, + language_configs: Sequence[EnglishCentricTranslationConfig], + metrics: Sequence[Metric | Metrics], + generation_size: int = 300, + stop_sequence: Sequence[str] = ("\n",), + version: int = 0, +) -> list[LightevalTaskConfig]: + tasks = [] + + for config in language_configs: + tasks.append( + LightevalTaskConfig( + name=f"{base_name}:en_to_x:{config.task_id}", + prompt_function=get_translation_prompt_function( + source_language=Language.ENGLISH, + target_language=config.target_language, + adapter=config.forward_adapter, + formulation=CFFormulation(), + ), + hf_repo=hf_repo, + hf_subset=config.forward_hf_subset, + hf_filter=config.hf_filter, + hf_avail_splits=config.hf_avail_splits, + evaluation_splits=config.evaluation_splits, + few_shots_split=config.few_shots_split, + few_shots_select=config.few_shots_select, + generation_size=generation_size, + metrics=metrics, + stop_sequence=stop_sequence, + version=version, + ) + ) + tasks.append( + LightevalTaskConfig( + name=f"{base_name}:x_to_en:{config.task_id}", + prompt_function=get_translation_prompt_function( + source_language=config.target_language, + target_language=Language.ENGLISH, + adapter=config.reverse_adapter, + formulation=CFFormulation(), + ), + hf_repo=hf_repo, + hf_subset=config.reverse_hf_subset, + hf_filter=config.hf_filter, + hf_avail_splits=config.hf_avail_splits, + evaluation_splits=config.evaluation_splits, + few_shots_split=config.few_shots_split, + few_shots_select=config.few_shots_select, + generation_size=generation_size, + metrics=metrics, + stop_sequence=stop_sequence, + version=version, + ) + ) + + return tasks diff --git a/src/lighteval/tasks/tasks/mmlu_prox.py b/src/lighteval/tasks/tasks/mmlu_prox.py new file mode 100644 index 000000000..08c04703b --- /dev/null +++ b/src/lighteval/tasks/tasks/mmlu_prox.py @@ -0,0 +1,133 @@ +""" +name: +MMLU-ProX (English) + +dataset: +li-lab/MMLU-ProX + +abstract: +MMLU-ProX is a multilingual extension of MMLU-Pro. This file covers the +English ("en") subset only. It uses up to 10 answer options per question +(labels A–J), consistent with MMLU-Pro's harder format. + +For multilingual evaluation see: +lighteval/tasks/multilingual/tasks/mmlu_prox.py + +languages: +english + +tags: +general-knowledge, knowledge, multiple-choice + +paper: +https://huggingface.co/datasets/li-lab/MMLU-ProX +""" + +from string import ascii_uppercase + +from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.normalizations import LogProbCharNorm +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +_CF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), + Metrics.target_bits_per_byte, +] + +_MCF_METRICS = [ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), +] + +# option_0..option_9 column names +_OPTION_COLS = [f"option_{i}" for i in range(10)] + + +def _get_options(line): + """Return list of non-null options from option_0..option_9.""" + return [line[col] for col in _OPTION_COLS if line.get(col) is not None and str(line[col]).strip()] + + +def mmlu_prox_cf_prompt(line, task_name: str = None): + """CF: score each non-null option text directly.""" + options = _get_options(line) + if not options: + return None + choices = [c if c and c[0].isspace() else " " + c for c in options] + return Doc( + task_name=task_name, + query=f"Question: {line['question'].strip()}\nAnswer:", + choices=choices, + gold_index=line["answer_index"], + ) + + +def mmlu_prox_mcf_prompt(line, task_name: str = None): + """MCF: show labeled options A–J, score label tokens only.""" + options = _get_options(line) + if not options: + return None + labels = list(ascii_uppercase[: len(options)]) + query = f"Question: {line['question'].strip()}\n" + query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)]) + query += "Answer:" + return Doc( + task_name=task_name, + query=query, + choices=[" " + lbl for lbl in labels], + gold_index=line["answer_index"], + ) + + +def _valid_filter(x): + """Skip rows where no options are present.""" + return any( + x.get(col) is not None and str(x[col]).strip() + for col in _OPTION_COLS + ) + + +mmlu_prox_cf = LightevalTaskConfig( + name="mmlu_prox_eng:cf", + prompt_function=mmlu_prox_cf_prompt, + hf_repo="li-lab/MMLU-ProX", + hf_subset="en", + evaluation_splits=("test",), + few_shots_split="validation", + hf_filter=_valid_filter, + metrics=_CF_METRICS, + version=0, +) + +mmlu_prox_mcf = LightevalTaskConfig( + name="mmlu_prox_eng:mcf", + prompt_function=mmlu_prox_mcf_prompt, + hf_repo="li-lab/MMLU-ProX", + hf_subset="en", + evaluation_splits=("test",), + few_shots_split="validation", + hf_filter=_valid_filter, + metrics=_MCF_METRICS, + version=0, +) + +# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match +mmlu_prox_mcf_em = LightevalTaskConfig( + name="mmlu_prox_eng:mcf_em", + prompt_function=mmlu_prox_mcf_prompt, + hf_repo="li-lab/MMLU-ProX", + hf_subset="en", + evaluation_splits=("test",), + few_shots_split="validation", + hf_filter=_valid_filter, + generation_size=5, + stop_sequence=["\n"], + metrics=[Metrics.exact_match], + version=0, +) + +TASKS_TABLE = [mmlu_prox_cf, mmlu_prox_mcf, mmlu_prox_mcf_em] diff --git a/src/lighteval/tasks/templates/translation.py b/src/lighteval/tasks/templates/translation.py index 6b4c54a62..174e05a38 100644 --- a/src/lighteval/tasks/templates/translation.py +++ b/src/lighteval/tasks/templates/translation.py @@ -145,7 +145,7 @@ def translation_prompt( for text in as_list(input_data["target_text"]) ] - return continuation_prompt_fn( + doc = continuation_prompt_fn( { "instruction": input_data.get("instruction", ""), "context": context, @@ -154,5 +154,9 @@ def translation_prompt( }, task_name, ) + # Store the cleaned source text so COMET metric can access it without re-parsing the prompt. + if doc is not None: + doc.specific = (doc.specific or {}) | {"source_text": source_text} + return doc return translation_prompt diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index 96cd8fc2c..99405c45b 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -1426,3 +1426,27 @@ def __getattribute__(self, name: str) -> str: Language.YUE_CHINESE: TranslationLiterals(language=Language.YUE_CHINESE), Language.ZULU: TranslationLiterals(language=Language.ZULU), } + +# Some multilingual benchmark additions rely on the generic MCQ prompt template, +# which requires at least localized question/answer labels. Where dedicated +# translations are still missing, fall back to English labels so the task can run. +for _language in [ + Language.AMHARIC, + Language.HAUSA, + Language.HEBREW, + Language.IGBO, + Language.KIRGHIZ, + Language.KOREAN, + Language.LITHUANIAN, + Language.MALAGASY, + Language.MALAY, + Language.NEPALI, + Language.NYANJA, + Language.PERSIAN, + Language.SHONA, + Language.SINHALA, + Language.SOMALI, + Language.YORUBA, +]: + TRANSLATION_LITERALS[_language].question_word = "question" + TRANSLATION_LITERALS[_language].answer = "answer" From b6f86713bdc9cc7ddfc9c8eb59f3a26868e4721b Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Sun, 19 Apr 2026 13:22:03 +0200 Subject: [PATCH 2/8] fix msmg --- src/lighteval/tasks/lighteval_task.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 3f45315ba..ddbccb1ad 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -134,7 +134,7 @@ def _load_hub_raw_dataset_files( return load_dataset(data_format, data_files=data_files) -def _load_mgsm_dataset(config_name: str) -> DatasetDict: +def _load_mgsm_dataset(config_name: str, local_files_only: bool = False) -> DatasetDict: import csv import importlib.util @@ -144,11 +144,13 @@ def _load_mgsm_dataset(config_name: str) -> DatasetDict: repo_id="juletxara/mgsm", repo_type="dataset", filename="exemplars.py", + local_files_only=local_files_only, ) tsv_path = hf_hub_download( repo_id="juletxara/mgsm", repo_type="dataset", filename=f"mgsm_{config_name}.tsv", + local_files_only=local_files_only, ) spec = importlib.util.spec_from_file_location("mgsm_exemplars", exemplars_path) @@ -769,8 +771,11 @@ def download_dataset_worker( if not (_is_script_err or _is_offline): raise - if _is_script_err and task.dataset_path == "juletxara/mgsm": - dataset = _load_mgsm_dataset(task.dataset_config_name) + if (_is_script_err or _is_offline) and task.dataset_path == "juletxara/mgsm": + dataset = _load_mgsm_dataset( + task.dataset_config_name, + local_files_only=_is_offline, + ) if task.dataset_filter is not None: dataset = dataset.filter(task.dataset_filter) return dataset # type: ignore From 2ae0dbc5d601e4ddf6fe232bd56a383d6cbff7a0 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Wed, 22 Apr 2026 22:36:35 +0200 Subject: [PATCH 3/8] fix bugs --- .../models/endpoints/litellm_model.py | 46 +++++++------------ .../multilingual/tasks/mlmm_hellaswag.py | 1 + src/lighteval/tasks/tasks/arc.py | 2 + src/lighteval/tasks/tasks/hellaswag.py | 13 +++--- src/lighteval/tasks/tasks/jeopardy.py | 2 +- src/lighteval/tasks/tasks/mmlu.py | 4 +- src/lighteval/tasks/tasks/winogrande.py | 4 +- 7 files changed, 33 insertions(+), 39 deletions(-) diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index f0036854d..65885d032 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -638,12 +638,13 @@ def _left_truncate_tokens(self, text: str, max_tokens: int) -> str: keep_chars = max(1, int(len(text) * max_tokens / total)) return text[-keep_chars:] - def _score_single(self, prompt: str, ctx_token_count: int) -> tuple[float, bool]: + def _score_single(self, prompt: str, choice_len: int) -> tuple[float, bool]: """Score one (context + choice) string via a single echo-based completions call. Args: prompt: Full text (context + choice) to score. - ctx_token_count: Number of context tokens; continuation logprobs start here. + choice_len: Number of choice tokens; sliced from the right of the echoed logprobs, + before the one generated token. This approach is server-BOS-agnostic. Returns: (logprob_sum, is_greedy) @@ -680,17 +681,12 @@ def _score_single(self, prompt: str, ctx_token_count: int) -> tuple[float, bool] logprobs_obj = response.choices[0].logprobs token_logprobs = logprobs_obj.token_logprobs or [] - # Infer choice token count from response: len = ctx + choice + 1 (generated token) - # :-1 would be equivalent but fails when ctx_token_count is off by 1 due to BPE - # boundary merges (e.g. "Answer:" + " Yes" -> "Answer:▁Yes" in joint tokenization), - # producing an empty slice and logprob_sum = -inf, which poisons corpus BPB. - start = ctx_token_count - n_choice = len(token_logprobs) - 1 - ctx_token_count - if n_choice <= 0 and ctx_token_count > 0: - # ctx_token_count overestimates by 1 — shift start back by 1 - start -= 1 - n_choice = len(token_logprobs) - 1 - start - end = start + max(n_choice, 0) + # Slice choice tokens from the right: the response is + # [BOS?] [context tokens...] [choice tokens...] [1 generated token] + # Anchoring from the right is BOS-agnostic and robust to any server-side + # special-token prepending. + end = len(token_logprobs) - 1 # exclude the one generated token + start = max(0, end - choice_len) cont_logprobs = [lp for lp in token_logprobs[start:end] if lp is not None] logprob_sum = sum(cont_logprobs) if cont_logprobs else float("-inf") @@ -715,14 +711,14 @@ def _score_batch( all pairs in the batch. Args: - batch: list of (doc_idx, choice_idx, full_text, ctx_token_count) + batch: list of (doc_idx, choice_idx, full_text, choice_len) Returns: list of (doc_idx, choice_idx, (logprob_sum, is_greedy)) """ if len(batch) == 1: - di, ci, full_text, ctx_len = batch[0] - return [(di, ci, self._score_single(full_text, ctx_len))] + di, ci, full_text, choice_len = batch[0] + return [(di, ci, self._score_single(full_text, choice_len))] prompts = [full_text for _, _, full_text, _ in batch] @@ -761,7 +757,7 @@ def _score_batch( choices_by_idx = {c.index: c for c in response.choices} results = [] - for seq_idx, (di, ci, _, ctx_token_count) in enumerate(batch): + for seq_idx, (di, ci, _, choice_len) in enumerate(batch): choice = choices_by_idx.get(seq_idx) if choice is None: logger.warning( @@ -773,12 +769,8 @@ def _score_batch( logprobs_obj = choice.logprobs token_logprobs = logprobs_obj.token_logprobs or [] - start = ctx_token_count - n_choice = len(token_logprobs) - 1 - ctx_token_count - if n_choice <= 0 and ctx_token_count > 0: - start -= 1 - n_choice = len(token_logprobs) - 1 - start - end = start + max(n_choice, 0) + end = len(token_logprobs) - 1 # exclude the one generated token + start = max(0, end - choice_len) cont_logprobs = [lp for lp in token_logprobs[start:end] if lp is not None] logprob_sum = sum(cont_logprobs) if cont_logprobs else float("-inf") @@ -846,11 +838,8 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]: """Tokenize a batch then score all pairs in one API call.""" prepared = [] for di, ci, context, choice in batch_items: - # ctx_len = total - choice avoids the BPE boundary ±1 error when - # tokenizing context alone (e.g. "Answer:" + " Yes" merges into - # "Answer:▁Yes", making context_len overestimate by 1). - total_len = self._count_tokens(context + choice) choice_len = self._count_tokens(choice) + total_len = self._count_tokens(context + choice) prompt = context + choice # Left-truncate (OLMES-style): drop early few-shot examples when # the combined prompt exceeds the model's context window. @@ -862,8 +851,7 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]: f"Prompt too long ({total_len} tokens > {max_input} max); left-truncating." ) prompt = self._left_truncate_tokens(prompt, max_input) - total_len = max_input - prepared.append((di, ci, prompt, total_len - choice_len)) + prepared.append((di, ci, prompt, choice_len)) return self._score_batch(prepared) # Chunk work into batches — each batch becomes a single API call with a list diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py index bbd631278..dba27632c 100644 --- a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py +++ b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py @@ -91,6 +91,7 @@ def _hellaswag_adapter(line): return { + "activity_label": line.get("activity_label", ""), "ctx_a": line["ctx_a"], "ctx_b": line["ctx_b"], "continuations": line["endings"], diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py index 05ca9ae30..afad00779 100644 --- a/src/lighteval/tasks/tasks/arc.py +++ b/src/lighteval/tasks/tasks/arc.py @@ -74,6 +74,7 @@ def arc_mcf_prompt(line, task_name: str = None): evaluation_splits=["test"], few_shots_split=None, few_shots_select="random_sampling_from_train", + generation_size=-1, metrics=_CF_METRICS, stop_sequence=["\n"], version=0, @@ -88,6 +89,7 @@ def arc_mcf_prompt(line, task_name: str = None): evaluation_splits=["test"], few_shots_split=None, few_shots_select="random_sampling_from_train", + generation_size=-1, metrics=_CF_METRICS, stop_sequence=["\n"], version=0, diff --git a/src/lighteval/tasks/tasks/hellaswag.py b/src/lighteval/tasks/tasks/hellaswag.py index 8a6ea4db0..14d9462fc 100644 --- a/src/lighteval/tasks/tasks/hellaswag.py +++ b/src/lighteval/tasks/tasks/hellaswag.py @@ -50,10 +50,12 @@ def harness_preprocess(text): def hellaswag_mcf_prompt(line, task_name: str = None): """MCF variant: labeled options in prompt, score label tokens via logprobs.""" - query = "The following are multiple choice questions (with answers) about common sense.\n\n" - query += f"Question: {line['activity_label']}: {line['ctx_a']} {line['ctx_b'].capitalize()}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["endings"])]) - query += "Answer:" + ctx = line["ctx_a"] + " " + line["ctx_b"].capitalize() + query = harness_preprocess(line["activity_label"] + ": " + ctx) + query += "".join( + [f"\n{key}. {harness_preprocess(choice)}" for key, choice in zip(ascii_uppercase, line["endings"])] + ) + query += "\nAnswer:" gold_ix = int(line["label"]) if line["label"] != "" else -1 return Doc( @@ -61,7 +63,6 @@ def hellaswag_mcf_prompt(line, task_name: str = None): query=query, choices=[" " + i for i in ascii_uppercase[: len(line["endings"])]], gold_index=gold_ix, - instruction="The following are multiple choice questions (with answers) about common sense.\n\n", ) def hellaswag_cf_prompt(line, task_name: str = None): @@ -110,7 +111,6 @@ def hellaswag_cf_prompt(line, task_name: str = None): stop_sequence=["\n"], version=0, ) - # CF variant: completion-style, logprob on full answer text + BPB on gold choice hellaswag_cf = LightevalTaskConfig( name="hellaswag:cf", @@ -121,6 +121,7 @@ def hellaswag_cf_prompt(line, task_name: str = None): evaluation_splits=["validation"], few_shots_split=None, few_shots_select=None, + generation_size=-1, metrics=_CF_METRICS, stop_sequence=["\n"], version=0, diff --git a/src/lighteval/tasks/tasks/jeopardy.py b/src/lighteval/tasks/tasks/jeopardy.py index de2dc0d41..67cf40360 100644 --- a/src/lighteval/tasks/tasks/jeopardy.py +++ b/src/lighteval/tasks/tasks/jeopardy.py @@ -94,7 +94,7 @@ def jeopardy_mc_mcf_prompt(line, task_name: str = None): return Doc( task_name=task_name, query=f"Category: {category}\nQuestion: {question}\n{options}\nAnswer:", - choices=labels, + choices=[" " + l for l in labels], gold_index=gold, ) diff --git a/src/lighteval/tasks/tasks/mmlu.py b/src/lighteval/tasks/tasks/mmlu.py index a1472daed..dc31502c5 100644 --- a/src/lighteval/tasks/tasks/mmlu.py +++ b/src/lighteval/tasks/tasks/mmlu.py @@ -208,9 +208,11 @@ def mmlu_redux_cf_prompt(line, task_name: str = None): else line["answer"] ) + query = f"The following are multiple choice questions about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}\nAnswer:" + return Doc( task_name=task_name, - query=f"The following are multiple choice questions about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}\nAnswer:", + query=query, choices=[" " + c for c in line["choices"]], gold_index=gold_ix, fewshot_sorting_class=line["choices"][gold_ix], diff --git a/src/lighteval/tasks/tasks/winogrande.py b/src/lighteval/tasks/tasks/winogrande.py index 79b70fbde..0e5f7b361 100644 --- a/src/lighteval/tasks/tasks/winogrande.py +++ b/src/lighteval/tasks/tasks/winogrande.py @@ -46,8 +46,8 @@ def winogrande_cf_prompt(line, task_name: str = None): end_of_target = end_of_target.strip() return Doc( task_name=task_name, - query=query, - choices=[f"{line['option1']} {end_of_target}", f"{line['option2']} {end_of_target}"], + query=query.rstrip(), + choices=[f" {line['option1']} {end_of_target}", f" {line['option2']} {end_of_target}"], gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1, ) From 7a2e52c6003472c7ccf1130d2d4461b54a2e0783 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Fri, 1 May 2026 15:52:24 +0200 Subject: [PATCH 4/8] update ruler --- src/lighteval/metrics/metrics_sample.py | 58 ++++-- src/lighteval/tasks/tasks/ruler.py | 196 ++++++++++++-------- tests/unit/metrics/test_ruler_metrics.py | 79 ++++++++ tests/unit/tasks/test_ruler.py | 226 +++++++++++++++++++++++ 4 files changed, 467 insertions(+), 92 deletions(-) create mode 100644 tests/unit/metrics/test_ruler_metrics.py create mode 100644 tests/unit/tasks/test_ruler.py diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 4863d7239..08d932846 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -27,6 +27,7 @@ import inspect import logging import os +import re from abc import ABC, abstractmethod from typing import Callable, Literal, Union @@ -46,6 +47,7 @@ from lighteval.metrics.normalizations import ( LogProbNormalization, LogProbTokenNorm, + helm_normalizer, normalize_log_probs, remove_braces, remove_braces_and_strip, @@ -1562,14 +1564,11 @@ class RulerStringMatchAll(SampleLevelComputation): """ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: - pred = ( - model_response.final_text[0].strip().lower() - if model_response.final_text - else "" - ) + pred = model_response.final_text[0] if model_response.final_text else "" golds = doc.get_golds() if not golds: return 0.0 + pred = pred.strip().lower() return sum(g.lower() in pred for g in golds) / len(golds) @@ -1582,15 +1581,38 @@ class RulerStringMatchAny(SampleLevelComputation): """ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: - pred = ( - model_response.final_text[0].strip().lower() - if model_response.final_text - else "" - ) + pred = model_response.final_text[0] if model_response.final_text else "" golds = doc.get_golds() if not golds: return 0.0 - return float(any(g.lower() in pred for g in golds)) + return float(self._has_match(pred, golds)) + + @staticmethod + def _parse_output(output: str, prefix: str = "Answer:") -> str | None: + def _lstrip_prefix(text: str, sub: str) -> str: + return re.sub(f"^{re.escape(sub)}", "", text, flags=re.IGNORECASE) + + patterns = [ + re.compile(f"(?:{re.escape(prefix)})(.*)(?:\\n|$)", flags=re.IGNORECASE), + re.compile(r"(?:^)(.*)(?:\n|$)"), + ] + for pattern in patterns: + match = pattern.search(output) + if match is not None: + return _lstrip_prefix(match[1].strip(), prefix).strip() + return None + + @staticmethod + def _normalized_substring_match(pred: str, gold: str) -> bool: + return helm_normalizer(gold) in helm_normalizer(pred) + + @classmethod + def _has_match(cls, pred: str, golds: list[str]) -> bool: + parsed_pred = cls._parse_output(pred) + candidates = [pred] + if parsed_pred is not None: + candidates.append(parsed_pred) + return any(cls._normalized_substring_match(candidate, gold) for candidate in candidates for gold in golds) class RulerStringMatch(SampleLevelComputation): @@ -1601,15 +1623,17 @@ class RulerStringMatch(SampleLevelComputation): """ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float: - pred = ( - model_response.final_text[0].strip().lower() - if model_response.final_text - else "" - ) + pred = model_response.final_text[0] if model_response.final_text else "" + # Prepend gen_prefix (answer prefix) so scoring matches olmo's behaviour: + # olmo does: output = doc["prepend_text"] + output + gen_prefix = (doc.specific or {}).get("gen_prefix", "") + if gen_prefix: + pred = gen_prefix + pred golds = doc.get_golds() if not golds: return 0.0 task_name = doc.task_name or "" if ":qa_" in task_name: - return float(any(g.lower() in pred for g in golds)) + return float(RulerStringMatchAny._has_match(pred, golds)) + pred = pred.strip().lower() return sum(g.lower() in pred for g in golds) / len(golds) diff --git a/src/lighteval/tasks/tasks/ruler.py b/src/lighteval/tasks/tasks/ruler.py index 8f141ab12..7b6648861 100644 --- a/src/lighteval/tasks/tasks/ruler.py +++ b/src/lighteval/tasks/tasks/ruler.py @@ -90,7 +90,8 @@ DEFAULT_LENGTHS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] -NUM_SAMPLES = 500 +# NUM_SAMPLES = 500 +NUM_SAMPLES = 100 RANDOM_SEED = 42 # --------------------------------------------------------------------------- @@ -147,7 +148,13 @@ def _get_cache_dir(tokenizer_path: str) -> Path: # --------------------------------------------------------------------------- NIAH_NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}." -NIAH_TEMPLATE = ( +NIAH_TEMPLATE_SINGLE = ( + "A special magic {type_needle_v} is hidden within the following text. " + "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" + "{context}\n" + "What is the special magic {type_needle_v} for {query} mentioned in the provided text?" +) +NIAH_TEMPLATE_MULTI = ( "Some special magic {type_needle_v} are hidden within the following text. " "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n" "{context}\n" @@ -321,7 +328,7 @@ def _niah_generate_samples( type_needle_v: str, template: str, num_samples: int = NUM_SAMPLES, - tokens_to_generate: int = 128, + tokens_to_generate: int = 50, num_needle_v: int = 1, num_needle_k: int = 1, num_needle_q: int = 1, @@ -361,7 +368,7 @@ def _gen_prefix(tnv, nq, nv, query): random_seed=random_seed, ) gen_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, query) - prompt = input_text + " " + gen_prefix + prompt = _build_runtime_prompt(input_text, gen_prefix) total_tokens = len(tokenizer(prompt).input_ids) if total_tokens + tokens_to_generate > budget: num_haystack -= incremental @@ -400,7 +407,7 @@ def _gen_prefix(tnv, nq, nv, query): random_seed=sample_seed, ) gen_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, query) - prompt = input_text + " " + gen_prefix + prompt = _build_runtime_prompt(input_text, gen_prefix) length = len(tokenizer(prompt).input_ids) + tokens_to_generate assert length <= budget break @@ -434,7 +441,7 @@ def _gen_prefix(tnv, nq, nv, query): # --------------------------------------------------------------------------- VT_CONFIG = { - "tokens_to_generate": 30, + "tokens_to_generate": 50, "template": ( "Memorize and track the chain(s) of variable assignment hidden in the " "following text.\n\n{context}\n" @@ -446,6 +453,7 @@ def _gen_prefix(tnv, nq, nv, query): ), } VT_TEMPLATE = VT_CONFIG["template"] + VT_CONFIG["answer_prefix"] +VT_ANSWER_PREFIX_BASE = " Answer: According to the chain(s) of variable assignment" def _vt_generate_chains( @@ -506,6 +514,47 @@ def _vt_randomize_icl(icl_example: str) -> str: return icl_example +def _split_prompt_on_answer_prefix( + prompt: str, answer_prefix: str, *, strip_input: bool = False +) -> tuple[str, str]: + prefix_index = prompt.rfind(answer_prefix) + if prefix_index == -1: + raise ValueError("Answer prefix not found in prompt") + + input_text = prompt[:prefix_index] + if strip_input: + input_text = input_text.strip() + gen_prefix = prompt[prefix_index:].strip() + return input_text, gen_prefix + + +def _runtime_prompt_budget_length( + tokenizer, input_text: str, gen_prefix: str, tokens_to_generate: int +) -> int: + prompt = _build_runtime_prompt(input_text, gen_prefix) + return len(tokenizer(prompt).input_ids) + tokens_to_generate + + +def _vt_build_cached_sample(input_text: str, icl_prompt: str) -> tuple[str, str]: + cutoff = input_text.index(VT_CONFIG["template"][:20]) + full_prompt = input_text[:cutoff] + icl_prompt + "\n\n" + input_text[cutoff:] + return _split_prompt_on_answer_prefix(full_prompt, VT_ANSWER_PREFIX_BASE) + + +def _cwe_build_cached_sample(input_example: str, input_text: str) -> tuple[str, str]: + input_text_clean, gen_prefix = _split_prompt_on_answer_prefix( + input_text, CWE_CONFIG["answer_prefix"] + ) + full_input = (input_example + "\n" + input_text_clean).strip() + return full_input, gen_prefix + + +def _fwe_build_cached_sample(input_text: str) -> tuple[str, str]: + return _split_prompt_on_answer_prefix( + input_text, FWE_CONFIG["answer_prefix"], strip_input=True + ) + + def _vt_generate_samples( tokenizer, max_seq_length: int, @@ -513,7 +562,7 @@ def _vt_generate_samples( incremental: int = 10, num_chains: int = 1, num_hops: int = 4, - tokens_to_generate: int = 30, + tokens_to_generate: int = VT_CONFIG["tokens_to_generate"], ) -> list[dict]: budget = max_seq_length @@ -542,15 +591,16 @@ def _vt_generate_samples( icl_str = ( icl_text_raw + " " + icl_out ) # full ICL text: question + gen_prefix + answers - example_tokens = len(tokenizer(icl_str + "\n\n").input_ids) - # --- Step 2: Find num_noises for main examples --- num_noises = incremental total_tokens = 0 - while total_tokens + tokens_to_generate + example_tokens < budget: + while total_tokens + tokens_to_generate < budget: input_text, _ = _vt_generate_input_output(num_noises, num_chains, num_hops) - total_tokens = len(tokenizer(input_text).input_ids) - if total_tokens + tokens_to_generate + example_tokens > budget: + cached_input, cached_gen_prefix = _vt_build_cached_sample(input_text, icl_str) + total_tokens = len( + tokenizer(_build_runtime_prompt(cached_input, cached_gen_prefix)).input_ids + ) + if total_tokens + tokens_to_generate > budget: num_noises -= incremental break num_noises += incremental @@ -571,10 +621,11 @@ def _vt_generate_samples( input_text, answer = _vt_generate_input_output( used_noises, num_chains, num_hops ) - length = ( - len(tokenizer(input_text).input_ids) - + tokens_to_generate - + example_tokens + cached_input, cached_gen_prefix = _vt_build_cached_sample( + input_text, _vt_randomize_icl(icl_str) + ) + length = _runtime_prompt_budget_length( + tokenizer, cached_input, cached_gen_prefix, tokens_to_generate ) assert length <= budget break @@ -587,30 +638,14 @@ def _vt_generate_samples( if answer is None: continue - # Insert ICL example between any model template prefix and the task template - cutoff = input_text.index(VT_CONFIG["template"][:20]) - input_text = ( - input_text[:cutoff] - + _vt_randomize_icl(icl_str) - + "\n\n" - + input_text[cutoff:] - ) - - # Split off gen_prefix (the answer_prefix at the end of VT_TEMPLATE) - gen_prefix_index = input_text.rfind( - " Answer: According to the chain(s) of variable assignment" - ) - gen_prefix = input_text[gen_prefix_index:].strip() - input_text = input_text[:gen_prefix_index] - write_jsons.append( { "index": index, - "input": input_text, + "input": cached_input, "outputs": answer, "length": length, "max_length": max_seq_length, - "gen_prefix": gen_prefix, + "gen_prefix": cached_gen_prefix, } ) @@ -624,7 +659,7 @@ def _vt_generate_samples( # --------------------------------------------------------------------------- CWE_CONFIG = { - "tokens_to_generate": 120, + "tokens_to_generate": 100, "template": ( "Below is a numbered list of words. In these words, some appear more often than others. " "Memorize the ones that appear most often.\n{context}\n" @@ -692,7 +727,7 @@ def _cwe_generate_samples( max_seq_length: int, num_samples: int = NUM_SAMPLES, incremental: int = 10, - tokens_to_generate: int = 120, + tokens_to_generate: int = CWE_CONFIG["tokens_to_generate"], ) -> list[dict]: words = _get_cwe_words() budget = max_seq_length @@ -703,14 +738,9 @@ def _cwe_generate_samples( input_example, input_text, answer = _cwe_generate_input_output( num_words, max_seq_length, words ) + full_input, gen_prefix = _cwe_build_cached_sample(input_example, input_text) total_tokens = len( - tokenizer( - input_example - + "\n" - + input_text - + " " - + " ".join(f"{i+1}. {w}" for i, w in enumerate(answer)) - ).input_ids + tokenizer(_build_runtime_prompt(full_input, gen_prefix)).input_ids ) if total_tokens + tokens_to_generate > budget: num_words -= incremental @@ -737,7 +767,12 @@ def _cwe_generate_samples( input_example, input_text, answer = _cwe_generate_input_output( used_words, max_seq_length, words ) - length = len(tokenizer(input_text).input_ids) + tokens_to_generate + full_input, gen_prefix = _cwe_build_cached_sample( + input_example, input_text + ) + length = _runtime_prompt_budget_length( + tokenizer, full_input, gen_prefix, tokens_to_generate + ) assert length <= budget break except Exception: @@ -749,13 +784,10 @@ def _cwe_generate_samples( if answer is None: continue - gen_prefix_idx = input_text.rfind(CWE_CONFIG["answer_prefix"]) - gen_prefix = input_text[gen_prefix_idx:].strip() - input_text = input_text[:gen_prefix_idx] write_jsons.append( { "index": index, - "input": input_text.strip(), + "input": full_input, "outputs": answer, "length": length, "max_length": max_seq_length, @@ -776,7 +808,7 @@ def _cwe_generate_samples( "tokens_to_generate": 50, "template": ( "Read the following coded text and track the frequency of each coded word. " - "Find the three most frequently appeared coded words. {context}\n" + "Find the three most frequently appeared coded words.\n{context}\n" "Question: Do not provide any explanation. Please ignore the dots '....'. " "What are the three most frequently appeared words in the above coded text?" ), @@ -875,14 +907,12 @@ def _fwe_generate_samples( alpha=alpha, sample_seed=sample_seed, ) - length = len(tokenizer(input_text).input_ids) + tokens_to_generate + input_text_clean, gen_prefix = _fwe_build_cached_sample(input_text) + length = _runtime_prompt_budget_length( + tokenizer, input_text_clean, gen_prefix, tokens_to_generate + ) assert length <= budget - # Strip answer prefix from input - ans_prefix_idx = input_text.rfind(FWE_CONFIG["answer_prefix"]) - gen_prefix = input_text[ans_prefix_idx:].strip() - input_text_clean = input_text[:ans_prefix_idx].strip() - write_jsons.append( { "index": index, @@ -904,7 +934,7 @@ def _fwe_generate_samples( # --------------------------------------------------------------------------- QA_CONFIG = { - "tokens_to_generate": 32, + "tokens_to_generate": 50, "template": ( "Answer the question based on the given documents. " "Only give me the answer and do not output any other words.\n\n" @@ -1033,7 +1063,7 @@ def _qa_generate_samples( qas: list[dict], max_seq_length: int, num_samples: int = NUM_SAMPLES, - tokens_to_generate: int = 32, + tokens_to_generate: int = QA_CONFIG["tokens_to_generate"], incremental: int = 10, ) -> list[dict]: budget = max_seq_length @@ -1043,7 +1073,7 @@ def _qa_generate_samples( total_tokens = 0 while total_tokens + tokens_to_generate < budget: input_text, _ = _qa_generate_input_output(0, num_docs, qas=qas, docs=docs) - prompt = input_text + " " + gen_prefix + prompt = _build_runtime_prompt(input_text, gen_prefix) total_tokens = len(tokenizer(prompt).input_ids) if total_tokens + tokens_to_generate > budget: num_docs -= incremental @@ -1069,7 +1099,7 @@ def _qa_generate_samples( input_text, answer = _qa_generate_input_output( index, used_docs, qas=qas, docs=docs ) - prompt = input_text + " " + gen_prefix + prompt = _build_runtime_prompt(input_text, gen_prefix) length = len(tokenizer(prompt).input_ids) + tokens_to_generate assert length <= budget break @@ -1110,7 +1140,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_haystack="repeat", type_needle_k="words", type_needle_v="numbers", - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_SINGLE, ) elif subset == "niah_single_2": _ensure_nltk() @@ -1121,7 +1151,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_haystack="essay", type_needle_k="words", type_needle_v="numbers", - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_SINGLE, ) elif subset == "niah_single_3": _ensure_nltk() @@ -1132,7 +1162,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_haystack="essay", type_needle_k="words", type_needle_v="uuids", - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_SINGLE, ) elif subset == "niah_multikey_1": _ensure_nltk() @@ -1144,7 +1174,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_needle_k="words", type_needle_v="numbers", num_needle_k=4, - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_SINGLE, ) elif subset == "niah_multikey_2": return _niah_generate_samples( @@ -1155,7 +1185,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_needle_k="words", type_needle_v="numbers", num_needle_k=4, - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_SINGLE, ) elif subset == "niah_multikey_3": return _niah_generate_samples( @@ -1166,7 +1196,8 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_needle_k="uuids", type_needle_v="uuids", num_needle_k=4, - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_SINGLE, + tokens_to_generate=100, ) elif subset == "niah_multiquery": _ensure_nltk() @@ -1178,7 +1209,8 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_needle_k="words", type_needle_v="numbers", num_needle_q=4, - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_MULTI, + tokens_to_generate=100, ) elif subset == "niah_multivalue": _ensure_nltk() @@ -1190,7 +1222,8 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_needle_k="words", type_needle_v="numbers", num_needle_v=4, - template=NIAH_TEMPLATE, + template=NIAH_TEMPLATE_MULTI, + tokens_to_generate=300 if length == 4096 else 250, ) elif subset == "vt": return _vt_generate_samples(tokenizer, max_seq_length=length) @@ -1272,19 +1305,27 @@ def ensure_ruler_cache( # --------------------------------------------------------------------------- +def _build_runtime_prompt(input_text: str, gen_prefix: str = "") -> str: + """Build the exact prompt text used at evaluation time. + + RULER caches `input` and `gen_prefix` separately, then `ruler_prompt` + joins them with a newline before sending the prompt to the model. + """ + if gen_prefix: + return input_text + "\n" + gen_prefix + return input_text + + def ruler_prompt(line: dict, task_name: str = None) -> Doc: outputs = line["outputs"] - # Append gen_prefix so the model receives the completion-eliciting prefix - # (matches lm-eval-harness YAML: doc_to_text="{{input}}" + gen_prefix="{{gen_prefix}}") - query = line["input"] - gp = line.get("gen_prefix", "") - if gp: - query = query + " " + gp + gen_prefix = line.get("gen_prefix", "") + query = _build_runtime_prompt(line["input"], gen_prefix) return Doc( query=query, choices=outputs, gold_index=list(range(len(outputs))), task_name=task_name, + specific={"gen_prefix": gen_prefix} if gen_prefix else None, ) @@ -1315,11 +1356,16 @@ def get_ruler_tasks( # subset never triggers generation of all subsets. cache_path = cache_base / str(length) / subset metric = [Metrics.ruler_match] + _niah_gen_sizes = { + "niah_multikey_3": 100, + "niah_multiquery": 100, + "niah_multivalue": 300 if length == 4096 else 250, + } gen_size = ( - 128 + _niah_gen_sizes.get(subset, 50) if "niah" in subset else ( - 30 if subset == "vt" else 120 if subset == "cwe" else 50 + 50 if subset == "vt" else 100 if subset == "cwe" else 50 ) # fwe and qa ) tasks.append( diff --git a/tests/unit/metrics/test_ruler_metrics.py b/tests/unit/metrics/test_ruler_metrics.py new file mode 100644 index 000000000..0817057dd --- /dev/null +++ b/tests/unit/metrics/test_ruler_metrics.py @@ -0,0 +1,79 @@ +# MIT License +# +# Copyright (c) 2024 The HuggingFace Team +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import pytest + +from lighteval.metrics.metrics_sample import RulerStringMatch, RulerStringMatchAny +from lighteval.models.model_output import ModelResponse +from lighteval.tasks.requests import Doc + + +def make_doc(golds: list[str], task_name: str) -> Doc: + return Doc(query="", choices=golds, gold_index=list(range(len(golds))), task_name=task_name) + + +@pytest.mark.parametrize( + ("prediction", "golds"), + [ + ("eiffel tower", ["The Eiffel Tower"]), + ("Answer: eiffel tower", ["The Eiffel Tower"]), + ("ANSWER: the eiffel tower", ["eiffel tower"]), + ("the answer is london", ["Paris", "London"]), + ], +) +def test_ruler_qa_match_uses_normalized_substring_semantics(prediction, golds): + metric = RulerStringMatch() + doc = make_doc(golds, "ruler_512:qa_2|0") + + assert metric.compute(doc, ModelResponse(text=[prediction])) == 1.0 + + +def test_ruler_qa_parses_answer_prefix(): + assert RulerStringMatchAny._parse_output("Answer: Paris\nExtra") == "Paris" + + +def test_ruler_qa_returns_zero_when_no_gold_matches(): + metric = RulerStringMatch() + doc = make_doc(["Paris"], "ruler_512:qa_1|0") + + assert metric.compute(doc, ModelResponse(text=["London"])) == 0.0 + + +def test_ruler_non_qa_full_recall_scores_one(): + metric = RulerStringMatch() + doc = make_doc(["alpha", "beta"], "ruler_512:cwe|0") + + assert metric.compute(doc, ModelResponse(text=["alpha beta gamma"])) == 1.0 + + +def test_ruler_non_qa_partial_recall_scores_fraction(): + metric = RulerStringMatch() + doc = make_doc(["alpha", "beta", "gamma"], "ruler_512:vt|0") + + assert metric.compute(doc, ModelResponse(text=["alpha gamma"])) == pytest.approx(2 / 3) + + +def test_ruler_non_qa_empty_prediction_scores_zero(): + metric = RulerStringMatch() + doc = make_doc(["alpha", "beta"], "ruler_512:fwe|0") + + assert metric.compute(doc, ModelResponse(text=[""])) == 0.0 diff --git a/tests/unit/tasks/test_ruler.py b/tests/unit/tasks/test_ruler.py new file mode 100644 index 000000000..8f65bad4c --- /dev/null +++ b/tests/unit/tasks/test_ruler.py @@ -0,0 +1,226 @@ +# MIT License +# +# Copyright (c) 2024 The HuggingFace Team +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from lighteval.tasks.tasks import ruler + + +def test_ruler_prompt_appends_generation_prefix(): + doc = ruler.ruler_prompt( + { + "input": "Question text", + "gen_prefix": "Answer:", + "outputs": ["Paris"], + }, + "ruler_512:qa_1|0", + ) + + assert doc.query == "Question text\nAnswer:" + assert doc.choices == ["Paris"] + assert doc.gold_index == [0] + assert doc.task_name == "ruler_512:qa_1|0" + + +def test_runtime_prompt_helper_matches_ruler_prompt(): + line = { + "input": "Question text", + "gen_prefix": "Answer:", + "outputs": ["Paris"], + } + doc = ruler.ruler_prompt(line, "ruler_512:qa_1|0") + + assert ruler._build_runtime_prompt(line["input"], line["gen_prefix"]) == doc.query + + +def test_ruler_generation_sizes_match_reference_defaults(): + tasks = ruler.get_ruler_tasks( + "dummy-tokenizer", + lengths=[512], + subsets=["niah_single_1", "vt", "cwe", "fwe", "qa_2"], + ) + task_map = {task.name: task.generation_size for task in tasks} + + assert task_map["ruler_512:niah_single_1"] == 50 + assert task_map["ruler_512:vt"] == 50 + assert task_map["ruler_512:cwe"] == 100 + assert task_map["ruler_512:fwe"] == 50 + assert task_map["ruler_512:qa_2"] == 50 + + +def test_ruler_qa_prompt_matches_reference_template(): + prompt, answers = ruler._qa_generate_input_output( + 0, + 1, + qas=[{"query": "Where is the Eiffel Tower?", "outputs": ["Paris"], "context": [0]}], + docs=["Paris is the capital of France."], + ) + + assert prompt == ( + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "The following are given documents.\n\n" + "Document 1:\nParis is the capital of France.\n\n" + "Answer the question based on the given documents. " + "Only give me the answer and do not output any other words.\n\n" + "Question: Where is the Eiffel Tower?" + ) + assert answers == ["Paris"] + + +def test_ruler_uses_singular_niah_template_for_single_tasks(monkeypatch): + captured = {} + + def fake_niah_generate_samples(*args, **kwargs): + captured["template"] = kwargs["template"] + return [] + + monkeypatch.setattr(ruler, "_niah_generate_samples", fake_niah_generate_samples) + + ruler._generate_subset("niah_single_1", 128, tokenizer=None) + + assert captured["template"] == ruler.NIAH_TEMPLATE_SINGLE + + +def test_ruler_uses_plural_niah_template_for_multi_tasks(monkeypatch): + captured = {} + + def fake_niah_generate_samples(*args, **kwargs): + captured["template"] = kwargs["template"] + return [] + + monkeypatch.setattr(ruler, "_niah_generate_samples", fake_niah_generate_samples) + monkeypatch.setattr(ruler, "_ensure_nltk", lambda: None) + monkeypatch.setattr(ruler, "_get_essay_haystack", lambda: ["one", "two"]) + + ruler._generate_subset("niah_multivalue", 128, tokenizer=None) + + assert captured["template"] == ruler.NIAH_TEMPLATE_MULTI + + +def test_ruler_budget_helper_uses_newline_before_generation_prefix(): + prompt = ruler._build_runtime_prompt("Question text", "Answer:") + + assert prompt == "Question text\nAnswer:" + + +class _FakeTokenizer: + def __call__(self, text): + return type("Tokens", (), {"input_ids": list(range(len(text)))})() + + +def test_qa_length_matches_runtime_prompt_shape(): + input_text = "Question text" + gen_prefix = "Answer:" + tokens_to_generate = ruler.QA_CONFIG["tokens_to_generate"] + + expected = len(_FakeTokenizer()(ruler._build_runtime_prompt(input_text, gen_prefix)).input_ids) + tokens_to_generate + + assert expected == len("Question text\nAnswer:") + tokens_to_generate + + +def test_cwe_full_prompt_length_includes_icl_example(): + full_input = "Example block\nTask text" + gen_prefix = "Answer: The top 10 words that appear most often in the list are:" + tokens_to_generate = ruler.CWE_CONFIG["tokens_to_generate"] + + expected = len(_FakeTokenizer()(ruler._build_runtime_prompt(full_input, gen_prefix)).input_ids) + tokens_to_generate + + assert expected == len(full_input + "\n" + gen_prefix) + tokens_to_generate + + +def test_vt_cached_sample_uses_final_runtime_prompt_shape(): + task_prompt = ruler.VT_TEMPLATE.format( + context="VAR A = 12345", + query="12345", + num_v=5, + ) + cached_input, gen_prefix = ruler._vt_build_cached_sample(task_prompt, "ICL EXAMPLE") + length = ruler._runtime_prompt_budget_length( + _FakeTokenizer(), + cached_input, + gen_prefix, + ruler.VT_CONFIG["tokens_to_generate"], + ) + + assert cached_input == ( + "ICL EXAMPLE\n\n" + "Memorize and track the chain(s) of variable assignment hidden in the " + "following text.\n\nVAR A = 12345\n" + "Question: Find all variables that are assigned the value 12345 in the text above." + ) + assert gen_prefix == ( + "Answer: According to the chain(s) of variable assignment in the text above, " + "5 variables are assigned the value 12345, they are:" + ) + assert length == len(cached_input + "\n" + gen_prefix) + ruler.VT_CONFIG["tokens_to_generate"] + + +def test_fwe_length_matches_runtime_prompt_shape(): + full_prompt = ( + "Read the following coded text and track the frequency of each coded word. " + "Find the three most frequently appeared coded words.\nalpha beta gamma\n" + "Question: Do not provide any explanation. Please ignore the dots '....'. " + "What are the three most frequently appeared words in the above coded text?" + " Answer: According to the coded text above, the three most frequently appeared words are:" + ) + input_text, gen_prefix = ruler._fwe_build_cached_sample(full_prompt) + length = ruler._runtime_prompt_budget_length( + _FakeTokenizer(), + input_text, + gen_prefix, + ruler.FWE_CONFIG["tokens_to_generate"], + ) + + assert length == len(input_text + "\n" + gen_prefix) + ruler.FWE_CONFIG["tokens_to_generate"] + + +def test_non_qa_cached_lengths_match_retokenized_runtime_prompts(): + samples = [ + { + "input": "ICL EXAMPLE\n\nTask block", + "gen_prefix": "Answer: According to the chain(s) of variable assignment in the text above, 5 variables are assigned the value 12345, they are:", + "tokens_to_generate": ruler.VT_CONFIG["tokens_to_generate"], + "max_length": 500, + }, + { + "input": "Example block\nTask text", + "gen_prefix": "Answer: The top 10 words that appear most often in the list are:", + "tokens_to_generate": ruler.CWE_CONFIG["tokens_to_generate"], + "max_length": 500, + }, + { + "input": "Task text", + "gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:", + "tokens_to_generate": ruler.FWE_CONFIG["tokens_to_generate"], + "max_length": 500, + }, + ] + + for sample in samples: + length = ruler._runtime_prompt_budget_length( + _FakeTokenizer(), + sample["input"], + sample["gen_prefix"], + sample["tokens_to_generate"], + ) + + assert length == len(sample["input"] + "\n" + sample["gen_prefix"]) + sample["tokens_to_generate"] + assert length <= sample["max_length"] From c627beb22a49833b5d23abe0e0b21399bcd6393b Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Mon, 11 May 2026 15:30:35 +0200 Subject: [PATCH 5/8] fix stuff --- .../models/endpoints/litellm_model.py | 24 ++-- src/lighteval/tasks/lighteval_task.py | 22 ++-- src/lighteval/tasks/tasks/arc.py | 2 +- src/lighteval/tasks/tasks/coqa.py | 64 ++++++----- src/lighteval/tasks/tasks/drop_qa.py | 4 +- src/lighteval/tasks/tasks/hellaswag.py | 9 +- src/lighteval/tasks/tasks/humaneval.py | 16 ++- src/lighteval/tasks/tasks/math.py | 3 +- src/lighteval/tasks/tasks/mbpp.py | 20 ++-- src/lighteval/tasks/tasks/med.py | 10 +- src/lighteval/tasks/tasks/mmlu.py | 2 +- src/lighteval/tasks/tasks/mt_mbpp.py | 13 ++- src/lighteval/tasks/tasks/winogrande.py | 108 +++++++++++++----- 13 files changed, 196 insertions(+), 101 deletions(-) diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 65885d032..3cf03e31f 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -827,19 +827,29 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: [(0.0, False)] * len(doc.choices) for doc in doc_list ] - # Flat work list: (doc_idx, choice_idx, context, choice) + # Pre-compute token lengths once per context (not 4× per context via choices) + # to avoid redundant BPE tokenization in the thread pool. + context_lens: list[int] = [self._count_tokens(ctx) for ctx in context_list] + + # Pre-compute choice lengths once per unique choice text. + _unique_choices: dict[str, int] = {} + for doc in doc_list: + for choice in doc.choices: + if choice not in _unique_choices: + _unique_choices[choice] = self._count_tokens(choice) + + # Flat work list: (doc_idx, choice_idx, context, choice, context_len, choice_len) work = [ - (di, ci, context_list[di], choice) + (di, ci, context_list[di], choice, context_lens[di], _unique_choices[choice]) for di, doc in enumerate(doc_list) for ci, choice in enumerate(doc.choices) ] def _score_batch_work(batch_items: list[tuple]) -> list[tuple]: - """Tokenize a batch then score all pairs in one API call.""" + """Score a batch of (context+choice) pairs in a single batched echo call.""" prepared = [] - for di, ci, context, choice in batch_items: - choice_len = self._count_tokens(choice) - total_len = self._count_tokens(context + choice) + for di, ci, context, choice, ctx_len, ch_len in batch_items: + total_len = ctx_len + ch_len prompt = context + choice # Left-truncate (OLMES-style): drop early few-shot examples when # the combined prompt exceeds the model's context window. @@ -851,7 +861,7 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]: f"Prompt too long ({total_len} tokens > {max_input} max); left-truncating." ) prompt = self._left_truncate_tokens(prompt, max_input) - prepared.append((di, ci, prompt, choice_len)) + prepared.append((di, ci, prompt, ch_len)) return self._score_batch(prepared) # Chunk work into batches — each batch becomes a single API call with a list diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index ddbccb1ad..370e3a249 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -577,14 +577,20 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: if doc is None or doc == []: continue - doc.id = str(ix) - - # Transfer task-level generation parameters to the document - doc.generation_grammar = self.generation_grammar - doc.generation_size = self.generation_size - doc.stop_sequences = self.stop_sequence - - docs.append(doc) + # Support multi-doc returns (list of Docs per row, e.g. multi-turn tasks) + if isinstance(doc, list): + for sub_ix, sub_doc in enumerate(doc): + sub_doc.id = f"{ix}_{sub_ix}" + sub_doc.generation_grammar = self.generation_grammar + sub_doc.generation_size = self.generation_size + sub_doc.stop_sequences = self.stop_sequence + docs.append(sub_doc) + else: + doc.id = str(ix) + doc.generation_grammar = self.generation_grammar + doc.generation_size = self.generation_size + doc.stop_sequences = self.stop_sequence + docs.append(doc) return docs diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py index afad00779..13bd1ae1a 100644 --- a/src/lighteval/tasks/tasks/arc.py +++ b/src/lighteval/tasks/tasks/arc.py @@ -52,7 +52,7 @@ def arc_prompt(line, task_name: str = None): def arc_mcf_prompt(line, task_name: str = None): query = f"Question: {line['question']}\n" - query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"]["text"])]) + query += "".join([f" {key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"]["text"])]) query += "Answer:" gold_ix = line["choices"]["label"].index(line["answerKey"]) diff --git a/src/lighteval/tasks/tasks/coqa.py b/src/lighteval/tasks/tasks/coqa.py index da59985ed..f1bd3bad6 100644 --- a/src/lighteval/tasks/tasks/coqa.py +++ b/src/lighteval/tasks/tasks/coqa.py @@ -43,45 +43,55 @@ def _build_query(story: str, questions: list, answers: list, turn_idx: int) -> s def coqa_gen_prompt(line, task_name: str = None): - """GenQA: evaluate all turns, build full preceding-QA context per turn.""" + """GenQA: evaluate ALL turns, build full preceding-QA context per turn. + + Returns a list of Docs (one per turn), matching OLMO's CoQA._process_doc_to_multi. + The lighteval_task.py extension (list-of-Docs support) handles multi-doc rows. + """ questions = line["questions"]["input_text"] answers = line["answers"]["input_text"] if not questions or not answers: return None - # Evaluate all turns; return last turn with full context (lighteval calls - # prompt_function once per row — we evaluate first turn as primary instance). - # For full multi-turn eval, the dataset would need to be pre-flattened. - turn_idx = 0 - gold_text = answers[turn_idx] - if not gold_text: - return None is_few_shots = line.get("__few_shots", False) - return Doc( - task_name=task_name, - query=_build_query(line["story"], questions, answers, turn_idx), - choices=[f"{' ' if is_few_shots else ''}{gold_text}"], - gold_index=0, - ) + docs = [] + for turn_idx in range(len(questions)): + gold_text = answers[turn_idx] + if not gold_text: + continue + docs.append(Doc( + task_name=task_name, + query=_build_query(line["story"], questions, answers, turn_idx), + choices=[f"{' ' if is_few_shots else ''}{gold_text}"], + gold_index=0, + )) + return docs if docs else None def coqa_bpb_prompt(line, task_name: str = None): - """BPB/CF: first turn with passage context.""" + """BPB: evaluate ALL turns of a story, returning one Doc per turn. + + Matches OLMO which creates one evaluation instance per (story, turn) pair. + The full preceding-QA context is built for each turn. + """ questions = line["questions"]["input_text"] answers = line["answers"]["input_text"] if not questions or not answers: return None - turn_idx = 0 - gold_text = answers[turn_idx] - if not gold_text: - return None - if not gold_text[0].isspace(): - gold_text = " " + gold_text - return Doc( - task_name=task_name, - query=_build_query(line["story"], questions, answers, turn_idx), - choices=[gold_text], - gold_index=0, - ) + + docs = [] + for turn_idx in range(len(questions)): + gold_text = answers[turn_idx] + if not gold_text: + continue + if not gold_text[0].isspace(): + gold_text = " " + gold_text + docs.append(Doc( + task_name=task_name, + query=_build_query(line["story"], questions, answers, turn_idx), + choices=[gold_text], + gold_index=0, + )) + return docs if docs else None TASKS_TABLE = [ diff --git a/src/lighteval/tasks/tasks/drop_qa.py b/src/lighteval/tasks/tasks/drop_qa.py index f62793987..838ae0952 100644 --- a/src/lighteval/tasks/tasks/drop_qa.py +++ b/src/lighteval/tasks/tasks/drop_qa.py @@ -76,7 +76,7 @@ def drop_bpb_prompt(line, task_name: str = None): few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, - stop_sequence=["\n"], + stop_sequence=["\n\n", "Passage:", "Question:"], metrics=[Metrics.target_bits_per_byte], version=0, ) @@ -90,7 +90,7 @@ def drop_bpb_prompt(line, task_name: str = None): few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=100, - stop_sequence=["\n"], + stop_sequence=["\n\n", "Passage:", "Question:"], metrics=[Metrics.drop], version=0, ) diff --git a/src/lighteval/tasks/tasks/hellaswag.py b/src/lighteval/tasks/tasks/hellaswag.py index 14d9462fc..940ef8d61 100644 --- a/src/lighteval/tasks/tasks/hellaswag.py +++ b/src/lighteval/tasks/tasks/hellaswag.py @@ -52,8 +52,9 @@ def hellaswag_mcf_prompt(line, task_name: str = None): """MCF variant: labeled options in prompt, score label tokens via logprobs.""" ctx = line["ctx_a"] + " " + line["ctx_b"].capitalize() query = harness_preprocess(line["activity_label"] + ": " + ctx) + query += "\nChoose the best continuation:" query += "".join( - [f"\n{key}. {harness_preprocess(choice)}" for key, choice in zip(ascii_uppercase, line["endings"])] + [f"\n {key}. {harness_preprocess(choice)}" for key, choice in zip(ascii_uppercase, line["endings"])] ) query += "\nAnswer:" @@ -88,7 +89,7 @@ def hellaswag_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, + few_shots_split="test", few_shots_select=None, generation_size=-1, metrics=_MCF_METRICS, @@ -104,7 +105,7 @@ def hellaswag_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, + few_shots_split="test", few_shots_select=None, generation_size=1, metrics=[Metrics.exact_match], @@ -119,7 +120,7 @@ def hellaswag_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, + few_shots_split="test", few_shots_select=None, generation_size=-1, metrics=_CF_METRICS, diff --git a/src/lighteval/tasks/tasks/humaneval.py b/src/lighteval/tasks/tasks/humaneval.py index 9d8a353fe..d61c5763a 100644 --- a/src/lighteval/tasks/tasks/humaneval.py +++ b/src/lighteval/tasks/tasks/humaneval.py @@ -13,14 +13,20 @@ from lighteval.tasks.requests import Doc +_ANSWER_PREFIX = "Here is the completed function:\n\n```python\n" + + def humaneval_bpb_prompt(line, task_name=None): - """Prompt = function signature + docstring; gold = canonical solution body.""" + """Prompt = function signature + answer_prefix; gold = canonical solution. + + Matches OLMO: query = prompt + answer_prefix so the gold continuation is + scored in the context of the opening code fence, not the raw docstring end. + Leading space is added to the gold to align with OLMO's perplexity_leading_space=True. + """ return Doc( task_name=task_name, - # `prompt` already ends at the opening of the function body (after the - # closing triple-quote of the docstring), ready for the solution to follow. - query=line["prompt"], - choices=[line["canonical_solution"]], + query=line["prompt"] + _ANSWER_PREFIX, + choices=[" " + line["canonical_solution"]], gold_index=0, ) diff --git a/src/lighteval/tasks/tasks/math.py b/src/lighteval/tasks/tasks/math.py index b335d8986..7fac4dc78 100644 --- a/src/lighteval/tasks/tasks/math.py +++ b/src/lighteval/tasks/tasks/math.py @@ -33,7 +33,8 @@ """.strip() # OLMo easy-suite style: minimal prompt, 4-shot from training data -MATH_BPB_PROMPT_TEMPLATE = "Problem: {prompt}\nSolution:" +# Format matches OLMO's minerva cot_style: "Problem:\n{problem}\n\nSolution:" +MATH_BPB_PROMPT_TEMPLATE = "Problem:\n{prompt}\n\nSolution:" _MATH_SUBSETS = [ 'algebra', diff --git a/src/lighteval/tasks/tasks/mbpp.py b/src/lighteval/tasks/tasks/mbpp.py index 32b041759..a23ebcaaf 100644 --- a/src/lighteval/tasks/tasks/mbpp.py +++ b/src/lighteval/tasks/tasks/mbpp.py @@ -18,19 +18,17 @@ def mbpp_bpb_prompt(line, task_name=None): - """Build a docstring prompt; gold = reference solution code.""" - tests = "\n".join(line.get("test_list", [])) - # Format as a Python docstring block so the model sees: - # """ - # - # - # """ - # def function_name(...): ← start of gold continuation - prompt = f'"""\n{line["prompt"]}\n{tests}\n"""\n' + """Prompt = task description + function header; gold = full code with leading space. + + Matches OLMO's default MBPP format: query ends at the function header colon + so the model predicts the entire function (signature + body) as the gold. + Leading space matches OLMO's perplexity_leading_space=True. + """ + query = line["prompt"] + line["code"].split(":")[0] + ":" return Doc( task_name=task_name, - query=prompt, - choices=[line["code"]], + query=query, + choices=[" " + line["code"]], gold_index=0, ) diff --git a/src/lighteval/tasks/tasks/med.py b/src/lighteval/tasks/tasks/med.py index 90b730bcd..2493fd3bc 100644 --- a/src/lighteval/tasks/tasks/med.py +++ b/src/lighteval/tasks/tasks/med.py @@ -39,11 +39,14 @@ def med_mcqa_mcf_prompt(line, task_name: str = None): - """MCF variant: labeled A/B/C/D options, score label tokens via logprobs.""" - query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n" + """MCF variant: labeled A/B/C/D options, score label tokens via logprobs. + + Matches OLMO's MedMCQAMC: plain Question + options, no instruction prefix. + """ + query = f"Question: {line['question']}\n" query += "".join( [ - f"{key}. {choice}\n" + f" {key}. {choice}\n" for key, choice in zip(ascii_uppercase, [line["opa"], line["opb"], line["opc"], line["opd"]]) ] ) @@ -53,7 +56,6 @@ def med_mcqa_mcf_prompt(line, task_name: str = None): query=query, choices=[" " + c for c in list(ascii_uppercase)[:4]], gold_index=line["cop"] - 1, - instruction="Give a letter answer among A, B, C or D.\n", ) diff --git a/src/lighteval/tasks/tasks/mmlu.py b/src/lighteval/tasks/tasks/mmlu.py index dc31502c5..c4b04c8f0 100644 --- a/src/lighteval/tasks/tasks/mmlu.py +++ b/src/lighteval/tasks/tasks/mmlu.py @@ -124,7 +124,7 @@ def mmlu_mcf_prompt(line, task_name: str = None): subject = line["subject"] query = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}" query += "".join( - [f"\n{key}. {choice}" for key, choice in zip(ascii_uppercase, line["choices"])] + [f"\n {key}. {choice}" for key, choice in zip(ascii_uppercase, line["choices"])] ) query += "\nAnswer:" diff --git a/src/lighteval/tasks/tasks/mt_mbpp.py b/src/lighteval/tasks/tasks/mt_mbpp.py index 32abdfcdd..695907746 100644 --- a/src/lighteval/tasks/tasks/mt_mbpp.py +++ b/src/lighteval/tasks/tasks/mt_mbpp.py @@ -49,11 +49,18 @@ def mt_mbpp_bpb_prompt(line, task_name=None): - """Prompt = problem description (text); gold = code solution.""" + """Prompt = problem description + opening code fence; gold = code + closing fence. + + Matches OLMO: the query ends with ```{language}\n so the model scores the + code body and closing fence as the gold continuation. perplexity_leading_space + is implicitly False here (gold has no leading space). + """ + lang = line["language"] + code = line["code"].strip() return Doc( task_name=task_name, - query=line["text"], - choices=[line["code"]], + query=line["text"].strip() + f"\n```{lang}\n", + choices=[code + "\n```"], gold_index=0, ) diff --git a/src/lighteval/tasks/tasks/winogrande.py b/src/lighteval/tasks/tasks/winogrande.py index 0e5f7b361..be13d4210 100644 --- a/src/lighteval/tasks/tasks/winogrande.py +++ b/src/lighteval/tasks/tasks/winogrande.py @@ -22,18 +22,12 @@ https://arxiv.org/abs/1907.10641 """ -from lighteval.metrics.metrics import Metrics from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric +from lighteval.metrics.metrics import Metrics from lighteval.metrics.normalizations import LogProbCharNorm from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -_CF_METRICS = [ - LogLikelihoodAccMetric(), - LogLikelihoodAccMetric(normalization=LogProbCharNorm()), - Metrics.target_bits_per_byte, -] - _MCF_METRICS = [ LogLikelihoodAccMetric(), LogLikelihoodAccMetric(normalization=LogProbCharNorm()), @@ -41,25 +35,38 @@ def winogrande_cf_prompt(line, task_name: str = None): - """CF variant: completion-style, score full answer texts via logprobs.""" - query, end_of_target = line["sentence"].split("_") - end_of_target = end_of_target.strip() + """Partial evaluation (Trinh & Le 2018), matching OLMO's implementation. + + Context = sentence with the gold option substituted in (prefix + option). + Gold = the sentence suffix after the blank, with a leading space. + BPB = logprob(suffix | prefix + gold_option) / bytes(suffix). + + Accuracy is not meaningful in this single-choice format (always 1.0), so + only BPB is reported. Use winogrande:mcf for accuracy scoring. + """ + sentence = line["sentence"] + blank_pos = sentence.index("_") + prefix = sentence[:blank_pos].rstrip() + suffix = sentence[blank_pos + 1:].strip() + + gold_ix = int(line["answer"]) - 1 if line["answer"].strip() else -1 + gold_option = [line["option1"], line["option2"]][gold_ix] + return Doc( task_name=task_name, - query=query.rstrip(), - choices=[f" {line['option1']} {end_of_target}", f" {line['option2']} {end_of_target}"], - gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1, + query=prefix + " " + gold_option, + choices=[" " + suffix], + gold_index=0, ) def winogrande_mcf_prompt(line, task_name: str = None): - """MCF variant: labeled A/B options, score label tokens via logprobs.""" - query, end_of_target = line["sentence"].split("_") - end_of_target = end_of_target.strip() - query = query.strip() - opt1 = f"{line['option1']} {end_of_target}" - opt2 = f"{line['option2']} {end_of_target}" - query = f"{query}\nA. {opt1}\nB. {opt2}\nAnswer:" + """MCF variant: Fill-in-the-blank format matching OLMO's WinograndeMC. + + Replaces _ with ___ in the sentence, lists options as A/B without the suffix. + """ + sentence_with_blank = line["sentence"].replace("_", "___") + query = f"Fill in the blank: {sentence_with_blank}\n A. {line['option1']}\n B. {line['option2']}\nAnswer:" gold_ix = int(line["answer"]) - 1 if line["answer"] != "" else -1 return Doc( task_name=task_name, @@ -69,7 +76,32 @@ def winogrande_mcf_prompt(line, task_name: str = None): ) -# CF variant: completion-style, logprob on full answer text + BPB on gold choice +def winogrande_rc_prompt(line, task_name: str = None): + """RC (rank choice) variant: standard CF approach for accuracy. + + Context = sentence prefix before blank. + Choices = full sentence completions with each option substituted in. + Approximates OLMO's RC_none (partial eval) within lighteval's Doc framework: + computes P(option+suffix | prefix) rather than P(suffix | prefix+option). + """ + sentence = line["sentence"] + blank_pos = sentence.index("_") + prefix = sentence[:blank_pos].rstrip() + suffix = " " + sentence[blank_pos + 1:].strip() + gold_ix = int(line["answer"]) - 1 if line["answer"].strip() else -1 + return Doc( + task_name=task_name, + query=prefix, + choices=[ + " " + line["option1"] + suffix, + " " + line["option2"] + suffix, + ], + gold_index=gold_ix, + ) + + +# CF variant: partial evaluation (OLMO-style), BPB only +# Accuracy is not reported because partial eval uses one fixed gold choice. winogrande_cf = LightevalTaskConfig( name="winogrande:cf", prompt_function=winogrande_cf_prompt, @@ -80,12 +112,12 @@ def winogrande_mcf_prompt(line, task_name: str = None): few_shots_split=None, few_shots_select="random_sampling", generation_size=-1, - metrics=_CF_METRICS, + metrics=[Metrics.target_bits_per_byte], stop_sequence=["\n"], version=0, ) -# MCF variant: labeled options, score label tokens via logprobs (TRUE MCF) +# MCF variant: "Fill in the blank:" format matching OLMO's WinograndeMC winogrande_mcf = LightevalTaskConfig( name="winogrande:mcf", prompt_function=winogrande_mcf_prompt, @@ -93,8 +125,8 @@ def winogrande_mcf_prompt(line, task_name: str = None): hf_subset="winogrande_xl", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, - few_shots_select="random_sampling", + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_MCF_METRICS, stop_sequence=["\n"], @@ -109,16 +141,38 @@ def winogrande_mcf_prompt(line, task_name: str = None): hf_subset="winogrande_xl", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, - few_shots_select="random_sampling", + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=1, metrics=[Metrics.exact_match], stop_sequence=["\n"], version=0, ) +# RC variant: standard cloze-form for accuracy (approximates Table 46 RC_none). +# Scores P(option+suffix | prefix) rather than OLMO's P(suffix | prefix+option), +# but gives meaningful accuracy using lighteval's standard Doc framework. +winogrande_rc = LightevalTaskConfig( + name="winogrande:rc", + prompt_function=winogrande_rc_prompt, + hf_repo="allenai/winogrande", + hf_subset="winogrande_xl", + hf_avail_splits=["train", "test", "validation"], + evaluation_splits=["validation"], + few_shots_split="train", + few_shots_select="random_sampling_from_train", + generation_size=-1, + metrics=[ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), + ], + stop_sequence=["\n"], + version=0, +) + TASKS_TABLE = [ winogrande_cf, winogrande_mcf, winogrande_mcf_em, + winogrande_rc, ] From 9b1dc98ab81b43102c242b290e97967ada95b942 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Mon, 11 May 2026 16:04:49 +0200 Subject: [PATCH 6/8] fix pr#2 comments --- src/lighteval/metrics/sample_preparator.py | 5 +- .../tasks/multilingual/tasks/flores200.py | 2 +- .../tasks/multilingual/tasks/global_mmlu.py | 2 +- .../tasks/multilingual/tasks/mgsm.py | 59 +++++++++++++++---- .../tasks/multilingual/tasks/mmlu_prox.py | 2 +- src/lighteval/tasks/tasks/mmlu_pro.py | 9 +-- src/lighteval/tasks/tasks/mmlu_prox.py | 2 +- src/lighteval/tasks/templates/hellaswag.py | 2 +- .../templates/utils/translation_literals.py | 10 +++- 9 files changed, 67 insertions(+), 26 deletions(-) diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index a80a1e573..d84f28c7e 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -110,7 +110,10 @@ def prepare(doc: Doc, model_response: ModelResponse, **kwargs) -> "COMETCorpusMe """ source = (doc.specific or {}).get("source_text", "") golds = as_list(doc.get_golds()) - return COMETCorpusMetricInput(source=source, hyp=model_response.final_text, ref=golds) + preds = model_response.final_text + if len(preds) > 1: + logger.warning("Multiple predictions present, keeping only the first prediction (for COMET).") + return COMETCorpusMetricInput(source=source, hyp=preds[0], ref=golds) class LoglikelihoodPreparator(Preparator): diff --git a/src/lighteval/tasks/multilingual/tasks/flores200.py b/src/lighteval/tasks/multilingual/tasks/flores200.py index 3d536bb3b..c3e8d7e70 100644 --- a/src/lighteval/tasks/multilingual/tasks/flores200.py +++ b/src/lighteval/tasks/multilingual/tasks/flores200.py @@ -267,7 +267,7 @@ def _flores_pair_subset(lang_code: str) -> str: hf_avail_splits=("dev", "devtest"), evaluation_splits=("devtest",), few_shots_split="dev", - few_shots_select="random_sampling_from_train", + few_shots_select="random_sampling", ) for lang_code in flores_200_languages if lang_code != _ENGLISH_FLORES_CODE diff --git a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py index 6167a5018..dc594632c 100644 --- a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py +++ b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py @@ -140,7 +140,7 @@ def _mcf_adapter(line): } -_MMLU_CF_METRICS = [MMLUCategoryGroupingCF] +_MMLU_CF_METRICS = [MMLUCategoryGroupingCF, Metrics.target_bits_per_byte] _MMLU_MCF_METRICS = [MMLUCategoryGroupingMCF] diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py index 1f03a0b26..3189314b0 100644 --- a/src/lighteval/tasks/multilingual/tasks/mgsm.py +++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py @@ -3,7 +3,7 @@ Mgsm dataset: -juletxara/mgsm +CohereLabs/global-mgsm abstract: MGSM (Multilingual Grade School Math) is a multilingual benchmark testing @@ -16,8 +16,11 @@ non-ASCII digit systems like Japanese/Thai). languages: -bengali, french, german, japanese, russian, spanish, swahili, telugu, thai, -chinese +amharic, arabic, bengali, catalan, czech, welsh, german, greek, spanish, +basque, french, galician, gujarati, hausa, hungarian, japanese, khmer, +kannada, korean, kyrgyz, ganda, burmese, nepali, russian, sinhala, shona, +serbian, southern sotho, swahili, tamil, telugu, thai, urdu, uzbek, +vietnamese, wolof, xhosa, yoruba, chinese, zulu tags: math, multilingual, reasoning @@ -35,18 +38,48 @@ from lighteval.utils.language import Language -# Languages covered by juletxara/mgsm (English excluded — use gsm8k.py) +# Languages covered by CohereLabs/global-mgsm (English excluded — use gsm8k.py) _LANGUAGES = [ + Language.AMHARIC, + Language.ARABIC, + Language.BENGALI, + Language.CATALAN, + Language.CZECH, + Language.WELSH, + Language.GERMAN, + Language.GREEK, Language.SPANISH, + Language.BASQUE, Language.FRENCH, - Language.GERMAN, - Language.RUSSIAN, - Language.CHINESE, + Language.GALICIAN, + Language.GUJARATI, + Language.HAUSA, + Language.HUNGARIAN, Language.JAPANESE, - Language.THAI, + Language.KHMER, + Language.KANNADA, + Language.KOREAN, + Language.KIRGHIZ, + Language.GANDA, + Language.BURMESE, + Language.NEPALI, + Language.RUSSIAN, + Language.SINHALA, + Language.SHONA, + Language.SERBIAN, + Language.SOUTHERN_SOTHO, Language.SWAHILI, - Language.BENGALI, + Language.TAMIL, Language.TELUGU, + Language.THAI, + Language.URDU, + Language.UZBEK, + Language.VIETNAMESE, + Language.WOLOF, + Language.XHOSA, + Language.YORUBA, + Language.CHINESE, + Language.ZULU, ] @@ -57,19 +90,19 @@ language, lambda line: { "question": line["question"], - "choices": [str(line["answer_number"])], + "choices": [line["answer"]], }, ), - hf_repo="juletxara/mgsm", + hf_repo="CohereLabs/global-mgsm", hf_subset=standardize_tag(language.value), evaluation_splits=("test",), - few_shots_split="train", + few_shots_split=None, generation_size=512, metrics=[ Metrics.expr_gold_metric, MultilingualQuasiExactMatchMetric(language, "full"), ], - stop_sequence=["\n"], + stop_sequence=["Question:", "Answer:"], ) for language in _LANGUAGES ] diff --git a/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py index a21c3b5f2..c9f574e8d 100644 --- a/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py +++ b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py @@ -116,7 +116,7 @@ def mmlu_prox_mcf_prompt(line, task_name: str = None): return None labels = list(ascii_uppercase[: len(options)]) query = f"Question: {line['question'].strip()}\n" - query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)]) + query += "".join([f" {lbl}. {opt}\n" for lbl, opt in zip(labels, options)]) query += "Answer:" return Doc( task_name=task_name, diff --git a/src/lighteval/tasks/tasks/mmlu_pro.py b/src/lighteval/tasks/tasks/mmlu_pro.py index b443d577b..2ede01f50 100644 --- a/src/lighteval/tasks/tasks/mmlu_pro.py +++ b/src/lighteval/tasks/tasks/mmlu_pro.py @@ -49,13 +49,13 @@ def mmlu_pro_mcf_prompt_function(line, task_name: str = None): query = f"Answer the following multiple choice question.\n\nQuestion: {line['question'].strip()}" - query += "".join([f"\n{key}. {choice}" for key, choice in zip(ascii_uppercase, line["options"])]) - query += "\nAnswer: " + query += "".join([f"\n {key}. {choice}" for key, choice in zip(ascii_uppercase, line["options"])]) + query += "\nAnswer:" return Doc( task_name=task_name, query=query, - choices=ascii_uppercase[: len(line["options"])], + choices=[" " + c for c in ascii_uppercase[:len(line["options"])]], gold_index=line["answer_index"], ) @@ -78,7 +78,7 @@ def mmlu_pro_prompt_function(line, task_name: str = None): return Doc( task_name=task_name, query=query, - choices=ascii_uppercase[: len(choices)], + choices=list(ascii_uppercase[: len(line["options"])]), gold_index=line["answer_index"], instruction=query, ) @@ -135,6 +135,7 @@ def record_to_sample(record): metrics=[ LogLikelihoodAccMetric(), LogLikelihoodAccMetric(normalization=LogProbCharNorm()), + Metrics.target_bits_per_byte, ], ) diff --git a/src/lighteval/tasks/tasks/mmlu_prox.py b/src/lighteval/tasks/tasks/mmlu_prox.py index 08c04703b..318b5fc7f 100644 --- a/src/lighteval/tasks/tasks/mmlu_prox.py +++ b/src/lighteval/tasks/tasks/mmlu_prox.py @@ -73,7 +73,7 @@ def mmlu_prox_mcf_prompt(line, task_name: str = None): return None labels = list(ascii_uppercase[: len(options)]) query = f"Question: {line['question'].strip()}\n" - query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)]) + query += "".join([f" {lbl}. {opt}\n" for lbl, opt in zip(labels, options)]) query += "Answer:" return Doc( task_name=task_name, diff --git a/src/lighteval/tasks/templates/hellaswag.py b/src/lighteval/tasks/templates/hellaswag.py index 2beb09574..c1aa25135 100644 --- a/src/lighteval/tasks/templates/hellaswag.py +++ b/src/lighteval/tasks/templates/hellaswag.py @@ -56,7 +56,7 @@ def hellaswag_preprocess( text = re.sub("\\[.*?\\]", "", text) text = text.replace(" ", " ") if truncate_dots: - text = text.replace(r"\.+", r"\.") + text = re.sub(r"\.+", ".", text) if strip_text: text = text.strip() return text diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index 99405c45b..b0098bb0b 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -760,6 +760,8 @@ def __getattribute__(self, name: str) -> str: Language.KIRGHIZ: TranslationLiterals(language=Language.KIRGHIZ), Language.KOREAN: TranslationLiterals( language=Language.KOREAN, + question_word="질문", + answer="답", confirmation_word="맞죠", yes="예", no="아니오", @@ -773,7 +775,11 @@ def __getattribute__(self, name: str) -> str: Language.LIGURIAN: TranslationLiterals(language=Language.LIGURIAN), Language.LIMBURGISH: TranslationLiterals(language=Language.LIMBURGISH), Language.LINGALA: TranslationLiterals(language=Language.LINGALA), - Language.LITHUANIAN: TranslationLiterals(language=Language.LITHUANIAN), + Language.LITHUANIAN: TranslationLiterals( + language=Language.LITHUANIAN, + question_word="klausimas", + answer="atsakymas", + ), Language.LOMBARD: TranslationLiterals(language=Language.LOMBARD), Language.LUBA_KASAI: TranslationLiterals(language=Language.LUBA_KASAI), Language.LUO: TranslationLiterals(language=Language.LUO), @@ -1436,8 +1442,6 @@ def __getattribute__(self, name: str) -> str: Language.HEBREW, Language.IGBO, Language.KIRGHIZ, - Language.KOREAN, - Language.LITHUANIAN, Language.MALAGASY, Language.MALAY, Language.NEPALI, From d444a2c0015152734f84ca2c204660322f3d8161 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Mon, 11 May 2026 16:26:40 +0200 Subject: [PATCH 7/8] improve mgsm --- .../tasks/multilingual/tasks/mgsm.py | 316 +++++++- .../templates/utils/translation_literals.py | 753 +++++++++++++++++- 2 files changed, 1019 insertions(+), 50 deletions(-) diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py index 3189314b0..8dac452f1 100644 --- a/src/lighteval/tasks/multilingual/tasks/mgsm.py +++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py @@ -9,8 +9,12 @@ MGSM (Multilingual Grade School Math) is a multilingual benchmark testing mathematical reasoning across languages, derived from GSM8K. -Refactored to use unified :gen suffix consistent with English gsm8k.py. +Uses CoT-style prompts matching English gsm8k.py ("Think step by step / +Problem: / Solution:"), with language-specific translations for all non-English +languages in CohereLabs/global-mgsm. + English is excluded — use gsm8k.py for English evaluation. + Reports both expr_gold_metric (math expression parser) and MultilingualQuasiExactMatchMetric (language-aware fuzzy match, handles non-ASCII digit systems like Japanese/Thai). @@ -34,11 +38,12 @@ from lighteval.metrics.dynamic_metrics import MultilingualQuasiExactMatchMetric from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig -from lighteval.tasks.templates.qa import get_qa_prompt_function +from lighteval.tasks.requests import Doc from lighteval.utils.language import Language -# Languages covered by CohereLabs/global-mgsm (English excluded — use gsm8k.py) +# Languages covered by CohereLabs/global-mgsm. +# English is intentionally excluded — use gsm8k.py for English evaluation. _LANGUAGES = [ Language.AMHARIC, Language.ARABIC, @@ -83,26 +88,311 @@ ] +# Each entry is: +# instruction, problem_label, solution_label +# +# The instruction is intentionally close in meaning across languages: +# "Solve the following math problem. Think step by step before giving +# the final answer." +# +# NOTE: +# Some low-resource translations should ideally receive native-speaker review +# before being treated as final benchmark wording. +_MGSM_PROMPT_PARTS: dict[Language, tuple[str, str, str]] = { + Language.AMHARIC: ( + "የሚከተለውን የሂሳብ ችግር ፍታ። የመጨረሻውን መልስ ከመስጠትህ በፊት ደረጃ በደረጃ አስብ።", + "ችግር:", + "መፍትሔ:", + ), + Language.ARABIC: ( + "حل المسألة الرياضية التالية. فكّر خطوة بخطوة قبل تقديم الإجابة النهائية.", + "المسألة:", + "الحل:", + ), + Language.BENGALI: ( + "নিম্নলিখিত গণিত সমস্যাটি সমাধান করো। চূড়ান্ত উত্তর দেওয়ার আগে ধাপে ধাপে চিন্তা করো।", + "সমস্যা:", + "সমাধান:", + ), + Language.CATALAN: ( + "Resol el següent problema de matemàtiques. Pensa pas a pas abans de donar la resposta final.", + "Problema:", + "Solució:", + ), + Language.CZECH: ( + "Vyřeš následující matematickou úlohu. Než uvedeš konečnou odpověď, přemýšlej krok za krokem.", + "Úloha:", + "Řešení:", + ), + Language.WELSH: ( + "Datrysa'r broblem fathemateg ganlynol. Meddylia gam wrth gam cyn rhoi'r ateb terfynol.", + "Problem:", + "Datrysiad:", + ), + Language.GERMAN: ( + "Löse die folgende Mathematikaufgabe. Denke Schritt für Schritt nach, bevor du die endgültige Antwort gibst.", + "Aufgabe:", + "Lösung:", + ), + Language.GREEK: ( + "Λύσε το παρακάτω μαθηματικό πρόβλημα. Σκέψου βήμα προς βήμα πριν δώσεις την τελική απάντηση.", + "Πρόβλημα:", + "Λύση:", + ), + Language.SPANISH: ( + "Resuelve el siguiente problema de matemáticas. Piensa paso a paso antes de dar la respuesta final.", + "Problema:", + "Solución:", + ), + Language.BASQUE: ( + "Ebatzi honako matematikako problema hau. Pentsatu pausoz pauso azken erantzuna eman aurretik.", + "Problema:", + "Ebazpena:", + ), + Language.FRENCH: ( + "Résous le problème de mathématiques suivant. Réfléchis étape par étape avant de donner la réponse finale.", + "Problème :", + "Solution:", + ), + Language.GALICIAN: ( + "Resolve o seguinte problema matemático. Pensa paso a paso antes de dar a resposta final.", + "Problema:", + "Solución:", + ), + Language.GUJARATI: ( + "નીચે આપેલ ગણિતનો પ્રશ્ન ઉકેલો. અંતિમ જવાબ આપતા પહેલાં પગલું દર પગલું વિચારો.", + "પ્રશ્ન:", + "ઉકેલ:", + ), + Language.HAUSA: ( + "Warware matsalar lissafi mai zuwa. Yi tunani mataki-mataki kafin ka bayar da amsar ƙarshe.", + "Matsala:", + "Magani:", + ), + Language.HUNGARIAN: ( + "Oldd meg a következő matematikai feladatot. Gondolkodj lépésről lépésre, mielőtt megadnád a végső választ.", + "Feladat:", + "Megoldás:", + ), + Language.JAPANESE: ( + "次の数学の問題を解きなさい。最終的な答えを出す前に、順を追って考えなさい。", + "問題:", + "解答:", + ), + Language.KHMER: ( + "ដោះស្រាយលំហាត់គណិតវិទ្យាខាងក្រោម។ សូមគិតជាជំហានៗ មុននឹងផ្តល់ចម្លើយចុងក្រោយ។", + "លំហាត់:", + "ដំណោះស្រាយ:", + ), + Language.KANNADA: ( + "ಕೆಳಗಿನ ಗಣಿತ ಸಮಸ್ಯೆಯನ್ನು ಪರಿಹರಿಸಿ. ಅಂತಿಮ ಉತ್ತರವನ್ನು ನೀಡುವ ಮೊದಲು ಹಂತ ಹಂತವಾಗಿ ಯೋಚಿಸಿ.", + "ಸಮಸ್ಯೆ:", + "ಪರಿಹಾರ:", + ), + Language.KOREAN: ( + "다음 수학 문제를 풀어라. 최종 답을 제시하기 전에 단계별로 생각하라.", + "문제:", + "풀이:", + ), + Language.KIRGHIZ: ( + "Төмөнкү математикалык маселени чыгар. Акыркы жоопту берүүдөн мурун кадам сайын ойлон.", + "Маселе:", + "Чечим:", + ), + Language.GANDA: ( + "Gonjoola ekibuuzo kya okubala kino wammanga. Lowooza mutendera ku mutendera nga tonnawa ky’okuddamu ekisembayo.", + "Ekibuuzo:", + "Okugonjoola:", + ), + Language.BURMESE: ( + "အောက်ပါ သင်္ချာပုစ္ဆာကို ဖြေရှင်းပါ။ နောက်ဆုံးအဖြေ မပေးမီ အဆင့်လိုက် စဉ်းစားပါ။", + "ပုစ္ဆာ:", + "ဖြေရှင်းချက်:", + ), + Language.NEPALI: ( + "तलको गणित समस्या हल गर। अन्तिम उत्तर दिनुअघि चरणबद्ध रूपमा सोच।", + "समस्या:", + "समाधान:", + ), + Language.RUSSIAN: ( + "Реши следующую математическую задачу. Рассуждай пошагово, прежде чем дать окончательный ответ.", + "Задача:", + "Решение:", + ), + Language.SINHALA: ( + "පහත ගණිත ගැටලුව විසඳන්න. අවසාන පිළිතුර ලබාදීමට පෙර පියවරෙන් පියවර සිතන්න.", + "ගැටලුව:", + "විසඳුම:", + ), + Language.SHONA: ( + "Gadzirisa dambudziko remasvomhu rinotevera. Funga nhanho nenhanho usati wapa mhinduro yekupedzisira.", + "Dambudziko:", + "Mhinduro:", + ), + Language.SERBIAN: ( + "Реши следећи математички задатак. Размишљај корак по корак пре него што даш коначан одговор.", + "Задатак:", + "Решење:", + ), + Language.SOUTHERN_SOTHO: ( + "Rarolla bothata bona ba dipalo bo latelang. Nahana mohato ka mohato pele o fana ka karabo ya ho qetela.", + "Bothata:", + "Tharollo:", + ), + Language.SWAHILI: ( + "Tatua tatizo lifuatalo la hisabati. Fikiri hatua kwa hatua kabla ya kutoa jibu la mwisho.", + "Tatizo:", + "Suluhisho:", + ), + Language.TAMIL: ( + "பின்வரும் கணிதப் பிரச்சினையைத் தீர்க்கவும். இறுதி விடையை அளிப்பதற்கு முன் படிப்படியாக சிந்திக்கவும்.", + "பிரச்சினை:", + "தீர்வு:", + ), + Language.TELUGU: ( + "క్రింది గణిత సమస్యను పరిష్కరించండి. తుది సమాధానం చెప్పే ముందు దశలవారీగా ఆలోచించండి.", + "సమస్య:", + "పరిష్కారం:", + ), + Language.THAI: ( + "จงแก้โจทย์คณิตศาสตร์ต่อไปนี้ คิดทีละขั้นตอนก่อนให้คำตอบสุดท้าย", + "โจทย์:", + "วิธีทำ:", + ), + Language.URDU: ( + "درج ذیل ریاضی کا مسئلہ حل کریں۔ حتمی جواب دینے سے پہلے مرحلہ وار سوچیں۔", + "مسئلہ:", + "حل:", + ), + Language.UZBEK: ( + "Quyidagi matematika masalasini yeching. Yakuniy javobni berishdan oldin bosqichma-bosqich o‘ylang.", + "Masala:", + "Yechim:", + ), + Language.VIETNAMESE: ( + "Giải bài toán sau đây. Hãy suy nghĩ từng bước trước khi đưa ra câu trả lời cuối cùng.", + "Bài toán:", + "Lời giải:", + ), + Language.WOLOF: ( + "Saafara jafe-jafe xayma bii ci suuf. Xalaatal ndànk-ndànk, jéego bu nekk, bala nga joxe tontu bu mujj bi.", + "Jafe-jafe:", + "Saafara:", + ), + Language.XHOSA: ( + "Sombulula le ngxaki yezibalo ilandelayo. Cinga inyathelo nenyathelo phambi kokunika impendulo yokugqibela.", + "Ingxaki:", + "Isisombululo:", + ), + Language.YORUBA: ( + "Yanju iṣoro iṣiro atẹle yii. Ronu ni igbesẹ-nipasẹ-igbesẹ ṣaaju ki o to fun idahun ikẹhin.", + "Iṣoro:", + "Ojútùú:", + ), + Language.CHINESE: ( + "解答下面的数学题。在给出最终答案前,请一步一步思考。", + "问题:", + "解答:", + ), + Language.ZULU: ( + "Xazulula inkinga yezibalo elandelayo. Cabanga isinyathelo ngesinyathelo ngaphambi kokunikeza impendulo yokugcina.", + "Inkinga:", + "Isixazululo:", + ), +} + + +# Extra problem-label variants that models may generate. +# Keep this small: too many stop strings increase accidental truncation risk. +_MGSM_PROBLEM_LABEL_STOP_ALIASES: dict[Language, list[str]] = { + Language.FRENCH: ["Problème:"], # no-space variant + Language.CHINESE: ["问题:"], # ASCII-colon variant + Language.JAPANESE: ["問題:"], # ASCII-colon variant +} + + +def validate_mgsm_prompt_parts() -> None: + """Fail fast if a language is missing localized prompt parts.""" + missing_languages = [ + language for language in _LANGUAGES if language not in _MGSM_PROMPT_PARTS + ] + + if missing_languages: + missing = ", ".join(language.value for language in missing_languages) + raise ValueError(f"Missing MGSM prompt parts for: {missing}") + + +def mgsm_cot_template(language: Language) -> str: + """Return the localized MGSM CoT template for a language.""" + try: + instruction, problem_label, solution_label = _MGSM_PROMPT_PARTS[language] + except KeyError as exc: + raise ValueError( + f"Missing MGSM prompt parts for language: {language.value}" + ) from exc + + return f"{instruction}\n\n{problem_label}\n{{prompt}}\n\n{solution_label}" + + +def mgsm_stop_sequences(language: Language) -> list[str]: + """Return stop sequences for a language. + + Stop when the model appears to start a new problem. + + We keep English stops because models sometimes switch back to English even + when the prompt is localized. We also add only the current language's + problem label and a few high-value aliases to avoid over-stopping. + """ + try: + _, problem_label, _ = _MGSM_PROMPT_PARTS[language] + except KeyError as exc: + raise ValueError( + f"Missing MGSM prompt parts for language: {language.value}" + ) from exc + + stop_sequences = [ + "Question:", + "Problem:", + problem_label, + *_MGSM_PROBLEM_LABEL_STOP_ALIASES.get(language, []), + ] + + # Preserve order while removing duplicates. + return list(dict.fromkeys(stop_sequences)) + + +def mgsm_cot_prompt(language: Language): + """Build a CoT prompt function for MGSM in the given language.""" + template = mgsm_cot_template(language) + + def prompt_fn(line, task_name: str = None): + return Doc( + task_name=task_name, + query=template.format(prompt=line["question"]), + choices=[line["answer"]], + gold_index=0, + ) + + return prompt_fn + + +validate_mgsm_prompt_parts() + + TASKS_TABLE = [ LightevalTaskConfig( name=f"mgsm:{language.value}:gen", - prompt_function=get_qa_prompt_function( - language, - lambda line: { - "question": line["question"], - "choices": [line["answer"]], - }, - ), + prompt_function=mgsm_cot_prompt(language), hf_repo="CohereLabs/global-mgsm", hf_subset=standardize_tag(language.value), evaluation_splits=("test",), few_shots_split=None, - generation_size=512, + generation_size=1024, metrics=[ Metrics.expr_gold_metric, MultilingualQuasiExactMatchMetric(language, "full"), ], - stop_sequence=["Question:", "Answer:"], + stop_sequence=mgsm_stop_sequences(language), ) for language in _LANGUAGES -] +] \ No newline at end of file diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py index b0098bb0b..f42cc4765 100644 --- a/src/lighteval/tasks/templates/utils/translation_literals.py +++ b/src/lighteval/tasks/templates/utils/translation_literals.py @@ -76,9 +76,55 @@ def __getattribute__(self, name: str) -> str: TRANSLATION_LITERALS: dict[Language, TranslationLiterals] = { Language.ACEHNESE: TranslationLiterals(language=Language.ACEHNESE), - Language.AFRIKAANS: TranslationLiterals(language=Language.AFRIKAANS), + Language.AFRIKAANS: TranslationLiterals( + language=Language.AFRIKAANS, + question_word="vraag", + answer="antwoord", + confirmation_word="nè", + yes="ja", + no="nee", + also="ook", + cause_word="omdat", + effect_word="daarom", + or_word="of", + and_word="en", + true="waar", + false="vals", + neither="nie een van die twee nie", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.AKAN: TranslationLiterals(language=Language.AKAN), - Language.ALBANIAN: TranslationLiterals(language=Language.ALBANIAN), + Language.ALBANIAN: TranslationLiterals( + language=Language.ALBANIAN, + question_word="pyetje", + answer="përgjigje", + confirmation_word="apo jo", + yes="po", + no="jo", + also="gjithashtu", + cause_word="sepse", + effect_word="prandaj", + or_word="ose", + and_word="dhe", + true="e vërtetë", + false="e rreme", + neither="asnjëra", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.AMHARIC: TranslationLiterals(language=Language.AMHARIC), Language.ARABIC: TranslationLiterals( language=Language.ARABIC, @@ -109,7 +155,30 @@ def __getattribute__(self, name: str) -> str: Language.ASTURIAN: TranslationLiterals(language=Language.ASTURIAN), Language.AWADHI: TranslationLiterals(language=Language.AWADHI), Language.AYACUCHO_QUECHUA: TranslationLiterals(language=Language.AYACUCHO_QUECHUA), - Language.AZERBAIJANI: TranslationLiterals(language=Language.AZERBAIJANI), + Language.AZERBAIJANI: TranslationLiterals( + language=Language.AZERBAIJANI, + question_word="sual", + answer="cavab", + confirmation_word="elə deyilmi", + yes="bəli", + no="xeyr", + also="həmçinin", + cause_word="çünki", + effect_word="buna görə", + or_word="və ya", + and_word="və", + true="doğru", + false="yanlış", + neither="heç biri", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.BALINESE: TranslationLiterals(language=Language.BALINESE), Language.BAMBARA: TranslationLiterals(language=Language.BAMBARA), Language.BANJAR: TranslationLiterals(language=Language.BANJAR), @@ -209,7 +278,30 @@ def __getattribute__(self, name: str) -> str: ), Language.BHOJPURI: TranslationLiterals(language=Language.BHOJPURI), Language.BIHARI: TranslationLiterals(language=Language.BIHARI), # Deprecated - Language.BOSNIAN: TranslationLiterals(language=Language.BOSNIAN), + Language.BOSNIAN: TranslationLiterals( + language=Language.BOSNIAN, + question_word="pitanje", + answer="odgovor", + confirmation_word="zar ne", + yes="da", + no="ne", + also="također", + cause_word="jer", + effect_word="zato", + or_word="ili", + and_word="i", + true="tačno", + false="netačno", + neither="nijedno", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.BRETON: TranslationLiterals(language=Language.BRETON), Language.BUGINESE: TranslationLiterals(language=Language.BUGINESE), Language.BULGARIAN: TranslationLiterals( @@ -519,7 +611,30 @@ def __getattribute__(self, name: str) -> str: semicolon=";", ), Language.GANDA: TranslationLiterals(language=Language.GANDA), - Language.GEORGIAN: TranslationLiterals(language=Language.GEORGIAN), + Language.GEORGIAN: TranslationLiterals( + language=Language.GEORGIAN, + question_word="კითხვა", + answer="პასუხი", + confirmation_word="მართალია", + yes="დიახ", + no="არა", + also="ასევე", + cause_word="რადგან", + effect_word="ამიტომ", + or_word="ან", + and_word="და", + true="ჭეშმარიტი", + false="მცდარი", + neither="არცერთი", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.GERMAN: TranslationLiterals( language=Language.GERMAN, question_word="frage", @@ -597,10 +712,56 @@ def __getattribute__(self, name: str) -> str: cause_word="poukisa", effect_word="donk sa", ), - Language.HAITIAN_CREOLE: TranslationLiterals(language=Language.HAITIAN_CREOLE), + Language.HAITIAN_CREOLE: TranslationLiterals( + language=Language.HAITIAN_CREOLE, + question_word="kesyon", + answer="repons", + confirmation_word="se vre", + yes="wi", + no="non", + also="tou", + cause_word="paske", + effect_word="kidonk", + or_word="oswa", + and_word="ak", + true="vre", + false="fo", + neither="okenn nan yo", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.HALH_MONGOLIAN: TranslationLiterals(language=Language.HALH_MONGOLIAN), Language.HAUSA: TranslationLiterals(language=Language.HAUSA), - Language.HEBREW: TranslationLiterals(language=Language.HEBREW), + Language.HEBREW: TranslationLiterals( + language=Language.HEBREW, + question_word="שאלה", + answer="תשובה", + confirmation_word="נכון", + yes="כן", + no="לא", + also="גם", + cause_word="כי", + effect_word="לכן", + or_word="או", + and_word="ו", + true="אמת", + false="שקר", + neither="אף אחד מהם", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.HINDI: TranslationLiterals( language=Language.HINDI, question_word="सवाल", @@ -696,7 +857,30 @@ def __getattribute__(self, name: str) -> str: colon=":", semicolon=";", ), - Language.IRISH: TranslationLiterals(language=Language.IRISH), + Language.IRISH: TranslationLiterals( + language=Language.IRISH, + question_word="ceist", + answer="freagra", + confirmation_word="nach ea", + yes="sea", + no="ní hea", + also="freisin", + cause_word="mar", + effect_word="dá bhrí sin", + or_word="nó", + and_word="agus", + true="fíor", + false="bréagach", + neither="ceachtar acu", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.ITALIAN: TranslationLiterals( language=Language.ITALIAN, question_word="domanda", @@ -728,9 +912,10 @@ def __getattribute__(self, name: str) -> str: yes="はい", no="いいえ", also="また", - cause_word="なので", - effect_word="なぜなら", + cause_word="なぜなら", + effect_word="だから", or_word="または", + and_word="そして", true="正解", false="不正解", neither="どちらでもない", @@ -749,29 +934,163 @@ def __getattribute__(self, name: str) -> str: Language.KABUVERDIANU: TranslationLiterals(language=Language.KABUVERDIANU), Language.KABYLE: TranslationLiterals(language=Language.KABYLE), Language.KAMBA: TranslationLiterals(language=Language.KAMBA), - Language.KANNADA: TranslationLiterals(language=Language.KANNADA), + Language.KANNADA: TranslationLiterals( + language=Language.KANNADA, + question_word="ಪ್ರಶ್ನೆ", + answer="ಉತ್ತರ", + confirmation_word="ಅಲ್ಲವೆ", + yes="ಹೌದು", + no="ಇಲ್ಲ", + also="ಹಾಗೂ", + cause_word="ಏಕೆಂದರೆ", + effect_word="ಆದ್ದರಿಂದ", + or_word="ಅಥವಾ", + and_word="ಮತ್ತು", + true="ಸತ್ಯ", + false="ಸುಳ್ಳು", + neither="ಎರಡೂ ಅಲ್ಲ", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + indices=["ಅ", "ಆ", "ಇ", "ಈ", "ಉ", "ಊ"], + ), Language.KASHMIRI: TranslationLiterals(language=Language.KASHMIRI), - Language.KAZAKH: TranslationLiterals(language=Language.KAZAKH), + Language.KAZAKH: TranslationLiterals( + language=Language.KAZAKH, + question_word="сұрақ", + answer="жауап", + confirmation_word="солай ма", + yes="иә", + no="жоқ", + also="сондай-ақ", + cause_word="өйткені", + effect_word="сондықтан", + or_word="немесе", + and_word="және", + true="дұрыс", + false="жалған", + neither="ешқайсысы", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + indices=["А", "Ә", "Б", "В", "Г", "Ғ"], + ), Language.KHMER: TranslationLiterals(language=Language.KHMER), Language.KIKONGO: TranslationLiterals(language=Language.KIKONGO), Language.KIKUYU: TranslationLiterals(language=Language.KIKUYU), Language.KIMBUNDU: TranslationLiterals(language=Language.KIMBUNDU), Language.KINYARWANDA: TranslationLiterals(language=Language.KINYARWANDA), - Language.KIRGHIZ: TranslationLiterals(language=Language.KIRGHIZ), + Language.KIRGHIZ: TranslationLiterals( + language=Language.KIRGHIZ, + question_word="суроо", + answer="жооп", + confirmation_word="туурабы", + yes="ооба", + no="жок", + also="ошондой эле", + cause_word="анткени", + effect_word="ошондуктан", + or_word="же", + and_word="жана", + true="чын", + false="жалган", + neither="эч бири", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + indices=["А", "Б", "В", "Г", "Д", "Е"], + ), Language.KOREAN: TranslationLiterals( language=Language.KOREAN, question_word="질문", - answer="답", + answer="답변", confirmation_word="맞죠", yes="예", no="아니오", + also="또한", + cause_word="왜냐하면", + effect_word="그러므로", + or_word="또는", + and_word="그리고", + true="참", + false="거짓", + neither="둘 다 아님", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", ), Language.KURDISH: TranslationLiterals(language=Language.KURDISH), - Language.KYRGYZ: TranslationLiterals(language=Language.KYRGYZ), + Language.KYRGYZ: TranslationLiterals( + language=Language.KYRGYZ, + question_word="суроо", + answer="жооп", + confirmation_word="туурабы", + yes="ооба", + no="жок", + also="ошондой эле", + cause_word="анткени", + effect_word="ошондуктан", + or_word="же", + and_word="жана", + true="чын", + false="жалган", + neither="эч бири", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + indices=["А", "Б", "В", "Г", "Д", "Е"], + ), Language.LAO: TranslationLiterals(language=Language.LAO), Language.LATGALIAN: TranslationLiterals(language=Language.LATGALIAN), Language.LATIN: TranslationLiterals(language=Language.LATIN), - Language.LATVIAN: TranslationLiterals(language=Language.LATVIAN), + Language.LATVIAN: TranslationLiterals( + language=Language.LATVIAN, + question_word="jautājums", + answer="atbilde", + confirmation_word="vai ne", + yes="jā", + no="nē", + also="arī", + cause_word="jo", + effect_word="tāpēc", + or_word="vai", + and_word="un", + true="patiesi", + false="aplami", + neither="neviens", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.LIGURIAN: TranslationLiterals(language=Language.LIGURIAN), Language.LIMBURGISH: TranslationLiterals(language=Language.LIMBURGISH), Language.LINGALA: TranslationLiterals(language=Language.LINGALA), @@ -779,6 +1098,25 @@ def __getattribute__(self, name: str) -> str: language=Language.LITHUANIAN, question_word="klausimas", answer="atsakymas", + confirmation_word="ar ne", + yes="taip", + no="ne", + also="taip pat", + cause_word="nes", + effect_word="todėl", + or_word="arba", + and_word="ir", + true="teisinga", + false="neteisinga", + neither="nei vienas", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", ), Language.LOMBARD: TranslationLiterals(language=Language.LOMBARD), Language.LUBA_KASAI: TranslationLiterals(language=Language.LUBA_KASAI), @@ -810,20 +1148,158 @@ def __getattribute__(self, name: str) -> str: Language.MAGAHI: TranslationLiterals(language=Language.MAGAHI), Language.MAITHILI: TranslationLiterals(language=Language.MAITHILI), Language.MALAGASY: TranslationLiterals(language=Language.MALAGASY), - Language.MALAY: TranslationLiterals(language=Language.MALAY), - Language.MALAYALAM: TranslationLiterals(language=Language.MALAYALAM), - Language.MALTESE: TranslationLiterals(language=Language.MALTESE), + Language.MALAY: TranslationLiterals( + language=Language.MALAY, + question_word="soalan", + answer="jawapan", + confirmation_word="bukan", + yes="ya", + no="tidak", + also="juga", + cause_word="kerana", + effect_word="oleh itu", + or_word="atau", + and_word="dan", + true="benar", + false="salah", + neither="kedua-duanya tidak", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), + Language.MALAYALAM: TranslationLiterals( + language=Language.MALAYALAM, + question_word="ചോദ്യം", + answer="ഉത്തരം", + confirmation_word="അല്ലേ", + yes="അതെ", + no="അല്ല", + also="കൂടാതെ", + cause_word="കാരണം", + effect_word="അതുകൊണ്ട്", + or_word="അല്ലെങ്കിൽ", + and_word="കൂടാതെ", + true="ശരി", + false="തെറ്റ്", + neither="ഇരുവരും അല്ല", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + indices=["അ", "ആ", "ഇ", "ഈ", "ഉ", "ഊ"], + ), + Language.MALTESE: TranslationLiterals( + language=Language.MALTESE, + question_word="mistoqsija", + answer="tweġiba", + confirmation_word="hux hekk", + yes="iva", + no="le", + also="ukoll", + cause_word="għax", + effect_word="għalhekk", + or_word="jew", + and_word="u", + true="veru", + false="falz", + neither="l-ebda waħda", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.MAORI: TranslationLiterals(language=Language.MAORI), - Language.MARATHI: TranslationLiterals(language=Language.MARATHI), + Language.MARATHI: TranslationLiterals( + language=Language.MARATHI, + question_word="प्रश्न", + answer="उत्तर", + confirmation_word="बरोबर ना", + yes="हो", + no="नाही", + also="तसेच", + cause_word="कारण", + effect_word="म्हणून", + or_word="किंवा", + and_word="आणि", + true="सत्य", + false="असत्य", + neither="दोन्ही नाही", + full_stop="।", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + indices=["क", "ख", "ग", "घ", "ङ", "च"], + ), Language.MEITEI: TranslationLiterals(language=Language.MEITEI), Language.MESOPOTAMIAN_ARABIC: TranslationLiterals(language=Language.MESOPOTAMIAN_ARABIC), Language.MINANGKABAU: TranslationLiterals(language=Language.MINANGKABAU), Language.MIZO: TranslationLiterals(language=Language.MIZO), - Language.MODERN_STANDARD_ARABIC: TranslationLiterals(language=Language.MODERN_STANDARD_ARABIC), + Language.MODERN_STANDARD_ARABIC: TranslationLiterals( + language=Language.MODERN_STANDARD_ARABIC, + question_word="سؤال", + answer="إجابة", + confirmation_word="صحيح", + yes="نعم", + no="لا", + also="كذلك", + cause_word="لأن", + effect_word="لذلك", + true="صحيح", + false="خاطئ", + neither="لا هذا ولا ذاك", + or_word="أو", + and_word="و", + full_stop=".", + comma="،", + question_mark="؟", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + indices=["أ", "ب", "ج", "د", "هـ", "و", "ز", "ح"], + ), Language.MOROCCAN_ARABIC: TranslationLiterals(language=Language.MOROCCAN_ARABIC), Language.MOSSI: TranslationLiterals(language=Language.MOSSI), Language.NAJDI_ARABIC: TranslationLiterals(language=Language.NAJDI_ARABIC), - Language.NEPALI: TranslationLiterals(language=Language.NEPALI), + Language.NEPALI: TranslationLiterals( + language=Language.NEPALI, + question_word="प्रश्न", + answer="उत्तर", + confirmation_word="होइन र", + yes="हो", + no="होइन", + also="पनि", + cause_word="किनभने", + effect_word="त्यसैले", + or_word="वा", + and_word="र", + true="सत्य", + false="असत्य", + neither="दुवै होइन", + full_stop="।", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + indices=["क", "ख", "ग", "घ", "ङ", "च"], + ), Language.NIGERIAN_FULFULDE: TranslationLiterals(language=Language.NIGERIAN_FULFULDE), Language.NORTHERN_KURDISH: TranslationLiterals(language=Language.NORTHERN_KURDISH), Language.NORTHERN_SOTHO: TranslationLiterals(language=Language.NORTHERN_SOTHO), @@ -853,17 +1329,109 @@ def __getattribute__(self, name: str) -> str: colon=":", semicolon=";", ), - Language.NORWEGIAN_BOKMAL: TranslationLiterals(language=Language.NORWEGIAN_BOKMAL), - Language.NORWEGIAN_NYNORSK: TranslationLiterals(language=Language.NORWEGIAN_NYNORSK), + Language.NORWEGIAN_BOKMAL: TranslationLiterals( + language=Language.NORWEGIAN_BOKMAL, + question_word="spørsmål", + answer="svar", + confirmation_word="ikke sant", + yes="ja", + no="nei", + also="også", + cause_word="fordi", + effect_word="derfor", + or_word="eller", + and_word="og", + true="sant", + false="usant", + neither="ingen av delene", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), + Language.NORWEGIAN_NYNORSK: TranslationLiterals( + language=Language.NORWEGIAN_NYNORSK, + question_word="spørsmål", + answer="svar", + confirmation_word="ikkje sant", + yes="ja", + no="nei", + also="òg", + cause_word="fordi", + effect_word="difor", + or_word="eller", + and_word="og", + true="sant", + false="usant", + neither="ingen av delane", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.NUER: TranslationLiterals(language=Language.NUER), Language.NYANJA: TranslationLiterals(language=Language.NYANJA), - Language.OCCITAN: TranslationLiterals(language=Language.OCCITAN), + Language.OCCITAN: TranslationLiterals( + language=Language.OCCITAN, + question_word="question", + answer="responsa", + confirmation_word="vertat", + yes="òc", + no="non", + also="tanben", + cause_word="perque", + effect_word="donc", + or_word="o", + and_word="e", + true="vertat", + false="fals", + neither="cap dels dos", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.ODIA: TranslationLiterals(language=Language.ODIA), Language.ORIYA: TranslationLiterals(language=Language.ORIYA), Language.PANGASINAN: TranslationLiterals(language=Language.PANGASINAN), Language.PAPIAMENTO: TranslationLiterals(language=Language.PAPIAMENTO), Language.PASHTO: TranslationLiterals(language=Language.PASHTO), - Language.PERSIAN: TranslationLiterals(language=Language.PERSIAN), + Language.PERSIAN: TranslationLiterals( + language=Language.PERSIAN, + question_word="پرسش", + answer="پاسخ", + confirmation_word="درست است", + yes="بله", + no="نه", + also="همچنین", + cause_word="زیرا", + effect_word="بنابراین", + or_word="یا", + and_word="و", + true="درست", + false="نادرست", + neither="هیچ‌کدام", + full_stop=".", + comma="،", + question_mark="؟", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon="؛", + ), Language.PLATEAU_MALAGASY: TranslationLiterals(language=Language.PLATEAU_MALAGASY), Language.POLISH: TranslationLiterals( language=Language.POLISH, @@ -1076,7 +1644,30 @@ def __getattribute__(self, name: str) -> str: colon=":", semicolon=";", ), - Language.SLOVENIAN: TranslationLiterals(language=Language.SLOVENIAN), + Language.SLOVENIAN: TranslationLiterals( + language=Language.SLOVENIAN, + question_word="vprašanje", + answer="odgovor", + confirmation_word="kajne", + yes="da", + no="ne", + also="tudi", + cause_word="ker", + effect_word="zato", + or_word="ali", + and_word="in", + true="resnično", + false="napačno", + neither="nobeno", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.SOMALI: TranslationLiterals(language=Language.SOMALI), Language.SORANI: TranslationLiterals(language=Language.SORANI), Language.SOUTHERN_PASHTO: TranslationLiterals(language=Language.SOUTHERN_PASHTO), @@ -1107,8 +1698,54 @@ def __getattribute__(self, name: str) -> str: colon=":", semicolon=";", ), - Language.STANDARD_LATVIAN: TranslationLiterals(language=Language.STANDARD_LATVIAN), - Language.STANDARD_MALAY: TranslationLiterals(language=Language.STANDARD_MALAY), + Language.STANDARD_LATVIAN: TranslationLiterals( + language=Language.STANDARD_LATVIAN, + question_word="jautājums", + answer="atbilde", + confirmation_word="vai ne", + yes="jā", + no="nē", + also="arī", + cause_word="jo", + effect_word="tāpēc", + or_word="vai", + and_word="un", + true="patiesi", + false="aplami", + neither="neviens", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), + Language.STANDARD_MALAY: TranslationLiterals( + language=Language.STANDARD_MALAY, + question_word="soalan", + answer="jawapan", + confirmation_word="bukan", + yes="ya", + no="tidak", + also="juga", + cause_word="kerana", + effect_word="oleh itu", + or_word="atau", + and_word="dan", + true="benar", + false="salah", + neither="kedua-duanya tidak", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.STANDARD_TIBETAN: TranslationLiterals(language=Language.STANDARD_TIBETAN), Language.SUNDANESE: TranslationLiterals(language=Language.SUNDANESE), Language.SWAHILI: TranslationLiterals( @@ -1143,8 +1780,9 @@ def __getattribute__(self, name: str) -> str: no="nej", also="också", cause_word="eftersom", - effect_word="därför att", + effect_word="därför", or_word="eller", + and_word="och", true="sant", false="falskt", neither="ingendera", @@ -1421,9 +2059,55 @@ def __getattribute__(self, name: str) -> str: ), Language.WAR: TranslationLiterals(language=Language.WAR), Language.WARAY: TranslationLiterals(language=Language.WARAY), - Language.WELSH: TranslationLiterals(language=Language.WELSH), + Language.WELSH: TranslationLiterals( + language=Language.WELSH, + question_word="cwestiwn", + answer="ateb", + confirmation_word="on'd yw e", + yes="ydy", + no="nac ydy", + also="hefyd", + cause_word="oherwydd", + effect_word="felly", + or_word="neu", + and_word="a", + true="gwir", + false="gau", + neither="dim un o'r ddau", + full_stop=".", + comma=",", + question_mark="?", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon=";", + ), Language.WESTERN_FRISIAN: TranslationLiterals(language=Language.WESTERN_FRISIAN), - Language.WESTERN_PERSIAN: TranslationLiterals(language=Language.WESTERN_PERSIAN), + Language.WESTERN_PERSIAN: TranslationLiterals( + language=Language.WESTERN_PERSIAN, + question_word="پرسش", + answer="پاسخ", + confirmation_word="درست است", + yes="بله", + no="نه", + also="همچنین", + cause_word="زیرا", + effect_word="بنابراین", + or_word="یا", + and_word="و", + true="درست", + false="نادرست", + neither="هیچ‌کدام", + full_stop=".", + comma="،", + question_mark="؟", + exclamation_mark="!", + word_space=" ", + sentence_space=" ", + colon=":", + semicolon="؛", + ), Language.WEST_CENTRAL_OROMO: TranslationLiterals(language=Language.WEST_CENTRAL_OROMO), Language.WOLOF: TranslationLiterals(language=Language.WOLOF), Language.XHOSA: TranslationLiterals(language=Language.XHOSA), @@ -1439,14 +2123,9 @@ def __getattribute__(self, name: str) -> str: for _language in [ Language.AMHARIC, Language.HAUSA, - Language.HEBREW, Language.IGBO, - Language.KIRGHIZ, Language.MALAGASY, - Language.MALAY, - Language.NEPALI, Language.NYANJA, - Language.PERSIAN, Language.SHONA, Language.SINHALA, Language.SOMALI, From f30649b07bde502439ad7a6447f1aef5be5fe262 Mon Sep 17 00:00:00 2001 From: rakkit <26144573+rakkit@users.noreply.github.com> Date: Tue, 19 May 2026 16:09:22 +0200 Subject: [PATCH 8/8] fix bpb and cf for swarm --- .../metrics/harness_compatibility/drop.py | 2 +- src/lighteval/metrics/metrics.py | 2 +- .../metrics/utils/linguistic_tokenizers.py | 4 +- .../models/endpoints/litellm_model.py | 19 +- .../tasks/multilingual/tasks/mgsm.py | 2 +- src/lighteval/tasks/tasks/arc.py | 12 +- src/lighteval/tasks/tasks/basic_skills.py | 135 ++++++++- src/lighteval/tasks/tasks/commonsenseqa.py | 8 +- src/lighteval/tasks/tasks/gsm8k.py | 2 +- src/lighteval/tasks/tasks/hellaswag.py | 12 +- src/lighteval/tasks/tasks/lambada.py | 64 +++- src/lighteval/tasks/tasks/med.py | 8 +- .../tasks/tasks/natural_questions.py | 5 +- src/lighteval/tasks/tasks/ruler.py | 279 +++++++++++++----- src/lighteval/tasks/tasks/squad.py | 6 +- src/lighteval/tasks/tasks/winogrande.py | 21 +- 16 files changed, 436 insertions(+), 145 deletions(-) diff --git a/src/lighteval/metrics/harness_compatibility/drop.py b/src/lighteval/metrics/harness_compatibility/drop.py index 93ff0028b..4a40f52a4 100644 --- a/src/lighteval/metrics/harness_compatibility/drop.py +++ b/src/lighteval/metrics/harness_compatibility/drop.py @@ -89,7 +89,7 @@ def _get_metrics(self, predicted: List[str], gold: List[str]): pred_normalized_spans, pred_bags = self._answer_to_bags(predicted) gold_normalized_spans, gold_bags = self._answer_to_bags(gold) - if set(pred_normalized_spans) == set(gold_normalized_spans) and len(gold_normalized_spans) == len( + if set(pred_normalized_spans) == set(gold_normalized_spans) and len(pred_normalized_spans) == len( gold_normalized_spans ): exact_match = 1.0 diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py index 41053de70..b707b4fdd 100644 --- a/src/lighteval/metrics/metrics.py +++ b/src/lighteval/metrics/metrics.py @@ -329,7 +329,7 @@ class Metrics(Enum): metric_name=["em", "f1"], sample_level_fn=DropMetrics(), category=SamplingMethod.GENERATIVE, - corpus_level_fn={"em": max, "f1": max}, + corpus_level_fn={"em": np.mean, "f1": np.mean}, higher_is_better={"em": True, "f1": True}, ) exact_match = SampleLevelMetric( diff --git a/src/lighteval/metrics/utils/linguistic_tokenizers.py b/src/lighteval/metrics/utils/linguistic_tokenizers.py index a5670a268..120596cec 100644 --- a/src/lighteval/metrics/utils/linguistic_tokenizers.py +++ b/src/lighteval/metrics/utils/linguistic_tokenizers.py @@ -175,10 +175,10 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]: # If you know a better tokenizer or better proxy language, please submit a PR TOKENIZER_FACTORY: dict[Language, Callable[[], WordTokenizer]] = { Language.ENGLISH: lambda: SpaCyTokenizer("en"), - Language.KOREAN: lambda: SpaCyTokenizer("ko"), + Language.KOREAN: lambda: StanzaTokenizer("ko"), Language.GERMAN: lambda: SpaCyTokenizer("de"), Language.FRENCH: lambda: SpaCyTokenizer("fr"), - Language.CZECH: lambda: SpaCyTokenizer("cz"), + Language.CZECH: lambda: SpaCyTokenizer("cs"), Language.DANISH: lambda: SpaCyTokenizer("da"), Language.DUTCH: lambda: SpaCyTokenizer("nl"), Language.ESTONIAN: lambda: SpaCyTokenizer("et"), diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 3cf03e31f..e7def5a55 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -67,13 +67,16 @@ def _env_flag(name: str, default: bool) -> bool: # SLURM jobs run concurrently. import tempfile as _tempfile - _cache_dir = os.environ.get("LITELLM_CACHE_DIR") - if _cache_dir: - _cache_dir = os.path.expandvars(os.path.expanduser(_cache_dir)) - os.makedirs(_cache_dir, exist_ok=True) + if os.environ.get("LITELLM_DISABLE_CACHE", "0").strip().lower() in ("1", "true", "yes"): + litellm.disable_cache() else: - _cache_dir = _tempfile.mkdtemp(prefix="litellm_cache_") - litellm.cache = Cache(type=LiteLLMCacheType.DISK, disk_cache_dir=_cache_dir) + _cache_dir = os.environ.get("LITELLM_CACHE_DIR") + if _cache_dir: + _cache_dir = os.path.expandvars(os.path.expanduser(_cache_dir)) + os.makedirs(_cache_dir, exist_ok=True) + else: + _cache_dir = _tempfile.mkdtemp(prefix="litellm_cache_") + litellm.cache = Cache(type=LiteLLMCacheType.DISK, disk_cache_dir=_cache_dir) else: from unittest.mock import Mock @@ -882,9 +885,11 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]: for doc, context, pairs in zip(doc_list, context_list, scored): logprobs = [lp for lp, _ in pairs] argmax = [g for _, g in pairs] + output_tokens = [[0] * _unique_choices[choice] for choice in doc.choices] results.append( ModelResponse( - input=context, logprobs=logprobs, argmax_logits_eq_gold=argmax + input=context, logprobs=logprobs, argmax_logits_eq_gold=argmax, + output_tokens=output_tokens, ) ) diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py index 8dac452f1..cec3c6108 100644 --- a/src/lighteval/tasks/multilingual/tasks/mgsm.py +++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py @@ -390,7 +390,7 @@ def prompt_fn(line, task_name: str = None): generation_size=1024, metrics=[ Metrics.expr_gold_metric, - MultilingualQuasiExactMatchMetric(language, "full"), + # MultilingualQuasiExactMatchMetric(language, "full"), ], stop_sequence=mgsm_stop_sequences(language), ) diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py index 13bd1ae1a..4ec4981ba 100644 --- a/src/lighteval/tasks/tasks/arc.py +++ b/src/lighteval/tasks/tasks/arc.py @@ -72,7 +72,7 @@ def arc_mcf_prompt(line, task_name: str = None): hf_subset="ARC-Challenge", hf_avail_splits=["train", "validation", "test"], evaluation_splits=["test"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_CF_METRICS, @@ -87,7 +87,7 @@ def arc_mcf_prompt(line, task_name: str = None): hf_subset="ARC-Easy", hf_avail_splits=["train", "validation", "test"], evaluation_splits=["test"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_CF_METRICS, @@ -103,7 +103,7 @@ def arc_mcf_prompt(line, task_name: str = None): hf_subset="ARC-Challenge", hf_avail_splits=["train", "validation", "test"], evaluation_splits=["test"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_MCF_METRICS, @@ -118,7 +118,7 @@ def arc_mcf_prompt(line, task_name: str = None): hf_subset="ARC-Easy", hf_avail_splits=["train", "validation", "test"], evaluation_splits=["test"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_MCF_METRICS, @@ -134,7 +134,7 @@ def arc_mcf_prompt(line, task_name: str = None): hf_subset="ARC-Challenge", hf_avail_splits=["train", "validation", "test"], evaluation_splits=["test"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=1, metrics=[Metrics.exact_match], @@ -149,7 +149,7 @@ def arc_mcf_prompt(line, task_name: str = None): hf_subset="ARC-Easy", hf_avail_splits=["train", "validation", "test"], evaluation_splits=["test"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=1, metrics=[Metrics.exact_match], diff --git a/src/lighteval/tasks/tasks/basic_skills.py b/src/lighteval/tasks/tasks/basic_skills.py index 7dff5241f..c413d342a 100644 --- a/src/lighteval/tasks/tasks/basic_skills.py +++ b/src/lighteval/tasks/tasks/basic_skills.py @@ -19,13 +19,24 @@ paper: """ +import hashlib +import random +import threading +from string import ascii_uppercase + +from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric from lighteval.metrics.metrics import Metrics +from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -# Assumed dataset fields: "question" (str), "answer" (str) -# Adjust if the actual allenai/basic-skills schema differs. +# Dataset fields: "question" (str), "answer" (str) — no built-in choices. +# Distractors for :cf and :mcf are built incrementally from the answers that +# flow through the prompt function (including few-shot examples, which are +# processed first). No separate dataset load is needed. +_BASIC_SKILLS_ACCUM: dict = {} # {subset: set of answers} +_ACCUM_LOCK = threading.Lock() # HF config names in allenai/basic-skills (only validation split available) _BASIC_SKILLS_SUBSETS = [ @@ -38,18 +49,21 @@ ] -def basic_skills_prompt(line, task_name: str = None): - """Greedy variant: generate answer text, compare with exact match.""" - return Doc( - task_name=task_name, - query=f"Question: {line['question']}\nAnswer:", - choices=[" " + line["answer"]], - gold_index=0, - ) +def _get_distractor_pool(subset: str, gold: str) -> list: + """Return available distractors (excluding *gold*), then accumulate *gold*. + + The pool is thread-safe and grows as examples pass through the prompt function. + Early examples in a subset may have fewer than 3 distractors until the pool + accumulates enough unique answers from prior few-shot or test examples. + """ + with _ACCUM_LOCK: + current = list(_BASIC_SKILLS_ACCUM.get(subset, set())) + _BASIC_SKILLS_ACCUM.setdefault(subset, set()).add(gold) + return [a for a in current if a != gold] def basic_skills_bpb_prompt(line, task_name: str = None): - """BPB variant: CF-style prompt with gold answer as single choice.""" + """BPB variant: single-choice, score gold continuation only (no acc ranking).""" gold_text = line["answer"] if not gold_text: return None @@ -63,13 +77,66 @@ def basic_skills_bpb_prompt(line, task_name: str = None): ) +def basic_skills_cf_prompt(line, task_name: str = None): + """CF variant: gold answer vs up to 3 distractors (RC_per-token acc + BPB). + + Subset is inferred from task_name (e.g. 'basic_skills:arithmetic:cf' → 'arithmetic'). + Distractors are drawn from a module-level pool that accumulates answers from + examples as they are processed — no separate dataset load required. + """ + if not line.get("answer") or not line["answer"].strip(): + return None + subset = task_name.split(":")[1] if task_name else _BASIC_SKILLS_SUBSETS[0] + gold = line["answer"] + seed = int(hashlib.md5((line["question"] + gold).encode()).hexdigest(), 16) % (2**32) + rng = random.Random(seed) + candidates = _get_distractor_pool(subset, gold) + distractors = rng.sample(candidates, min(3, len(candidates))) + all_choices = distractors + [gold] + rng.shuffle(all_choices) + gold_ix = all_choices.index(gold) + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + choices=[" " + c for c in all_choices], + gold_index=gold_ix, + ) + + +def basic_skills_mcf_prompt(line, task_name: str = None): + """MCF variant: labeled A/B/C/D options with gold answer (MC Acc). + + Uses a different hash seed from :cf to produce an independent shuffle. + """ + if not line.get("answer") or not line["answer"].strip(): + return None + subset = task_name.split(":")[1] if task_name else _BASIC_SKILLS_SUBSETS[0] + gold = line["answer"] + seed = int(hashlib.md5((line["question"] + gold).encode()).hexdigest(), 16) % (2**32) + rng = random.Random(seed + 1) + candidates = _get_distractor_pool(subset, gold) + distractors = rng.sample(candidates, min(3, len(candidates))) + all_choices = distractors + [gold] + rng.shuffle(all_choices) + gold_ix = all_choices.index(gold) + query = f"Question: {line['question']}\n" + query += "".join(f" {key}. {choice}\n" for key, choice in zip(ascii_uppercase, all_choices)) + query += "Answer:" + return Doc( + task_name=task_name, + query=query, + choices=[" " + key for key in ascii_uppercase[:len(all_choices)]], + gold_index=gold_ix, + ) + + TASKS_TABLE = [] for _subset in _BASIC_SKILLS_SUBSETS: - # CF variant: RC per-token normalization + BPB on gold (merged) + # BPB variant: single-choice gold continuation, BPB only TASKS_TABLE.append( LightevalTaskConfig( - name=f"basic_skills:{_subset}:cf", + name=f"basic_skills:{_subset}:bpb", prompt_function=basic_skills_bpb_prompt, hf_repo="allenai/basic-skills", hf_subset=_subset, @@ -78,8 +145,48 @@ def basic_skills_bpb_prompt(line, task_name: str = None): few_shots_split="validation", few_shots_select="random_sampling", generation_size=-1, + metrics=[Metrics.target_bits_per_byte], + stop_sequence=["\n"], + version=0, + ) + ) + + # CF variant: multi-choice RC_per-token acc + BPB (gold vs up to 3 distractors) + TASKS_TABLE.append( + LightevalTaskConfig( + name=f"basic_skills:{_subset}:cf", + prompt_function=basic_skills_cf_prompt, + hf_repo="allenai/basic-skills", + hf_subset=_subset, + hf_avail_splits=["validation"], + evaluation_splits=["validation"], + few_shots_split="validation", + few_shots_select="random_sampling", + generation_size=-1, + metrics=[ + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbTokenNorm()), + ], + stop_sequence=["\n"], + version=0, + ) + ) + + # MCF variant: labeled A/B/C/D, MC acc (gold vs up to 3 distractors) + TASKS_TABLE.append( + LightevalTaskConfig( + name=f"basic_skills:{_subset}:mcf", + prompt_function=basic_skills_mcf_prompt, + hf_repo="allenai/basic-skills", + hf_subset=_subset, + hf_avail_splits=["validation"], + evaluation_splits=["validation"], + few_shots_split="validation", + few_shots_select="random_sampling", + generation_size=-1, metrics=[ - Metrics.target_bits_per_byte, + LogLikelihoodAccMetric(), + LogLikelihoodAccMetric(normalization=LogProbCharNorm()), ], stop_sequence=["\n"], version=0, diff --git a/src/lighteval/tasks/tasks/commonsenseqa.py b/src/lighteval/tasks/tasks/commonsenseqa.py index 0aa872082..03f6d1ff5 100644 --- a/src/lighteval/tasks/tasks/commonsenseqa.py +++ b/src/lighteval/tasks/tasks/commonsenseqa.py @@ -77,8 +77,8 @@ def commonsenseqa_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, - few_shots_select=None, + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=1, metrics=[Metrics.exact_match], stop_sequence=["\n"], @@ -93,7 +93,7 @@ def commonsenseqa_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_MCF_METRICS, @@ -109,7 +109,7 @@ def commonsenseqa_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_CF_METRICS, diff --git a/src/lighteval/tasks/tasks/gsm8k.py b/src/lighteval/tasks/tasks/gsm8k.py index 83e5390df..b0c124154 100644 --- a/src/lighteval/tasks/tasks/gsm8k.py +++ b/src/lighteval/tasks/tasks/gsm8k.py @@ -86,7 +86,7 @@ def gsm8k_prompt(line, task_name: str = None): hf_subset="main", hf_avail_splits=["train", "test"], evaluation_splits=["test"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=512, metrics=[ diff --git a/src/lighteval/tasks/tasks/hellaswag.py b/src/lighteval/tasks/tasks/hellaswag.py index 940ef8d61..be560e73c 100644 --- a/src/lighteval/tasks/tasks/hellaswag.py +++ b/src/lighteval/tasks/tasks/hellaswag.py @@ -89,8 +89,8 @@ def hellaswag_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split="test", - few_shots_select=None, + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_MCF_METRICS, stop_sequence=["\n"], @@ -105,8 +105,8 @@ def hellaswag_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split="test", - few_shots_select=None, + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=1, metrics=[Metrics.exact_match], stop_sequence=["\n"], @@ -120,8 +120,8 @@ def hellaswag_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split="test", - few_shots_select=None, + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_CF_METRICS, stop_sequence=["\n"], diff --git a/src/lighteval/tasks/tasks/lambada.py b/src/lighteval/tasks/tasks/lambada.py index b4e7aa2f3..778b87202 100644 --- a/src/lighteval/tasks/tasks/lambada.py +++ b/src/lighteval/tasks/tasks/lambada.py @@ -21,12 +21,29 @@ https://arxiv.org/abs/1606.06031 """ +import hashlib +import random + from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric from lighteval.metrics.metrics import Metrics from lighteval.metrics.normalizations import LogProbCharNorm from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc +# Distractor pool for lambada:cf — loaded lazily on first prompt call. +# Contains all last words from the test split; used to sample 3 distractors +# per example via a deterministic hash of the passage text. +_LAMBADA_LAST_WORDS = None + + +def _get_lambada_last_words(): + global _LAMBADA_LAST_WORDS + if _LAMBADA_LAST_WORDS is None: + import datasets as _hf_datasets + ds = _hf_datasets.load_dataset("cimec/lambada", "plain_text", split="test") + _LAMBADA_LAST_WORDS = [text.rsplit(" ", 1)[1] for text in ds["text"]] + return _LAMBADA_LAST_WORDS + def lambada_prompt(line, task_name: str = None): """Standard LAMBADA prompt: context as query, last word as gold continuation.""" @@ -39,6 +56,29 @@ def lambada_prompt(line, task_name: str = None): ) +def lambada_multichoice_prompt(line, task_name: str = None): + """RC_per-char prompt: gold last word vs 3 distractors sampled from the test set. + + Distractors are chosen deterministically by hashing the passage text, so + results are reproducible without storing a precomputed mapping file. + """ + query, gold = line["text"].rsplit(" ", 1) + last_words = _get_lambada_last_words() + seed = int(hashlib.md5(line["text"].encode()).hexdigest(), 16) % (2**32) + rng = random.Random(seed) + pool = [w for w in last_words if w != gold] + distractors = rng.sample(pool, 3) + all_choices = distractors + [gold] + rng.shuffle(all_choices) + gold_ix = all_choices.index(gold) + return Doc( + task_name=task_name, + query=query, + choices=[" " + c for c in all_choices], + gold_index=gold_ix, + ) + + def lambada_cloze_prompt(line, task_name: str = None): """Cloze-style LAMBADA prompt with fill-in-the-blank indicator.""" query, choice = line["text"].rsplit(" ", 1) @@ -50,10 +90,26 @@ def lambada_cloze_prompt(line, task_name: str = None): ) -# CF variant: RC per-char (rank choice with per-character normalization) + BPB on gold +# BPB variant: single-choice, score gold continuation only (no acc ranking) +lambada_bpb = LightevalTaskConfig( + name="lambada:bpb", + prompt_function=lambada_prompt, + hf_repo="cimec/lambada", + hf_subset="plain_text", + hf_avail_splits=["train", "test", "validation"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=-1, + metrics=[Metrics.target_bits_per_byte], + stop_sequence=["\n"], + version=0, +) + +# CF variant: multi-choice RC per-char acc (gold last word vs 3 sampled distractors) + BPB lambada_cf = LightevalTaskConfig( name="lambada:cf", - prompt_function=lambada_prompt, + prompt_function=lambada_multichoice_prompt, hf_repo="cimec/lambada", hf_subset="plain_text", hf_avail_splits=["train", "test", "validation"], @@ -62,8 +118,7 @@ def lambada_cloze_prompt(line, task_name: str = None): few_shots_select=None, generation_size=-1, metrics=[ - LogLikelihoodAccMetric(normalization=LogProbCharNorm()), - Metrics.target_bits_per_byte, + LogLikelihoodAccMetric(normalization=LogProbCharNorm()) ], stop_sequence=["\n"], version=0, @@ -103,6 +158,7 @@ def lambada_cloze_prompt(line, task_name: str = None): ) TASKS_TABLE = [ + lambada_bpb, lambada_cf, lambada_standard_cloze, lambada_openai_cloze, diff --git a/src/lighteval/tasks/tasks/med.py b/src/lighteval/tasks/tasks/med.py index 2493fd3bc..fa4f2f19f 100644 --- a/src/lighteval/tasks/tasks/med.py +++ b/src/lighteval/tasks/tasks/med.py @@ -115,8 +115,8 @@ def med_mcqa_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, - few_shots_select=None, + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=1, metrics=[Metrics.exact_match], stop_sequence=["\n"], @@ -131,7 +131,7 @@ def med_mcqa_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_MCF_METRICS, @@ -179,7 +179,7 @@ def med_mcqa_cf_prompt(line, task_name: str = None): hf_subset="default", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, + few_shots_split="train", few_shots_select="random_sampling_from_train", generation_size=-1, metrics=_CF_METRICS, diff --git a/src/lighteval/tasks/tasks/natural_questions.py b/src/lighteval/tasks/tasks/natural_questions.py index 6ded9c40d..5a39fa8ff 100644 --- a/src/lighteval/tasks/tasks/natural_questions.py +++ b/src/lighteval/tasks/tasks/natural_questions.py @@ -30,11 +30,12 @@ def nq_gen_prompt(line, task_name: str = None): if not answers: return None is_few_shots = line.get("__few_shots", False) + prefix = " " if is_few_shots else "" return Doc( task_name=task_name, query=f"Question: {line['question']}\nAnswer:", - choices=[f"{' ' if is_few_shots else ''}{answers[0]}"], - gold_index=0, + choices=[f"{prefix}{ans}" for ans in answers], + gold_index=list(range(len(answers))), ) diff --git a/src/lighteval/tasks/tasks/ruler.py b/src/lighteval/tasks/tasks/ruler.py index 7b6648861..caa3a953d 100644 --- a/src/lighteval/tasks/tasks/ruler.py +++ b/src/lighteval/tasks/tasks/ruler.py @@ -88,12 +88,24 @@ "qa_2", ] -DEFAULT_LENGTHS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] +DEFAULT_LENGTHS = [2048, 4096, 8192, 16384, 32768, 65536, 131072] # NUM_SAMPLES = 500 NUM_SAMPLES = 100 RANDOM_SEED = 42 +MIN_RULER_LENGTH = 2048 +# Reserve 1 token so vllm/inference-time special tokens (e.g. BOS) don't push +# prompt + gen_size over the model's context window. +PROMPT_SAFETY_MARGIN = 1 + + +def _validate_lengths(lengths: list[int]) -> list[int]: + bad = [x for x in lengths if x < MIN_RULER_LENGTH] + if bad: + raise ValueError(f"RULER lengths must be >= {MIN_RULER_LENGTH}; got {bad}") + return lengths + # --------------------------------------------------------------------------- # Tokenizer helpers # --------------------------------------------------------------------------- @@ -135,10 +147,26 @@ def _get_model_arch(tokenizer_path: str) -> str: return f"{safe}_{h}" +def _get_tokenizer_hash(tokenizer_path: str) -> str: + candidates = [ + os.path.join(tokenizer_path, "tokenizer.json"), + os.path.join(tokenizer_path, "tokenizer.model"), + os.path.join(tokenizer_path, "vocab.json"), + os.path.join(tokenizer_path, "tokenizer_config.json"), + ] + hasher = hashlib.md5() + for path in candidates: + if os.path.isfile(path): + hasher.update(path.encode()) + hasher.update(str(os.path.getmtime(path)).encode()) + return hasher.hexdigest()[:8] + + def _get_cache_dir(tokenizer_path: str) -> Path: hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface")) arch = _get_model_arch(tokenizer_path) - return Path(hf_home) / "lighteval" / "ruler" / arch + tok_hash = _get_tokenizer_hash(tokenizer_path) + return Path(hf_home) / "lighteval" / "ruler" / f"{arch}_{tok_hash}" # --------------------------------------------------------------------------- @@ -334,13 +362,11 @@ def _niah_generate_samples( num_needle_q: int = 1, random_seed: int = RANDOM_SEED, ) -> list[dict]: - budget = max_seq_length + budget = max_seq_length - PROMPT_SAFETY_MARGIN num_needle_k = max(num_needle_k, num_needle_q) if type_haystack == "essay": incremental = 500 - elif type_haystack in ("repeat", "needle"): - incremental = 25 if max_seq_length >= 4096 else 5 else: incremental = 25 @@ -383,9 +409,16 @@ def _gen_prefix(tnv, nq, nv, query): num_haystack = max(incremental, num_haystack) write_jsons = [] - for index in tqdm( - range(num_samples), desc=f"Generating NIAH samples | {max_seq_length}" - ): + index = 0 + pbar = tqdm(total=num_samples, desc=f"Generating NIAH samples | {max_seq_length}") + MAX_ATTEMPTS = num_samples * 10 + while len(write_jsons) < num_samples: + if index >= MAX_ATTEMPTS: + logger.warning( + f"Reached max attempts ({MAX_ATTEMPTS}) for NIAH at length {max_seq_length}, " + f"got {len(write_jsons)}/{num_samples} samples" + ) + break # Per-sample seeding for reproducibility and diversity sample_seed = random_seed + index random.seed(sample_seed) @@ -394,7 +427,7 @@ def _gen_prefix(tnv, nq, nv, query): input_text = answer = query = gen_prefix = length = None while True: try: - input_text, answer, query = _niah_generate_input_output( + cand_text, cand_answer, cand_query = _niah_generate_input_output( used_haystack, haystack, type_haystack=type_haystack, @@ -406,10 +439,15 @@ def _gen_prefix(tnv, nq, nv, query): num_needle_q=num_needle_q, random_seed=sample_seed, ) - gen_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, query) - prompt = _build_runtime_prompt(input_text, gen_prefix) - length = len(tokenizer(prompt).input_ids) + tokens_to_generate - assert length <= budget + cand_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, cand_query) + cand_prompt = _build_runtime_prompt(cand_text, cand_prefix) + cand_length = len(tokenizer(cand_prompt).input_ids) + tokens_to_generate + if cand_length > budget: + raise RuntimeError(f"length {cand_length} exceeds budget {budget}") + # Commit only after validation passes + input_text, answer, query = cand_text, cand_answer, cand_query + gen_prefix = cand_prefix + length = cand_length break except Exception: if used_haystack > incremental: @@ -418,6 +456,7 @@ def _gen_prefix(tnv, nq, nv, query): break if answer is None: + index += 1 continue write_jsons.append( @@ -430,6 +469,9 @@ def _gen_prefix(tnv, nq, nv, query): "gen_prefix": gen_prefix, } ) + pbar.update(1) + index += 1 + pbar.close() return write_jsons @@ -448,12 +490,12 @@ def _gen_prefix(tnv, nq, nv, query): "Question: Find all variables that are assigned the value {query} in the text above." ), "answer_prefix": ( - " Answer: According to the chain(s) of variable assignment in the text above, " - "{num_v} variables are assigned the value {query}, they are: " + "Answer: According to the chain(s) of variable assignment in the text above, " + "{num_v} variables are assigned the value {query}, they are:" ), } VT_TEMPLATE = VT_CONFIG["template"] + VT_CONFIG["answer_prefix"] -VT_ANSWER_PREFIX_BASE = " Answer: According to the chain(s) of variable assignment" +VT_ANSWER_PREFIX_BASE = "Answer: According to the chain(s) of variable assignment" def _vt_generate_chains( @@ -564,7 +606,7 @@ def _vt_generate_samples( num_hops: int = 4, tokens_to_generate: int = VT_CONFIG["tokens_to_generate"], ) -> list[dict]: - budget = max_seq_length + budget = max_seq_length - PROMPT_SAFETY_MARGIN # --- Step 1: Generate ICL example within a 500-token budget (matches reference) --- # Reference: get_dataset() calls sys_vartrack_w_noise_random(max_seq_length=500, incremental=5) @@ -608,26 +650,38 @@ def _vt_generate_samples( num_noises = max(incremental, num_noises) write_jsons = [] - for index in tqdm( - range(num_samples), desc=f"Generating VT samples | {max_seq_length}" - ): + index = 0 + pbar = tqdm(total=num_samples, desc=f"Generating VT samples | {max_seq_length}") + MAX_ATTEMPTS = num_samples * 10 + while len(write_jsons) < num_samples: + if index >= MAX_ATTEMPTS: + logger.warning( + f"Reached max attempts ({MAX_ATTEMPTS}) for VT at length {max_seq_length}, " + f"got {len(write_jsons)}/{num_samples} samples" + ) + break # Per-sample seeding for reproducibility and diversity random.seed(RANDOM_SEED + index) np.random.seed(RANDOM_SEED + index) used_noises = num_noises - input_text = answer = length = None + input_text = answer = cached_input = cached_gen_prefix = length = None while True: try: - input_text, answer = _vt_generate_input_output( + cand_text, cand_answer = _vt_generate_input_output( used_noises, num_chains, num_hops ) - cached_input, cached_gen_prefix = _vt_build_cached_sample( - input_text, _vt_randomize_icl(icl_str) + cand_cached_input, cand_cached_gen_prefix = _vt_build_cached_sample( + cand_text, _vt_randomize_icl(icl_str) ) - length = _runtime_prompt_budget_length( - tokenizer, cached_input, cached_gen_prefix, tokens_to_generate + cand_length = _runtime_prompt_budget_length( + tokenizer, cand_cached_input, cand_cached_gen_prefix, tokens_to_generate ) - assert length <= budget + if cand_length > budget: + raise RuntimeError(f"length {cand_length} exceeds budget {budget}") + # Commit only after validation passes + input_text, answer = cand_text, cand_answer + cached_input, cached_gen_prefix = cand_cached_input, cand_cached_gen_prefix + length = cand_length break except Exception: if used_noises > incremental: @@ -636,6 +690,7 @@ def _vt_generate_samples( break if answer is None: + index += 1 continue write_jsons.append( @@ -648,6 +703,9 @@ def _vt_generate_samples( "gen_prefix": cached_gen_prefix, } ) + pbar.update(1) + index += 1 + pbar.close() return write_jsons @@ -665,7 +723,7 @@ def _vt_generate_samples( "Memorize the ones that appear most often.\n{context}\n" "Question: What are the 10 most common words in the above list?" ), - "answer_prefix": " Answer: The top 10 words that appear most often in the list are:", + "answer_prefix": "Answer: The top 10 words that appear most often in the list are:", } CWE_TEMPLATE = CWE_CONFIG["template"] + CWE_CONFIG["answer_prefix"] _CWE_RNG = random.Random(RANDOM_SEED) @@ -708,12 +766,8 @@ def _cwe_get_example( def _cwe_generate_input_output(num_words: int, max_seq_length: int, words: list[str]): - if max_seq_length < 4096: - ctx_ex, ans_ex = _cwe_get_example(20, words, 3, 1, 10) - context, answer = _cwe_get_example(num_words, words, 6, 1, 10) - else: - ctx_ex, ans_ex = _cwe_get_example(40, words, 10, 3, 10) - context, answer = _cwe_get_example(num_words, words, 30, 3, 10) + ctx_ex, ans_ex = _cwe_get_example(40, words, 10, 3, 10) + context, answer = _cwe_get_example(num_words, words, 30, 3, 10) input_example = CWE_TEMPLATE.format(context=ctx_ex, query="") + " ".join( [f"{i + 1}. {w}" for i, w in enumerate(ans_ex)] @@ -730,7 +784,7 @@ def _cwe_generate_samples( tokens_to_generate: int = CWE_CONFIG["tokens_to_generate"], ) -> list[dict]: words = _get_cwe_words() - budget = max_seq_length + budget = max_seq_length - PROMPT_SAFETY_MARGIN num_words = incremental total_tokens = 0 @@ -753,27 +807,40 @@ def _cwe_generate_samples( num_words = max(incremental, num_words) write_jsons = [] - for index in tqdm( - range(num_samples), desc=f"Generating CWE samples | {max_seq_length}" - ): + write_jsons = [] + index = 0 + pbar = tqdm(total=num_samples, desc=f"Generating CWE samples | {max_seq_length}") + MAX_ATTEMPTS = num_samples * 10 + while len(write_jsons) < num_samples: + if index >= MAX_ATTEMPTS: + logger.warning( + f"Reached max attempts ({MAX_ATTEMPTS}) for CWE at length {max_seq_length}, " + f"got {len(write_jsons)}/{num_samples} samples" + ) + break # Per-sample seeding for reproducibility and diversity random.seed(RANDOM_SEED + index) np.random.seed(RANDOM_SEED + index) _CWE_RNG.seed(RANDOM_SEED + index) used_words = num_words - input_example = input_text = answer = length = None + input_text = answer = full_input = gen_prefix = length = None while True: try: - input_example, input_text, answer = _cwe_generate_input_output( + cand_example, cand_text, cand_answer = _cwe_generate_input_output( used_words, max_seq_length, words ) - full_input, gen_prefix = _cwe_build_cached_sample( - input_example, input_text + cand_full_input, cand_gen_prefix = _cwe_build_cached_sample( + cand_example, cand_text ) - length = _runtime_prompt_budget_length( - tokenizer, full_input, gen_prefix, tokens_to_generate + cand_length = _runtime_prompt_budget_length( + tokenizer, cand_full_input, cand_gen_prefix, tokens_to_generate ) - assert length <= budget + if cand_length > budget: + raise RuntimeError(f"length {cand_length} exceeds budget {budget}") + # Commit only after validation passes + input_text, answer = cand_text, cand_answer + full_input, gen_prefix = cand_full_input, cand_gen_prefix + length = cand_length break except Exception: if used_words > incremental: @@ -782,6 +849,7 @@ def _cwe_generate_samples( break if answer is None: + index += 1 continue write_jsons.append( @@ -794,6 +862,9 @@ def _cwe_generate_samples( "gen_prefix": gen_prefix, } ) + pbar.update(1) + index += 1 + pbar.close() return write_jsons @@ -812,7 +883,7 @@ def _cwe_generate_samples( "Question: Do not provide any explanation. Please ignore the dots '....'. " "What are the three most frequently appeared words in the above coded text?" ), - "answer_prefix": " Answer: According to the coded text above, the three most frequently appeared words are:", + "answer_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:", } FWE_TEMPLATE = FWE_CONFIG["template"] + FWE_CONFIG["answer_prefix"] FWE_SEED = RANDOM_SEED @@ -876,42 +947,67 @@ def _fwe_generate_samples( alpha: float = 2.0, tokens_to_generate: int = 50, ) -> list[dict]: - budget = max_seq_length + budget = max_seq_length - PROMPT_SAFETY_MARGIN input_max_len = budget - tokens_to_generate vs = max_seq_length // 50 if vocab_size == -1 else vocab_size + fwe_incremental = input_max_len // 32 _, _, num_example_words = _fwe_generate_input_output( input_max_len, tokenizer, coded_wordlen=coded_wordlen, vocab_size=vs, - incremental=input_max_len // 32, + incremental=fwe_incremental, alpha=alpha, ) write_jsons = [] - for index in tqdm( - range(num_samples), desc=f"Generating FWE samples | {max_seq_length}" - ): + index = 0 + pbar = tqdm(total=num_samples, desc=f"Generating FWE samples | {max_seq_length}") + MAX_ATTEMPTS = num_samples * 10 + while len(write_jsons) < num_samples: + if index >= MAX_ATTEMPTS: + logger.warning( + f"Reached max attempts ({MAX_ATTEMPTS}) for FWE at length {max_seq_length}, " + f"got {len(write_jsons)}/{num_samples} samples" + ) + break # Per-sample seeding for reproducibility and diversity sample_seed = RANDOM_SEED + index random.seed(sample_seed) np.random.seed(sample_seed) - input_text, answer, _ = _fwe_generate_input_output( - input_max_len, - tokenizer, - num_words=num_example_words, - coded_wordlen=coded_wordlen, - vocab_size=vs, - incremental=input_max_len // 32, - alpha=alpha, - sample_seed=sample_seed, - ) - input_text_clean, gen_prefix = _fwe_build_cached_sample(input_text) - length = _runtime_prompt_budget_length( - tokenizer, input_text_clean, gen_prefix, tokens_to_generate - ) - assert length <= budget + used_words = num_example_words + input_text_clean = answer = gen_prefix = length = None + while True: + try: + cand_text, cand_answer, _ = _fwe_generate_input_output( + input_max_len, + tokenizer, + num_words=used_words, + coded_wordlen=coded_wordlen, + vocab_size=vs, + incremental=fwe_incremental, + alpha=alpha, + sample_seed=sample_seed, + ) + cand_clean, cand_prefix = _fwe_build_cached_sample(cand_text) + cand_length = _runtime_prompt_budget_length( + tokenizer, cand_clean, cand_prefix, tokens_to_generate + ) + if cand_length > budget: + raise RuntimeError(f"length {cand_length} exceeds budget {budget}") + input_text_clean, answer = cand_clean, cand_answer + gen_prefix, length = cand_prefix, cand_length + break + except Exception: + if used_words > fwe_incremental: + used_words -= fwe_incremental + else: + break + + if answer is None: + index += 1 + continue write_jsons.append( { @@ -923,6 +1019,9 @@ def _fwe_generate_samples( "gen_prefix": gen_prefix, } ) + pbar.update(1) + index += 1 + pbar.close() return write_jsons @@ -1066,7 +1165,7 @@ def _qa_generate_samples( tokens_to_generate: int = QA_CONFIG["tokens_to_generate"], incremental: int = 10, ) -> list[dict]: - budget = max_seq_length + budget = max_seq_length - PROMPT_SAFETY_MARGIN gen_prefix = QA_CONFIG["answer_prefix"] num_docs = incremental @@ -1086,9 +1185,16 @@ def _qa_generate_samples( num_docs = max(incremental, num_docs) write_jsons = [] - for index in tqdm( - range(num_samples), desc=f"Generating QA samples | {max_seq_length}" - ): + index = 0 + pbar = tqdm(total=num_samples, desc=f"Generating QA samples | {max_seq_length}") + MAX_ATTEMPTS = num_samples * 10 + while len(write_jsons) < num_samples: + if index >= MAX_ATTEMPTS: + logger.warning( + f"Reached max attempts ({MAX_ATTEMPTS}) for QA at length {max_seq_length}, " + f"got {len(write_jsons)}/{num_samples} samples" + ) + break # Per-sample seeding for reproducibility and diversity random.seed(RANDOM_SEED + index) np.random.seed(RANDOM_SEED + index) @@ -1096,12 +1202,16 @@ def _qa_generate_samples( input_text = answer = length = None while True: try: - input_text, answer = _qa_generate_input_output( + cand_text, cand_answer = _qa_generate_input_output( index, used_docs, qas=qas, docs=docs ) - prompt = _build_runtime_prompt(input_text, gen_prefix) - length = len(tokenizer(prompt).input_ids) + tokens_to_generate - assert length <= budget + cand_prompt = _build_runtime_prompt(cand_text, gen_prefix) + cand_length = len(tokenizer(cand_prompt).input_ids) + tokens_to_generate + if cand_length > budget: + raise RuntimeError(f"length {cand_length} exceeds budget {budget}") + # Commit only after validation passes + input_text, answer = cand_text, cand_answer + length = cand_length break except Exception: if used_docs > incremental: @@ -1110,6 +1220,7 @@ def _qa_generate_samples( break if answer is None: + index += 1 continue write_jsons.append( @@ -1122,6 +1233,9 @@ def _qa_generate_samples( "gen_prefix": gen_prefix, } ) + pbar.update(1) + index += 1 + pbar.close() return write_jsons @@ -1223,7 +1337,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]: type_needle_v="numbers", num_needle_v=4, template=NIAH_TEMPLATE_MULTI, - tokens_to_generate=300 if length == 4096 else 250, + tokens_to_generate=300 if length <= 4096 else 250, ) elif subset == "vt": return _vt_generate_samples(tokenizer, max_seq_length=length) @@ -1254,6 +1368,7 @@ def ensure_ruler_cache( tokenizer_name: str, lengths: list[int] | None = None, subsets: list[str] | None = None, + allow_partial: bool = False, ) -> Path: """Generate and cache RULER data for the given tokenizer. @@ -1263,7 +1378,7 @@ def ensure_ruler_cache( Returns the cache base directory. """ - lengths = lengths or DEFAULT_LENGTHS + lengths = _validate_lengths(lengths or DEFAULT_LENGTHS) subsets = subsets or SUBSETS cache_base = _get_cache_dir(tokenizer_name) cache_base.mkdir(parents=True, exist_ok=True) @@ -1282,6 +1397,7 @@ def ensure_ruler_cache( ) tokenizer = _load_tokenizer(tokenizer_name) + failed = [] for length, subset in missing: logger.info(f"Generating RULER data: length={length} subset={subset} ...") try: @@ -1295,8 +1411,14 @@ def ensure_ruler_cache( ) logger.info(f"Saved {len(samples)} samples to {subset_dir}") except Exception as e: + failed.append((length, subset, str(e))) logger.warning(f"Skipping subset {subset} at length {length}: {e}") + if failed and not allow_partial: + raise RuntimeError( + f"RULER cache generation failed for {len(failed)} subset(s): {failed}" + ) + return cache_base @@ -1344,7 +1466,7 @@ def get_ruler_tasks( Data is generated on demand in download_dataset_worker (keyed off TOKENIZER_PATH) so only the lengths actually evaluated are generated. """ - lengths = lengths or DEFAULT_LENGTHS + lengths = _validate_lengths(lengths or DEFAULT_LENGTHS) subsets = subsets or SUBSETS cache_base = _get_cache_dir(tokenizer_name) @@ -1359,7 +1481,6 @@ def get_ruler_tasks( _niah_gen_sizes = { "niah_multikey_3": 100, "niah_multiquery": 100, - "niah_multivalue": 300 if length == 4096 else 250, } gen_size = ( _niah_gen_sizes.get(subset, 50) @@ -1368,6 +1489,8 @@ def get_ruler_tasks( 50 if subset == "vt" else 100 if subset == "cwe" else 50 ) # fwe and qa ) + if subset == "niah_multivalue": + gen_size = 300 if length <= 4096 else 250 tasks.append( LightevalTaskConfig( name=f"ruler_{length}:{subset}", diff --git a/src/lighteval/tasks/tasks/squad.py b/src/lighteval/tasks/tasks/squad.py index 67c6a9895..6baf16d6b 100644 --- a/src/lighteval/tasks/tasks/squad.py +++ b/src/lighteval/tasks/tasks/squad.py @@ -31,12 +31,12 @@ def squad_gen_prompt(line, task_name: str = None): if not answers_text: return None is_few_shots = line.get("__few_shots", False) - gold_text = f"{' ' if is_few_shots else ''}{answers_text[0]}" + prefix = " " if is_few_shots else "" return Doc( task_name=task_name, query=f"Title: {line['title']}\n\nBackground: {line['context']}\n\nQuestion: {line['question']}\n\nAnswer:", - choices=[gold_text], - gold_index=0, + choices=[f"{prefix}{ans}" for ans in answers_text], + gold_index=list(range(len(answers_text))), ) diff --git a/src/lighteval/tasks/tasks/winogrande.py b/src/lighteval/tasks/tasks/winogrande.py index be13d4210..72c2099c2 100644 --- a/src/lighteval/tasks/tasks/winogrande.py +++ b/src/lighteval/tasks/tasks/winogrande.py @@ -42,7 +42,7 @@ def winogrande_cf_prompt(line, task_name: str = None): BPB = logprob(suffix | prefix + gold_option) / bytes(suffix). Accuracy is not meaningful in this single-choice format (always 1.0), so - only BPB is reported. Use winogrande:mcf for accuracy scoring. + only BPB is reported. Use winogrande:cf for accuracy scoring. """ sentence = line["sentence"] blank_pos = sentence.index("_") @@ -100,17 +100,17 @@ def winogrande_rc_prompt(line, task_name: str = None): ) -# CF variant: partial evaluation (OLMO-style), BPB only +# BPB variant: partial evaluation (OLMO-style), BPB only # Accuracy is not reported because partial eval uses one fixed gold choice. -winogrande_cf = LightevalTaskConfig( - name="winogrande:cf", +winogrande_bpb = LightevalTaskConfig( + name="winogrande:bpb", prompt_function=winogrande_cf_prompt, hf_repo="allenai/winogrande", hf_subset="winogrande_xl", hf_avail_splits=["train", "test", "validation"], evaluation_splits=["validation"], - few_shots_split=None, - few_shots_select="random_sampling", + few_shots_split="train", + few_shots_select="random_sampling_from_train", generation_size=-1, metrics=[Metrics.target_bits_per_byte], stop_sequence=["\n"], @@ -149,11 +149,11 @@ def winogrande_rc_prompt(line, task_name: str = None): version=0, ) -# RC variant: standard cloze-form for accuracy (approximates Table 46 RC_none). +# CF variant: standard cloze-form for accuracy (approximates Table 46 RC_none). # Scores P(option+suffix | prefix) rather than OLMO's P(suffix | prefix+option), # but gives meaningful accuracy using lighteval's standard Doc framework. -winogrande_rc = LightevalTaskConfig( - name="winogrande:rc", +winogrande_cf = LightevalTaskConfig( + name="winogrande:cf", prompt_function=winogrande_rc_prompt, hf_repo="allenai/winogrande", hf_subset="winogrande_xl", @@ -164,15 +164,14 @@ def winogrande_rc_prompt(line, task_name: str = None): generation_size=-1, metrics=[ LogLikelihoodAccMetric(), - LogLikelihoodAccMetric(normalization=LogProbCharNorm()), ], stop_sequence=["\n"], version=0, ) TASKS_TABLE = [ + winogrande_bpb, winogrande_cf, winogrande_mcf, winogrande_mcf_em, - winogrande_rc, ]