From 9cf5ef81f6a31c80cda9f03cb71670336f8c82ea Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Sat, 28 Mar 2026 16:00:49 +0100
Subject: [PATCH 1/8] Add multilingual task suite aligned with unified
:cf/:mcf/:gen conventions
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Refactor existing multilingual tasks and add new ones to match the English
task conventions (BPB merged into CF, explicit few-shot counts, English
explicitly excluded from all multilingual task registrations).
Refactored:
- mlmm_arc_challenge: cf+mcf variants, 26 langs, hf_revision pinned
- global_mmlu: cf+mcf variants, 33 langs (English removed), both formulations
- mlmm_hellaswag: cf only, 33 langs, few_shots_split=train added
- mgsm: :gen suffix, generation_size=512, both expr_gold + multilingual_quasi_em
New tasks:
- global_mmlu_lite: CohereLabs/Global-MMLU-Lite, 17 langs, cf+mcf
- mmlu_prox (multilingual): li-lab/MMLU-ProX, 28 langs, 10-option cf+mcf
- mmlu_prox (English): li-lab/MMLU-ProX English subset in tasks/tasks/
- wmt24pp: google/wmt24pp, 24 lang pairs × 2 directions, 0-shot gen
Tooling:
- scripts/multilingual_aggregate.py: cross-language average post-processor
- TASK_NAMING_Multilingual.md: naming conventions and language inventory doc
add comet22 metrics
---
TASK_NAMING_Multilingual.md | 487 ++++++++++++++++++
pyproject.toml | 1 +
scripts/multilingual_aggregate.py | 301 +++++++++++
src/lighteval/logging/info_loggers.py | 5 +-
src/lighteval/metrics/dynamic_metrics.py | 121 ++++-
src/lighteval/metrics/metrics.py | 9 +
src/lighteval/metrics/metrics_corpus.py | 53 +-
src/lighteval/metrics/metrics_sample.py | 10 +-
src/lighteval/metrics/sample_preparator.py | 20 +
src/lighteval/metrics/utils/nltk_resources.py | 28 +
.../models/transformers/transformers_model.py | 7 +-
src/lighteval/tasks/lighteval_task.py | 311 ++++++++++-
.../tasks/multilingual/tasks/flores200.py | 80 +--
.../tasks/multilingual/tasks/global_mmlu.py | 275 +++++-----
.../multilingual/tasks/global_mmlu_lite.py | 131 +++++
.../tasks/multilingual/tasks/mgsm.py | 57 +-
.../multilingual/tasks/mlmm_arc_challenge.py | 169 +++---
.../multilingual/tasks/mlmm_hellaswag.py | 180 ++++---
.../tasks/multilingual/tasks/mmlu_prox.py | 170 ++++++
.../tasks/multilingual/tasks/wmt24pp.py | 114 ++++
.../tasks/multilingual/utils/translation.py | 103 ++++
src/lighteval/tasks/tasks/mmlu_prox.py | 133 +++++
src/lighteval/tasks/templates/translation.py | 6 +-
.../templates/utils/translation_literals.py | 24 +
24 files changed, 2448 insertions(+), 347 deletions(-)
create mode 100644 TASK_NAMING_Multilingual.md
create mode 100644 scripts/multilingual_aggregate.py
create mode 100644 src/lighteval/metrics/utils/nltk_resources.py
create mode 100644 src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py
create mode 100644 src/lighteval/tasks/multilingual/tasks/mmlu_prox.py
create mode 100644 src/lighteval/tasks/multilingual/tasks/wmt24pp.py
create mode 100644 src/lighteval/tasks/multilingual/utils/translation.py
create mode 100644 src/lighteval/tasks/tasks/mmlu_prox.py
diff --git a/TASK_NAMING_Multilingual.md b/TASK_NAMING_Multilingual.md
new file mode 100644
index 000000000..26b833aaa
--- /dev/null
+++ b/TASK_NAMING_Multilingual.md
@@ -0,0 +1,487 @@
+# Multilingual Task Naming Conventions
+
+Most multilingual tasks follow the 3-part naming rule: `{base}:{lang}:{suffix}`. Translation tasks
+are now English-centric exceptions: the public selector is `{base}:{lang}`, and it expands to the
+two internal directional leaf tasks `{base}:en_to_x:{lang}` and `{base}:x_to_en:{lang}`.
+
+**Key rule**: English is **explicitly excluded** from every multilingual task. Use the
+corresponding English task file instead (e.g., `arc.py`, `mmlu.py`, `hellaswag.py`, `gsm8k.py`).
+
+---
+
+## Naming pattern
+
+```
+{base}:{lang}:{suffix}|{n_shot}
+```
+
+- `{base}` — task name (e.g., `global_mmlu`, `mlmm_arc`, `mmlu_prox`)
+- `{lang}` — ISO 639-3 language value from `Language` enum (e.g., `deu`, `fra`, `zho`)
+- `{suffix}` — `:cf`, `:mcf`, or `:gen`
+- `{n_shot}` — appended at runtime (e.g., `|5` for 5-shot)
+
+**Group selectors** (auto-generated by `_task_superset_dict`):
+
+| Selector | Matches |
+|---|---|
+| `global_mmlu:cf\|5` | all 41 languages, CF variant |
+| `global_mmlu:fra:cf\|5` | French only, CF variant |
+| `global_mmlu\|5` | all 41 languages × CF + MCF |
+| `mlmm_arc:cf\|5` | all 26 languages, CF variant |
+| `mmlu_prox:cf\|5` | all 28 multilingual languages, CF variant |
+| `mlmm_hellaswag:cf\|5` | all 32 languages |
+| `mgsm:gen\|8` | all 10 languages |
+| `wmt24pp\|0` | all ~24 English-centric language slices (both directions each) |
+| `wmt24pp:de_DE\|0` | German slice only: en→de_DE + de_DE→en |
+| `flores200\|0` | all English-centric FLORES language slices |
+| `flores200:deu_Latn\|0` | German slice only: eng_Latn→deu_Latn + deu_Latn→eng_Latn |
+
+---
+
+## Suffix reference (same as English tasks)
+
+| Suffix | Metrics reported | Description |
+|--------|-----------------|-------------|
+| `:cf` | `acc`, `acc_norm` (char), `target_bpb` | Score full answer text via log p(choice\|context); BPB merged in |
+| `:mcf` | `acc`, `acc_norm` (char) | Score label tokens only (A, B, …); no BPB |
+| `:mcf_em` | `em` (exact match accuracy) | MCF prompt + greedy generation (1–5 tokens) + exact match on predicted label |
+| `:gen` | task-specific (F1, EM, expr_gold, quasi_em; or `chrf++`, `bleu`, `comet22` for translation) | Greedy generation |
+
+---
+
+## Task inventory
+
+### Multilingual ARC Challenge — `mlmm_arc`
+
+| Task pattern | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `mlmm_arc:{lang}:cf\|5` | `jon-tow/okapi_arc_challenge` | test | train | 5 | acc, acc_norm, target_bpb |
+| `mlmm_arc:{lang}:mcf\|5` | same | test | train | 5 | acc, acc_norm |
+| `mlmm_arc:{lang}:mcf_em\|5` | same | test | train | 5 | em |
+
+**26 languages** (English excluded): arabic, bengali, catalan, chinese, croatian, danish, dutch,
+french, german, hindi, hungarian, indonesian, italian, kannada, malayalam, marathi, nepali,
+romanian, russian, serbian, slovak, spanish, tamil, telugu, ukrainian, vietnamese.
+
+**Language codes**:
+`ara`, `ben`, `cat`, `zho`, `hrv`, `dan`, `nld`, `fra`, `deu`, `hin`, `hun`, `ind`, `ita`,
+`kan`, `mal`, `mar`, `nep`, `ron`, `rus`, `srp`, `slk`, `spa`, `tam`, `tel`, `ukr`, `vie`.
+
+File: `multilingual/tasks/mlmm_arc_challenge.py`
+
+**Copy-pasteable examples:**
+```
+# All 26 languages, CF
+mlmm_arc:cf|5
+
+# All 26 languages, MCF
+mlmm_arc:mcf|5
+
+# All 26 languages, MCF greedy
+mlmm_arc:mcf_em|5
+
+# Single language
+mlmm_arc:deu:cf|5
+mlmm_arc:fra:mcf|5
+mlmm_arc:fra:mcf_em|5
+```
+
+---
+
+### Global MMLU — `global_mmlu`
+
+| Task pattern | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `global_mmlu:{lang}:cf\|5` | `CohereForAI/Global-MMLU` | test | dev | 5 | see below |
+| `global_mmlu:{lang}:mcf\|5` | same | test | dev | 5 | see below |
+| `global_mmlu:{lang}:mcf_em\|5` | same | test | dev | 5 | em |
+
+**41 languages** (English excluded): amharic, arabic, bengali, chinese, czech, dutch, french,
+german, greek, hausa, hebrew, hindi, igbo, indonesian, italian, japanese, kirghiz, korean,
+lithuanian, malagasy, malay, nepali, nyanja, persian, polish, portuguese, romanian, russian,
+serbian, shona, sinhala, somali, spanish, swahili, swedish, tagalog, telugu, turkish,
+ukrainian, vietnamese, yoruba.
+
+**Language codes**:
+`amh`, `ara`, `ben`, `zho`, `ces`, `nld`, `fra`, `deu`, `ell`, `hau`, `heb`, `hin`, `ibo`,
+`ind`, `ita`, `jpn`, `kir`, `kor`, `lit`, `mlg`, `msa`, `nep`, `nya`, `fas`, `pol`, `por`,
+`ron`, `rus`, `srp`, `sna`, `sin`, `som`, `spa`, `swa`, `swe`, `tgl`, `tel`, `tur`, `ukr`,
+`vie`, `yor`.
+
+**All 57 MMLU subjects are evaluated** — no per-subject task splitting. Results are
+automatically routed to STEM / Humanities / Social / Other categories inside the metric.
+
+#### CF metrics (`global_mmlu:{lang}:cf`)
+
+| Metric key | Description |
+|---|---|
+| `acc` | overall accuracy across all subjects |
+| `acc_norm` | overall char-normalised accuracy |
+| `bpb` | overall bits-per-byte (lower = better) |
+| `acc_stem`, `acc_norm_stem`, `bpb_stem` | STEM category (20 subjects) |
+| `acc_humanities`, `acc_norm_humanities`, `bpb_humanities` | Humanities category (13 subjects) |
+| `acc_social`, `acc_norm_social`, `bpb_social` | Social Sciences category (12 subjects) |
+| `acc_other`, `acc_norm_other`, `bpb_other` | Other category (remaining subjects) |
+
+Per-category keys return `None` for samples outside the category; the corpus aggregator
+uses `mean_notnone` so they average only over samples that belong to the category.
+
+#### MCF metrics (`global_mmlu:{lang}:mcf`)
+
+| Metric key | Description |
+|---|---|
+| `acc` | overall accuracy |
+| `acc_norm` | overall char-normalised accuracy |
+| `acc_stem`, `acc_norm_stem` | STEM category |
+| `acc_humanities`, `acc_norm_humanities` | Humanities category |
+| `acc_social`, `acc_norm_social` | Social Sciences category |
+| `acc_other`, `acc_norm_other` | Other category |
+
+File: `multilingual/tasks/global_mmlu.py`
+Metrics class: `MMLUCategoryGroupingCF` / `MMLUCategoryGroupingMCF` in `metrics/dynamic_metrics.py`
+
+**Copy-pasteable examples:**
+```
+# All 41 languages, CF (with category-level metrics)
+global_mmlu:cf|5
+
+# All 41 languages, MCF
+global_mmlu:mcf|5
+
+# All 41 languages, MCF greedy
+global_mmlu:mcf_em|5
+
+# All 41 languages, both CF and MCF
+global_mmlu|5
+
+# Single language
+global_mmlu:fra:cf|5
+global_mmlu:zho:mcf|5
+global_mmlu:zho:mcf_em|5
+```
+
+---
+
+### Global MMLU Lite — `global_mmlu_lite`
+
+| Task pattern | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `global_mmlu_lite:{lang}:cf\|5` | `CohereLabs/Global-MMLU-Lite` | test | dev | 5 | acc, acc_norm, target_bpb |
+| `global_mmlu_lite:{lang}:mcf\|5` | same | test | dev | 5 | acc, acc_norm |
+| `global_mmlu_lite:{lang}:mcf_em\|5` | same | test | dev | 5 | em |
+
+**17 languages** (English excluded): arabic, bengali, welsh, german, spanish, french, hindi,
+indonesian, italian, japanese, korean, burmese, portuguese, albanian, swahili, yoruba, chinese.
+
+**Language codes**:
+`ara`, `ben`, `cym`, `deu`, `spa`, `fra`, `hin`, `ind`, `ita`, `jpn`, `kor`, `mya`, `por`,
+`sqi`, `swa`, `yor`, `zho`.
+
+~400 test samples and ~215 dev samples per language across 43 subjects.
+
+File: `multilingual/tasks/global_mmlu_lite.py`
+
+**Copy-pasteable examples:**
+```
+# All 17 languages, CF
+global_mmlu_lite:cf|5
+
+# All 17 languages, MCF
+global_mmlu_lite:mcf|5
+
+# All 17 languages, MCF greedy
+global_mmlu_lite:mcf_em|5
+
+# Single language
+global_mmlu_lite:ara:cf|5
+global_mmlu_lite:ara:mcf_em|5
+```
+
+---
+
+### MMLU-ProX English — `mmlu_prox_eng`
+
+| Task | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `mmlu_prox_eng:cf\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | acc, acc_norm, target_bpb |
+| `mmlu_prox_eng:mcf\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | acc, acc_norm |
+| `mmlu_prox_eng:mcf_em\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | em |
+
+English subset (`hf_subset="en"`). Up to 10 answer options (labels A–J).
+
+**Language code**: `eng`
+
+File: `tasks/tasks/mmlu_prox.py` (2-part name — not captured by `mmlu_prox:cf` superset)
+
+**Copy-pasteable examples:**
+```
+mmlu_prox_eng:cf|5
+mmlu_prox_eng:mcf|5
+mmlu_prox_eng:mcf_em|5
+```
+
+---
+
+### MMLU-ProX Multilingual — `mmlu_prox`
+
+| Task pattern | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `mmlu_prox:{lang}:cf\|5` | `li-lab/MMLU-ProX` | test | validation | 5 | acc, acc_norm, target_bpb |
+| `mmlu_prox:{lang}:mcf\|5` | same | test | validation | 5 | acc, acc_norm |
+| `mmlu_prox:{lang}:mcf_em\|5` | same | test | validation | 5 | em |
+
+**28 languages** (English excluded): afrikaans, arabic, bengali, czech, german, spanish, french,
+hindi, hungarian, indonesian, italian, japanese, korean, marathi, nepali, portuguese, russian,
+serbian, swahili, telugu, thai, ukrainian, urdu, vietnamese, wolof, yoruba, chinese, zulu.
+
+**Language codes**:
+`afr`, `ara`, `ben`, `ces`, `deu`, `spa`, `fra`, `hin`, `hun`, `ind`, `ita`, `jpn`, `kor`,
+`mar`, `nep`, `por`, `rus`, `srp`, `swa`, `tel`, `tha`, `ukr`, `urd`, `vie`, `wol`, `yor`,
+`zho`, `zul`.
+
+Up to 10 options per question (labels A–J).
+
+File: `multilingual/tasks/mmlu_prox.py`
+
+**Copy-pasteable examples:**
+```
+# All 28 languages, CF
+mmlu_prox:cf|5
+
+# All 28 languages, MCF
+mmlu_prox:mcf|5
+
+# All 28 languages, MCF greedy
+mmlu_prox:mcf_em|5
+
+# Single language
+mmlu_prox:deu:cf|5
+mmlu_prox:deu:mcf_em|5
+```
+
+---
+
+### Multilingual HellaSwag — `mlmm_hellaswag`
+
+| Task pattern | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `mlmm_hellaswag:{lang}:cf\|5` | `alexandrainst/m_hellaswag` (most) / `jon-tow/okapi_hellaswag` (zh) | val/validation | None | 5 | acc, acc_norm, target_bpb |
+| `mlmm_hellaswag:{lang}:mcf\|5` | same | val/validation | None | 5 | acc, acc_norm |
+| `mlmm_hellaswag:{lang}:mcf_em\|5` | same | val/validation | None | 5 | em |
+
+HellaSwag has 4 candidate continuations; the MCF formulation labels them A–D and scores (or generates) the label token.
+
+**32 languages** (English excluded): arabic, armenian, basque, bengali, catalan, chinese,
+croatian, danish, dutch, french, german, gujarati, hindi, hungarian, icelandic, indonesian,
+italian, kannada, malayalam, marathi, nepali, portuguese, romanian, russian,
+serbian, slovak, spanish, swedish, tamil, telugu, ukrainian, vietnamese.
+
+**Language codes**:
+`ara`, `hye`, `eus`, `ben`, `cat`, `zho`, `hrv`, `dan`, `nld`, `fra`, `deu`, `guj`, `hin`,
+`hun`, `isl`, `ind`, `ita`, `kan`, `mal`, `mar`, `nep`, `por`, `ron`, `rus`, `srp`, `slk`,
+`spa`, `swe`, `tam`, `tel`, `ukr`, `vie`.
+
+File: `multilingual/tasks/mlmm_hellaswag.py`
+
+**Copy-pasteable examples:**
+```
+# All 32 languages, CF
+mlmm_hellaswag:cf|5
+
+# All 32 languages, MCF
+mlmm_hellaswag:mcf|5
+
+# All 32 languages, MCF greedy
+mlmm_hellaswag:mcf_em|5
+
+# Single language
+mlmm_hellaswag:deu:cf|5
+mlmm_hellaswag:zho:mcf|5
+mlmm_hellaswag:zho:mcf_em|5
+```
+
+---
+
+### Multilingual GSM — `mgsm`
+
+| Task pattern | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `mgsm:{lang}:gen\|8` | `juletxara/mgsm` | test | train | 8 | expr_gold_metric, multilingual_quasi_em |
+
+Both `expr_gold_metric` (math expression parser, for Arabic-numeral answers) and
+`MultilingualQuasiExactMatchMetric` (language-aware fuzzy match, for non-ASCII digit systems)
+are reported in the same pass.
+
+**10 languages** (English excluded): bengali, french, german, japanese, russian, spanish,
+swahili, telugu, thai, chinese.
+
+**Language codes**:
+`ben`, `fra`, `deu`, `jpn`, `rus`, `spa`, `swa`, `tel`, `tha`, `zho`.
+
+File: `multilingual/tasks/mgsm.py`
+
+**Copy-pasteable examples:**
+```
+# All 10 languages
+mgsm:gen|8
+
+# Single language
+mgsm:deu:gen|8
+mgsm:zho:gen|8
+```
+
+---
+
+### WMT24++ Translation — `wmt24pp`
+
+| Task pattern | Dataset | Eval | FS | ICL | Metrics |
+|---|---|---|---|---|---|
+| `wmt24pp:{lp}\|0` | `google/wmt24pp` | train | none | 0 | chrf++, bleu, comet22 |
+
+Running `wmt24pp:{lp}|0` expands to both internal directional tasks:
+`wmt24pp:en_to_x:{lp}|0` and `wmt24pp:x_to_en:{lp}|0`.
+
+`{lp}` is the language-pair code from the dataset's `lp` column (e.g., `de_DE`, `fr_FR`,
+`zh_CN`). Rows with `is_bad_source=True` are filtered out.
+
+**Available language codes** in `wmt24pp.py`:
+`de_DE`, `fr_FR`, `cs_CZ`, `es_MX`, `ru_RU`, `zh_CN`, `ja_JP`, `uk_UA`, `hi_IN`, `ar_EG`,
+`ko_KR`, `pt_BR`, `tr_TR`, `pl_PL`, `he_IL`, `nl_NL`, `it_IT`, `sv_SE`, `fi_FI`, `vi_VN`,
+`bn_IN`, `th_TH`, `id_ID`, `hu_HU`.
+
+File: `multilingual/tasks/wmt24pp.py`
+
+**Copy-pasteable examples:**
+```
+# All ~24 language slices, both directions each
+wmt24pp|0
+
+# Single language slice
+wmt24pp:de_DE|0
+wmt24pp:zh_CN|0
+```
+
+---
+
+### FLORES-200 Translation — `flores200`
+
+| Task pattern | Dataset | Eval | FS split | ICL | Metrics |
+|---|---|---|---|---|---|
+| `flores200:{lang}\|0` | `facebook/flores` | devtest | dev | 0 | chrf++, bleu, comet22 |
+
+Task names use FLORES-200 language codes: `{iso639-3}_{script}` (e.g., `deu_Latn`, `zho_Hans`,
+`arb_Arab`). English is implicit: running `flores200:{lang}|0` expands to both internal
+directional tasks `flores200:en_to_x:{lang}|0` and `flores200:x_to_en:{lang}|0`.
+
+**~200 languages** are listed in `flores_200_languages`, but only the English-centric bilingual
+slices are registered here.
+
+File: `multilingual/tasks/flores200.py`
+
+**Copy-pasteable examples:**
+```
+# German slice (eng_Latn ↔ deu_Latn)
+flores200:deu_Latn|0
+
+# Chinese (Simplified) slice
+flores200:zho_Hans|0
+
+# Arabic slice
+flores200:arb_Arab|0
+```
+
+Note: FLORES-200 uses 4-part language codes (e.g., `eng_Latn`) not the 3-letter ISO codes
+used by other tasks. The public selector has 2 components (`flores200:{lang}`), while the
+directional leaf tasks use 3 components so they still participate in the `a:b:c` superset and
+averaging behavior.
+
+---
+
+## Two-level averaging
+
+### Level 1 — per-language average (automatic)
+
+For any 3-component task name `a:b:c`, lighteval auto-generates an `a:b:c:_average` aggregate
+key by grouping over the middle component. For tasks with a single subset per language (all
+tasks here except the old subject-split global_mmlu), this average equals the task score.
+
+```
+# Auto-generated by lighteval when running group selectors:
+global_mmlu:cf:_average|5 # mean of global_mmlu:{lang}:cf across all langs
+mlmm_arc:cf:_average|5
+mmlu_prox:cf:_average|5
+mlmm_hellaswag:cf:_average|5
+mgsm:gen:_average|8
+wmt24pp:de_DE:_average|0 # mean of en_to_x + x_to_en for de_DE
+flores200:deu_Latn:_average|0 # mean of eng_Latn→deu_Latn + deu_Latn→eng_Latn
+```
+
+For `global_mmlu`, the per-language aggregate gives you `acc`, `acc_norm`, `bpb` overall
+plus the 4×3 (CF) or 4×2 (MCF) per-category variants, all averaged across all 57 subjects
+for that language.
+
+### Level 2 — cross-language average (post-hoc script)
+
+After evaluation, compute a mean across all languages:
+
+```bash
+python scripts/multilingual_aggregate.py results.json
+python scripts/multilingual_aggregate.py results.json --prefix global_mmlu
+```
+
+Output keys:
+```
+global_mmlu:cf:cross_lang_average|5 # mean of per-lang CF averages
+global_mmlu:mcf:cross_lang_average|5
+mlmm_arc:cf:cross_lang_average|5
+mlmm_hellaswag:cf:cross_lang_average|5
+mgsm:gen:cross_lang_average|8
+mmlu_prox:cf:cross_lang_average|5
+wmt24pp:de_DE:en_to_x|0 # alias for the raw directional task
+wmt24pp:de_DE:x_to_en|0 # alias for the raw directional task
+wmt24pp:de_DE:bidirectional_average|0 # alias for wmt24pp:de_DE:_average|0
+wmt24pp:en_to_x:cross_lang_average|0 # mean across all English→X directions
+wmt24pp:x_to_en:cross_lang_average|0 # mean across all X→English directions
+wmt24pp:bidirectional_cross_lang_average|0
+flores200:deu_Latn:bidirectional_average|0
+flores200:bidirectional_cross_lang_average|0
+```
+
+---
+
+## Quick reference: all copy-pasteable group commands
+
+```bash
+# --- Multiple-choice tasks (CF + BPB) ---
+global_mmlu:cf|5 # Global MMLU, all 41 langs, CF (acc+norm+bpb + 4 categories each)
+global_mmlu:mcf|5 # Global MMLU, all 41 langs, MCF (acc+norm + 4 categories each)
+global_mmlu:mcf_em|5 # Global MMLU, all 41 langs, MCF greedy (em)
+global_mmlu|5 # Global MMLU, all 41 langs, CF + MCF
+
+global_mmlu_lite:cf|5 # Global MMLU Lite, all 17 langs, CF
+global_mmlu_lite:mcf|5 # Global MMLU Lite, all 17 langs, MCF
+global_mmlu_lite:mcf_em|5 # Global MMLU Lite, all 17 langs, MCF greedy (em)
+
+mlmm_arc:cf|5 # ARC Challenge, all 26 langs, CF
+mlmm_arc:mcf|5 # ARC Challenge, all 26 langs, MCF
+mlmm_arc:mcf_em|5 # ARC Challenge, all 26 langs, MCF greedy (em)
+
+mmlu_prox:cf|5 # MMLU-ProX multilingual, all 28 langs, CF
+mmlu_prox:mcf|5 # MMLU-ProX multilingual, all 28 langs, MCF
+mmlu_prox:mcf_em|5 # MMLU-ProX multilingual, all 28 langs, MCF greedy (em)
+mmlu_prox_eng:cf|5 # MMLU-ProX English only
+mmlu_prox_eng:mcf|5
+mmlu_prox_eng:mcf_em|5
+
+# --- Reasoning / generation ---
+mlmm_hellaswag:cf|5 # HellaSwag, all 32 langs, CF
+mlmm_hellaswag:mcf|5 # HellaSwag, all 32 langs, MCF
+mlmm_hellaswag:mcf_em|5 # HellaSwag, all 32 langs, MCF greedy (em)
+mgsm:gen|8 # MGSM math, all 10 langs
+
+# --- Translation ---
+wmt24pp|0 # WMT24++, all English-centric language slices
+wmt24pp:de_DE|0 # WMT24++, German slice (both directions)
+flores200|0 # FLORES-200, all English-centric language slices
+flores200:deu_Latn|0 # FLORES-200, German slice (both directions)
+```
diff --git a/pyproject.toml b/pyproject.toml
index 576a93122..253f0ea35 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -125,6 +125,7 @@ multilingual = [
"spacy[ja,ko,th]>=3.8.0",
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
+ "unbabel-comet>=2.0.0", # COMET-22 metric for translation tasks
]
math = ["latex2sympy2_extended==1.0.6"]
wandb = ["wandb"]
diff --git a/scripts/multilingual_aggregate.py b/scripts/multilingual_aggregate.py
new file mode 100644
index 000000000..8347cd7a9
--- /dev/null
+++ b/scripts/multilingual_aggregate.py
@@ -0,0 +1,301 @@
+"""
+multilingual_aggregate.py — Aggregation helpers for multilingual task results.
+
+Usage:
+ python scripts/multilingual_aggregate.py results.json [--prefix PREFIX] [--output out.json]
+
+ results.json — lighteval output JSON file (the "results" field)
+ --prefix — optional task prefix filter, e.g. "global_mmlu" (default: all)
+ --output — write aggregate outputs to this JSON file (default: stdout)
+
+How it works
+------------
+For standard multilingual tasks, lighteval auto-generates per-language subtask
+averages such as:
+
+ global_mmlu_fra:cf:_average|5 ← mean over all 57 MMLU subjects for French
+ global_mmlu_deu:cf:_average|5 ← mean over all 57 MMLU subjects for German
+
+This script groups those into cross-language averages such as:
+
+ global_mmlu:cf:cross_lang_average|5
+
+For the English-centric translation tasks, the raw task names are direction-first:
+
+ wmt24pp:en_to_x:de_DE|0
+ wmt24pp:x_to_en:de_DE|0
+ wmt24pp:de_DE:_average|0
+
+This script additionally emits language-centered aliases and translation-specific
+cross-language averages such as:
+
+ wmt24pp:de_DE:en_to_x|0
+ wmt24pp:de_DE:x_to_en|0
+ wmt24pp:de_DE:bidirectional_average|0
+ wmt24pp:en_to_x:cross_lang_average|0
+ wmt24pp:x_to_en:cross_lang_average|0
+ wmt24pp:bidirectional_cross_lang_average|0
+"""
+
+import argparse
+import json
+import math
+import re
+import sys
+from collections import defaultdict
+from pathlib import Path
+
+# All 3-letter ISO 639-3 language codes used in the Language enum.
+# Derived from lighteval.utils.language.Language; kept as a static set here
+# to avoid importing the full lighteval package.
+_LANG_CODES = {
+ "afr", "aka", "amh", "ara", "aze", "bam", "bel", "ben", "bos", "bul",
+ "cat", "ces", "ckb", "cmn", "cym", "dan", "deu", "ell", "eng", "epo",
+ "est", "eus", "ewe", "fas", "fin", "fra", "ful", "gla", "gle", "glg",
+ "grn", "guj", "hau", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind",
+ "isl", "ita", "jpn", "kan", "kat", "kaz", "khm", "kin", "kir", "kon",
+ "kor", "lao", "lat", "lin", "lit", "ltz", "lug", "luo", "lvs", "mal",
+ "mar", "mkd", "mlt", "mri", "msa", "mya", "nep", "nld", "nor", "nob",
+ "nno", "nya", "ory", "pan", "pol", "por", "pus", "ron", "run", "rus",
+ "sag", "sin", "slk", "slv", "smo", "sna", "som", "sot", "spa", "sqi",
+ "srp", "ssw", "swa", "swe", "tam", "tel", "tgk", "tgl", "tha", "tir",
+ "ton", "tsn", "tso", "tuk", "tur", "twi", "uig", "ukr", "urd", "uzn",
+ "vie", "wol", "xho", "yor", "zho", "zsm", "zul",
+}
+
+# Pattern: task names whose first component ends with _{lang_code}
+# e.g. "global_mmlu_fra" → base="global_mmlu", lang="fra"
+_LANG_RE = re.compile(r"^(.+?)_(" + "|".join(_LANG_CODES) + r")(:.+)?$")
+_TRANSLATION_DIRECTION_RE = re.compile(r"^(?P[^:]+):(?Pen_to_x|x_to_en):(?P[^:]+)$")
+_TRANSLATION_AVERAGE_RE = re.compile(r"^(?P[^:]+):(?P[^:]+):_average$")
+
+
+def _strip_lang(task_no_fewshot: str):
+ """
+ Given a task name without the |n suffix, return (base_name, lang_code).
+ base_name is the task name with the language component replaced by nothing.
+ Returns (None, None) if no language component found.
+
+ Example:
+ "global_mmlu_fra:cf:_average" → ("global_mmlu:cf:_average", "fra")
+ "mlmm_arc_deu:challenge:cf:_average" → ("mlmm_arc:challenge:cf:_average", "deu")
+ "gsm8k" → (None, None)
+ """
+ m = _LANG_RE.match(task_no_fewshot)
+ if m:
+ base_prefix = m.group(1)
+ lang = m.group(2)
+ rest = m.group(3) or ""
+ return f"{base_prefix}{rest}", lang
+ return None, None
+
+
+def _avg(values):
+ return sum(values) / len(values) if values else float("nan")
+
+
+def _propagate_se(se_values):
+ """SE of mean of independent estimates: sqrt(sum(SE_i^2)) / n."""
+ n = len(se_values)
+ return math.sqrt(sum(v ** 2 for v in se_values)) / n if n else float("nan")
+
+
+def _split_fewshot(task_name: str):
+ if "|" in task_name:
+ task_no_fs, fs_suffix = task_name.rsplit("|", 1)
+ return task_no_fs, "|" + fs_suffix
+ return task_name, ""
+
+
+def _is_valid_number(value):
+ return isinstance(value, (int, float)) and not (
+ isinstance(value, float) and math.isnan(value)
+ )
+
+
+def _append_numeric_metrics(metric_lists: dict[str, list[float]], metrics: dict):
+ for metric_name, metric_val in metrics.items():
+ if _is_valid_number(metric_val):
+ metric_lists[metric_name].append(metric_val)
+
+
+def _average_metric_lists(metric_lists: dict[str, list[float]]) -> dict:
+ averaged = {}
+ for metric_name, vals in metric_lists.items():
+ if metric_name.endswith("_stderr"):
+ averaged[metric_name] = _propagate_se(vals)
+ else:
+ averaged[metric_name] = _avg(vals)
+ return averaged
+
+
+def compute_cross_lang_averages(results: dict, prefix: str | None = None) -> dict:
+ """
+ Given the 'results' dict from a lighteval JSON output, compute cross-language
+ averages for all multilingual per-language-average tasks.
+
+ Returns a dict mapping cross-language-average task name → metric dict.
+ """
+ # Only consider tasks that are per-language averages (contain :_average)
+ average_tasks = {k: v for k, v in results.items() if ":_average" in k}
+
+ if prefix:
+ average_tasks = {k: v for k, v in average_tasks.items() if k.startswith(prefix)}
+
+ # Group by (base_name_with_fewshot) where lang is stripped out
+ groups: dict[str, dict[str, list[float]]] = defaultdict(lambda: defaultdict(list))
+
+ for task_name, metrics in average_tasks.items():
+ # Split off |n fewshot suffix
+ if "|" in task_name:
+ task_no_fs, fs_suffix = task_name.rsplit("|", 1)
+ fs_part = "|" + fs_suffix
+ else:
+ task_no_fs = task_name
+ fs_part = ""
+
+ base_name, lang = _strip_lang(task_no_fs)
+ if base_name is None:
+ continue # not a multilingual task we can aggregate
+
+ group_key = base_name + fs_part
+ for metric_name, metric_val in metrics.items():
+ if isinstance(metric_val, (int, float)) and not (
+ isinstance(metric_val, float) and math.isnan(metric_val)
+ ):
+ groups[group_key][metric_name].append(metric_val)
+
+ # Compute cross-language averages
+ cross_lang: dict[str, dict] = {}
+ for group_key, metric_lists in groups.items():
+ if not metric_lists:
+ continue
+ n_langs = max(len(v) for v in metric_lists.values())
+ if n_langs < 2:
+ continue # only one language — no meaningful cross-lang average
+
+ # Rename key: replace ":_average" with ":cross_lang_average"
+ out_key = group_key.replace(":_average", ":cross_lang_average")
+ cross_lang[out_key] = {}
+ for metric_name, vals in metric_lists.items():
+ if metric_name.endswith("_stderr"):
+ cross_lang[out_key][metric_name] = _propagate_se(vals)
+ else:
+ cross_lang[out_key][metric_name] = _avg(vals)
+ cross_lang[out_key]["_n_languages"] = n_langs
+
+ return cross_lang
+
+
+def compute_translation_aggregates(results: dict, prefix: str | None = None) -> dict:
+ directional_metrics = {}
+ bidirectional_averages = {}
+
+ for task_name, metrics in results.items():
+ if prefix and not task_name.startswith(prefix):
+ continue
+
+ task_no_fs, fs_part = _split_fewshot(task_name)
+
+ directional_match = _TRANSLATION_DIRECTION_RE.match(task_no_fs)
+ if directional_match:
+ key = (
+ directional_match.group("base"),
+ directional_match.group("lang"),
+ fs_part,
+ )
+ directional_metrics.setdefault(key, {})[
+ directional_match.group("direction")
+ ] = metrics
+ continue
+
+ average_match = _TRANSLATION_AVERAGE_RE.match(task_no_fs)
+ if average_match:
+ key = (
+ average_match.group("base"),
+ average_match.group("lang"),
+ fs_part,
+ )
+ bidirectional_averages[key] = metrics
+
+ if not directional_metrics:
+ return {}
+
+ aliases = {}
+ directional_groups: dict[str, dict[str, list[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+ bidirectional_groups: dict[str, dict[str, list[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+
+ for key, per_direction in directional_metrics.items():
+ base, lang, fs_part = key
+
+ for direction, metrics in per_direction.items():
+ alias_key = f"{base}:{lang}:{direction}{fs_part}"
+ aliases[alias_key] = metrics
+ _append_numeric_metrics(
+ directional_groups[f"{base}:{direction}:cross_lang_average{fs_part}"],
+ metrics,
+ )
+
+ average_metrics = bidirectional_averages.get(key)
+ if average_metrics is not None:
+ aliases[f"{base}:{lang}:bidirectional_average{fs_part}"] = average_metrics
+ _append_numeric_metrics(
+ bidirectional_groups[f"{base}:bidirectional_cross_lang_average{fs_part}"],
+ average_metrics,
+ )
+
+ for out_key, metric_lists in directional_groups.items():
+ n_langs = max(len(v) for v in metric_lists.values())
+ if n_langs < 2:
+ continue
+ aliases[out_key] = _average_metric_lists(metric_lists)
+ aliases[out_key]["_n_languages"] = n_langs
+
+ for out_key, metric_lists in bidirectional_groups.items():
+ n_langs = max(len(v) for v in metric_lists.values())
+ if n_langs < 2:
+ continue
+ aliases[out_key] = _average_metric_lists(metric_lists)
+ aliases[out_key]["_n_languages"] = n_langs
+
+ return aliases
+
+
+def main():
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
+ parser.add_argument("results_file", help="Path to lighteval results JSON file")
+ parser.add_argument("--prefix", default=None, help="Filter tasks by prefix (e.g. 'global_mmlu')")
+ parser.add_argument("--output", default=None, help="Write output to this JSON file (default: stdout)")
+ args = parser.parse_args()
+
+ with open(args.results_file) as f:
+ data = json.load(f)
+
+ # Support both the raw results dict and the full lighteval JSON format
+ if "results" in data:
+ results = data["results"]
+ else:
+ results = data
+
+ aggregated = {}
+ aggregated.update(compute_cross_lang_averages(results, prefix=args.prefix))
+ aggregated.update(compute_translation_aggregates(results, prefix=args.prefix))
+
+ if not aggregated:
+ print("No aggregate outputs found.", file=sys.stderr)
+ sys.exit(0)
+
+ out = json.dumps(aggregated, indent=2, ensure_ascii=False)
+ if args.output:
+ Path(args.output).write_text(out)
+ print(f"Written {len(aggregated)} aggregate outputs to {args.output}", file=sys.stderr)
+ else:
+ print(out)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py
index 7309ea552..3cbf5c693 100644
--- a/src/lighteval/logging/info_loggers.py
+++ b/src/lighteval/logging/info_loggers.py
@@ -364,9 +364,10 @@ def aggregate(self, task_dict: dict[str, LightevalTask], bootstrap_iters: int =
)
else:
stderr = get_stderr_function(aggregation=aggregation, number_experiments=bootstrap_iters)
- if stderr is not None and len(metric_values) > 1:
+ stderr_values = [v for v in metric_values if v is not None]
+ if stderr is not None and len(stderr_values) > 1:
try:
- self.metric_aggregated[task_name][f"{metric_name}_stderr"] = stderr(metric_values)
+ self.metric_aggregated[task_name][f"{metric_name}_stderr"] = stderr(stderr_values)
except OverflowError:
# Is this need or should we just pass?
self.metric_aggregated[task_name][f"{metric_name}_stderr"] = float("nan")
diff --git a/src/lighteval/metrics/dynamic_metrics.py b/src/lighteval/metrics/dynamic_metrics.py
index 66ed91c3a..4de1c33ad 100644
--- a/src/lighteval/metrics/dynamic_metrics.py
+++ b/src/lighteval/metrics/dynamic_metrics.py
@@ -33,6 +33,7 @@
Probability,
)
from lighteval.metrics.normalizations import (
+ LogProbCharNorm,
LogProbNormalization,
LogProbTokenNorm,
get_multilingual_normalizer,
@@ -45,10 +46,11 @@
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparison import compare_gold_target
-from lighteval.metrics.utils.metric_utils import SampleLevelComputation, SampleLevelMetric
+from lighteval.metrics.utils.metric_utils import SampleLevelComputation, SampleLevelMetric, SampleLevelMetricGrouping
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.requests import Doc, SamplingMethod
from lighteval.utils.language import Language
+from lighteval.utils.utils import as_list
from lighteval.utils.timeout import timeout
@@ -266,3 +268,120 @@ def compute(self, doc: Doc, model_response: ModelResponse) -> float:
for pred in extracted_predictions
]
)
+
+
+# ---------------------------------------------------------------------------
+# MMLU per-category accuracy metric
+# ---------------------------------------------------------------------------
+import math as _math
+
+_MMLU_STEM = frozenset([
+ "abstract_algebra", "anatomy", "astronomy", "college_biology", "college_chemistry",
+ "college_computer_science", "college_mathematics", "college_medicine", "college_physics",
+ "computer_security", "conceptual_physics", "electrical_engineering", "elementary_mathematics",
+ "high_school_biology", "high_school_chemistry", "high_school_computer_science",
+ "high_school_mathematics", "high_school_physics", "high_school_statistics", "machine_learning",
+])
+_MMLU_HUMANITIES = frozenset([
+ "formal_logic", "high_school_european_history", "high_school_us_history",
+ "high_school_world_history", "international_law", "jurisprudence", "logical_fallacies",
+ "moral_disputes", "moral_scenarios", "philosophy", "prehistory", "professional_law",
+ "world_religions",
+])
+_MMLU_SOCIAL = frozenset([
+ "econometrics", "high_school_geography", "high_school_government_and_politics",
+ "high_school_macroeconomics", "high_school_microeconomics", "high_school_psychology",
+ "human_sexuality", "professional_psychology", "public_relations", "security_studies",
+ "sociology", "us_foreign_policy",
+])
+# "Other" = any subject not in the above three sets
+
+_MMLU_CATS = ("stem", "humanities", "social", "other")
+
+
+def _subject_to_mmlu_cat(subject: str) -> str:
+ s = subject.lower()
+ if s in _MMLU_STEM: return "stem"
+ if s in _MMLU_HUMANITIES: return "humanities"
+ if s in _MMLU_SOCIAL: return "social"
+ return "other"
+
+
+def _mean_notnone(vals):
+ """Mean of non-None values; returns nan if all are None."""
+ filtered = [v for v in vals if v is not None]
+ return float(np.mean(filtered)) if filtered else float("nan")
+
+
+class _MMLUCategoryFn(SampleLevelComputation):
+ """Computes acc, acc_norm (char), and optionally per-sample BPB, routed to
+ the correct MMLU category (stem/humanities/social/other) via
+ ``doc.specific["subject"]``."""
+
+ def __init__(self, include_bpb: bool = False):
+ self._acc_fn = LoglikelihoodAcc(logprob_normalization=None)
+ self._acc_norm_fn = LoglikelihoodAcc(logprob_normalization=LogProbCharNorm())
+ self.include_bpb = include_bpb
+
+ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict:
+ acc = self._acc_fn.compute(doc=doc, model_response=model_response)
+ acc_norm = self._acc_norm_fn.compute(doc=doc, model_response=model_response)
+
+ subject = (doc.specific or {}).get("subject", "")
+ cat = _subject_to_mmlu_cat(subject)
+
+ result: dict = {"acc": acc, "acc_norm": acc_norm}
+ for c in _MMLU_CATS:
+ result[f"acc_{c}"] = acc if c == cat else None
+ result[f"acc_norm_{c}"] = acc_norm if c == cat else None
+
+ if self.include_bpb:
+ gold_ix = as_list(doc.gold_index)[0]
+ gold_logprob = model_response.logprobs[gold_ix]
+ gold_bytes = max(len(doc.choices[gold_ix].encode("utf-8")), 1)
+ bpb = -gold_logprob / _math.log(2) / gold_bytes
+ result["bpb"] = bpb
+ for c in _MMLU_CATS:
+ result[f"bpb_{c}"] = bpb if c == cat else None
+
+ return result
+
+
+def _make_mmlu_category_grouping(include_bpb: bool) -> SampleLevelMetricGrouping:
+ keys_acc = (
+ ["acc", "acc_norm"]
+ + [f"acc_{c}" for c in _MMLU_CATS]
+ + [f"acc_norm_{c}" for c in _MMLU_CATS]
+ )
+ keys_bpb = ["bpb"] + [f"bpb_{c}" for c in _MMLU_CATS] if include_bpb else []
+ all_keys = keys_acc + keys_bpb
+
+ corpus_fns: dict = {
+ "acc": np.mean,
+ "acc_norm": np.mean,
+ **{f"acc_{c}": _mean_notnone for c in _MMLU_CATS},
+ **{f"acc_norm_{c}": _mean_notnone for c in _MMLU_CATS},
+ }
+ if include_bpb:
+ corpus_fns["bpb"] = np.mean
+ corpus_fns.update({f"bpb_{c}": _mean_notnone for c in _MMLU_CATS})
+
+ higher_is_better = {
+ **{k: True for k in keys_acc},
+ **{k: False for k in keys_bpb},
+ }
+
+ return SampleLevelMetricGrouping(
+ metric_name=all_keys,
+ sample_level_fn=_MMLUCategoryFn(include_bpb=include_bpb),
+ corpus_level_fn=corpus_fns,
+ category=SamplingMethod.LOGPROBS,
+ higher_is_better=higher_is_better,
+ )
+
+
+MMLUCategoryGroupingCF = _make_mmlu_category_grouping(include_bpb=True)
+"""CF variant: reports acc, acc_norm, bpb — overall and per MMLU category (stem/humanities/social/other)."""
+
+MMLUCategoryGroupingMCF = _make_mmlu_category_grouping(include_bpb=False)
+"""MCF variant: reports acc, acc_norm — overall and per MMLU category (no BPB for label-token scoring)."""
diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py
index c1164b9a3..41053de70 100644
--- a/src/lighteval/metrics/metrics.py
+++ b/src/lighteval/metrics/metrics.py
@@ -32,6 +32,7 @@
from lighteval.metrics.harness_compatibility.drop import DropMetrics
from lighteval.metrics.harness_compatibility.truthful_qa import TruthfulqaMCMetrics
from lighteval.metrics.metrics_corpus import (
+ CorpusLevelCOMETMetric,
CorpusLevelF1Score,
CorpusLevelPerplexityMetric,
CorpusLevelTranslationMetric,
@@ -66,6 +67,7 @@
remove_braces_and_strip,
)
from lighteval.metrics.sample_preparator import (
+ COMETPreparator,
GenerativePreparator,
LoglikelihoodPreparator,
PerplexityPreparator,
@@ -290,6 +292,13 @@ class Metrics(Enum):
corpus_level_fn=CorpusLevelTranslationMetric("chrf++"),
higher_is_better=True,
)
+ comet22 = CorpusLevelMetric(
+ metric_name="comet22",
+ sample_level_fn=COMETPreparator(),
+ category=SamplingMethod.GENERATIVE,
+ corpus_level_fn=CorpusLevelCOMETMetric("Unbabel/wmt22-comet-da"),
+ higher_is_better=True,
+ )
copyright = SampleLevelMetricGrouping(
metric_name=[
"longest_common_prefix_length",
diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py
index 92c2c574a..7b4b9b4cf 100644
--- a/src/lighteval/metrics/metrics_corpus.py
+++ b/src/lighteval/metrics/metrics_corpus.py
@@ -35,6 +35,7 @@
import sklearn.metrics
from lighteval.metrics.sample_preparator import (
+ COMETCorpusMetricInput,
GenerativeCorpusMetricInput,
LogprobCorpusMetricInput,
PerplexityCorpusMetricInput,
@@ -126,9 +127,6 @@ def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
def get_metric(self):
if self.metric_type == "bleu":
- import nltk
-
- nltk.download("punkt_tab")
return sacrebleu.BLEU(trg_lang=self.lang)
elif self.metric_type == "chrf":
return sacrebleu.CHRF()
@@ -190,3 +188,52 @@ def compute_corpus(self, items: list[PerplexityCorpusMetricInput]):
return math.exp(-sum(logprobs) / sum(weights))
if self.metric_type == "bits_per_byte":
return -sum(logprobs) / sum(weights) * 1 / math.log(2)
+
+
+class CorpusLevelCOMETMetric(CorpusLevelComputation):
+ """Corpus-level COMET-22 metric using the Unbabel/wmt22-comet-da model.
+
+ Requires `unbabel-comet`: pip install lighteval[multilingual] (or pip install unbabel-comet).
+
+ The COMET model is loaded once per process (class-level cache) and reused across all
+ translation language pairs in the same evaluation run.
+
+ On multi-GPU nodes, all available GPUs are used automatically via the `gpus` parameter
+ of unbabel-comet's predict() API. batch_size_per_gpu=256 is tuned for H100 96GB;
+ reduce to 64 for A100 40GB or 32 for 16GB GPUs.
+ """
+
+ _model = None # class-level cache: loaded once per process
+
+ def __init__(self, model_name: str = "Unbabel/wmt22-comet-da", batch_size_per_gpu: int = 256):
+ self.model_name = model_name
+ self.batch_size_per_gpu = batch_size_per_gpu
+
+ def _load_model(self):
+ if CorpusLevelCOMETMetric._model is None:
+ try:
+ from comet import download_model, load_from_checkpoint
+ except ImportError:
+ raise ImportError(
+ "COMET metric requires `unbabel-comet`. "
+ "Install with: pip install lighteval[multilingual]"
+ )
+ CorpusLevelCOMETMetric._model = load_from_checkpoint(download_model(self.model_name))
+ return CorpusLevelCOMETMetric._model
+
+ def compute_corpus(self, items: list[COMETCorpusMetricInput]) -> float:
+ import torch
+
+ model = self._load_model()
+ data = [
+ {"src": i.source, "mt": i.hyp, "ref": i.ref[0] if i.ref else ""}
+ for i in items
+ ]
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
+ output = model.predict(
+ data,
+ batch_size=self.batch_size_per_gpu,
+ gpus=num_gpus,
+ progress_bar=False,
+ )
+ return float(np.mean(output.scores))
diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py
index 5a1176b40..4863d7239 100644
--- a/src/lighteval/metrics/metrics_sample.py
+++ b/src/lighteval/metrics/metrics_sample.py
@@ -35,7 +35,6 @@
from huggingface_hub import HfApi
from nltk.metrics.distance import edit_distance
from nltk.tokenize import word_tokenize
-from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.translate.bleu_score import sentence_bleu
from pydantic import BaseModel
from scipy.stats import hypergeom
@@ -51,6 +50,7 @@
remove_braces,
remove_braces_and_strip,
)
+from lighteval.metrics.utils.nltk_resources import ensure_nltk_resource
from lighteval.metrics.utils.judge_utils import (
get_judge_prompt_simpleqa,
process_judge_response_simpleqa,
@@ -870,9 +870,7 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs):
Returns:
float: Score over the current sample's items.
"""
- import nltk
-
- nltk.download("punkt_tab")
+ ensure_nltk_resource("tokenizers/punkt_tab", "punkt_tab")
golds = doc.get_golds()
predictions = model_response.final_text
return np.mean([self._bleu_score(golds, p) for p in predictions])
@@ -889,7 +887,9 @@ def _bleu_score(self, gold: list[str], pred: str):
"""
weights = [1 if ix == self.n_gram else 0 for ix in range(1, 5)]
return sentence_bleu(
- [word_tokenize(g) for g in gold], word_tokenize(pred), weights=weights
+ [word_tokenize(g) for g in gold],
+ word_tokenize(pred),
+ weights=weights,
)
diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py
index 277dde993..a80a1e573 100644
--- a/src/lighteval/metrics/sample_preparator.py
+++ b/src/lighteval/metrics/sample_preparator.py
@@ -60,6 +60,13 @@ class PerplexityCorpusMetricInput(CorpusMetricInput):
weights: list[int]
+@dataclass
+class COMETCorpusMetricInput(CorpusMetricInput):
+ source: str
+ hyp: str
+ ref: list[str]
+
+
class Preparator:
pass
@@ -93,6 +100,19 @@ def __str__(self):
return f"{self.__class__.__name__}({', '.join(attr_strs)})"
+class COMETPreparator(Preparator):
+ @staticmethod
+ def prepare(doc: Doc, model_response: ModelResponse, **kwargs) -> "COMETCorpusMetricInput":
+ """Prepares a translation example for COMET scoring.
+
+ Reads the source sentence from doc.specific["source_text"] (set by the translation
+ template), combines it with the model's hypothesis and the gold reference(s).
+ """
+ source = (doc.specific or {}).get("source_text", "")
+ golds = as_list(doc.get_golds())
+ return COMETCorpusMetricInput(source=source, hyp=model_response.final_text, ref=golds)
+
+
class LoglikelihoodPreparator(Preparator):
def __init__(self, is_single_token: bool = False):
"""Init.
diff --git a/src/lighteval/metrics/utils/nltk_resources.py b/src/lighteval/metrics/utils/nltk_resources.py
new file mode 100644
index 000000000..bf0a9a474
--- /dev/null
+++ b/src/lighteval/metrics/utils/nltk_resources.py
@@ -0,0 +1,28 @@
+import nltk
+
+
+def ensure_nltk_resource(resource_path: str, package_name: str) -> None:
+ try:
+ nltk.data.find(resource_path)
+ return
+ except LookupError:
+ pass
+
+ error_message = (
+ f"Required NLTK resource '{package_name}' is not available locally. "
+ f"Tried to resolve '{resource_path}' and attempted an automatic download, but it failed. "
+ "Download it in advance on a machine with network access and set NLTK_DATA to the cached directory."
+ )
+
+ try:
+ downloaded = nltk.download(package_name, quiet=True)
+ except Exception as exc:
+ raise RuntimeError(error_message) from exc
+
+ if not downloaded:
+ raise RuntimeError(error_message)
+
+ try:
+ nltk.data.find(resource_path)
+ except LookupError as exc:
+ raise RuntimeError(error_message) from exc
diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py
index 2a59bc6d0..736e06a38 100644
--- a/src/lighteval/models/transformers/transformers_model.py
+++ b/src/lighteval/models/transformers/transformers_model.py
@@ -707,7 +707,12 @@ def _padded_greedy_until(
# NOTE: we are assuming all items in a batch behave similarly (same
# stop_tokens and max_tokens genrated) which is not necessarily
# the case! Because of that we only use batch size of 1
- stop_tokens = [self.tokenizer.eos_token] + batch[0].stop_sequences if len(batch[0].stop_sequences) > 0 else [self.tokenizer.eos_token]
+ batch_stop_sequences = list(batch[0].stop_sequences)
+ stop_tokens = (
+ [self.tokenizer.eos_token] + batch_stop_sequences
+ if len(batch_stop_sequences) > 0
+ else [self.tokenizer.eos_token]
+ )
max_new_tokens = batch[0].generation_size
num_samples = batch[0].num_samples
diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py
index 04629379c..3f45315ba 100644
--- a/src/lighteval/tasks/lighteval_task.py
+++ b/src/lighteval/tasks/lighteval_task.py
@@ -21,13 +21,18 @@
# SOFTWARE.
import functools
+import json
import logging
+import os
import random
+import tarfile
+import urllib.request
from dataclasses import asdict, dataclass, field
+from pathlib import Path, PurePosixPath
from typing import Callable
-from datasets import DatasetDict, load_dataset
-from huggingface_hub import TextGenerationInputGrammarType
+from datasets import Dataset, DatasetDict, load_dataset
+from huggingface_hub import TextGenerationInputGrammarType, hf_hub_download, list_repo_files
from inspect_ai.dataset import Sample
from multiprocess import Pool
from pytablewriter import MarkdownTableWriter
@@ -45,6 +50,249 @@
logger = logging.getLogger(__name__)
+def _get_manual_dataset_cache_dir() -> Path:
+ hf_home = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+ cache_dir = hf_home / "lighteval_manual_datasets"
+ cache_dir.mkdir(parents=True, exist_ok=True)
+ return cache_dir
+
+
+def _matches_hub_split_file(path: str, split: str, ext: str, config_name: str | None) -> bool:
+ posix_path = PurePosixPath(path)
+ filename = posix_path.name
+ if not filename.endswith(ext):
+ return False
+
+ if filename.startswith(f"{split}-") or filename == f"{split}{ext}":
+ if config_name:
+ return posix_path.parent.as_posix() == config_name
+ return posix_path.parent.as_posix() in (".", "")
+
+ if not config_name:
+ return False
+
+ if filename == f"{config_name}_{split}{ext}" and posix_path.parent.as_posix() in (
+ "data",
+ ".",
+ "",
+ ):
+ return True
+
+ return False
+
+
+def _load_hub_raw_dataset_files(
+ dataset_path: str,
+ config_name: str | None,
+ splits: list[str],
+ *,
+ revision: str | None = None,
+ local_files_only: bool = False,
+) -> DatasetDict | None:
+ repo_files = list_repo_files(
+ repo_id=dataset_path,
+ repo_type="dataset",
+ revision=revision,
+ )
+
+ data_files: dict[str, list[str]] = {}
+ data_format: str | None = None
+ for split in splits:
+ matched_files = []
+ matched_format = None
+ for ext, fmt in ((".parquet", "parquet"), (".jsonl", "json"), (".json", "json")):
+ matched_files = [
+ hf_hub_download(
+ repo_id=dataset_path,
+ repo_type="dataset",
+ filename=repo_file,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ for repo_file in repo_files
+ if _matches_hub_split_file(repo_file, split, ext, config_name)
+ ]
+ if matched_files:
+ matched_format = fmt
+ break
+
+ if not matched_files or matched_format is None:
+ return None
+
+ if data_format is None:
+ data_format = matched_format
+ elif data_format != matched_format:
+ raise ValueError(
+ f"Inconsistent raw file formats for dataset {dataset_path}: {data_format} vs {matched_format}"
+ )
+
+ data_files[split] = sorted(matched_files)
+
+ if not data_files or data_format is None:
+ return None
+
+ return load_dataset(data_format, data_files=data_files)
+
+
+def _load_mgsm_dataset(config_name: str) -> DatasetDict:
+ import csv
+ import importlib.util
+
+ from huggingface_hub import hf_hub_download
+
+ exemplars_path = hf_hub_download(
+ repo_id="juletxara/mgsm",
+ repo_type="dataset",
+ filename="exemplars.py",
+ )
+ tsv_path = hf_hub_download(
+ repo_id="juletxara/mgsm",
+ repo_type="dataset",
+ filename=f"mgsm_{config_name}.tsv",
+ )
+
+ spec = importlib.util.spec_from_file_location("mgsm_exemplars", exemplars_path)
+ exemplars_module = importlib.util.module_from_spec(spec)
+ assert spec.loader is not None
+ spec.loader.exec_module(exemplars_module)
+
+ train_rows = []
+ examples = exemplars_module.MGSM_EXEMPLARS[config_name]
+ number_answers = exemplars_module.EXEMPLAR_NUMBER_ANSWERS
+ equation_solutions = exemplars_module.EXEMPLAR_EQUATION_SOLUTIONS
+ for key, data in examples.items():
+ idx = int(key) - 1
+ train_rows.append(
+ {
+ "question": data["q"],
+ "answer": data["a"],
+ "answer_number": number_answers[idx],
+ "equation_solution": equation_solutions[idx],
+ }
+ )
+
+ test_rows = []
+ with open(tsv_path, encoding="utf-8") as csv_file:
+ csv_reader = csv.reader(csv_file, quotechar='"', delimiter="\t")
+ for row in csv_reader:
+ test_rows.append(
+ {
+ "question": row[0],
+ "answer": None,
+ "answer_number": int(row[1].replace(",", "")),
+ "equation_solution": None,
+ }
+ )
+
+ return DatasetDict(
+ {
+ "train": Dataset.from_list(train_rows),
+ "test": Dataset.from_list(test_rows),
+ }
+ )
+
+
+def _load_wmt24pp_dataset(config_name: str, local_files_only: bool = False) -> DatasetDict:
+ from huggingface_hub import hf_hub_download
+
+ jsonl_path = hf_hub_download(
+ repo_id="google/wmt24pp",
+ repo_type="dataset",
+ filename=f"{config_name}.jsonl",
+ local_files_only=local_files_only,
+ )
+
+ rows = []
+ with open(jsonl_path, encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ rows.append(json.loads(line))
+
+ return DatasetDict({"train": Dataset.from_list(rows)})
+
+
+def _load_flores_dataset(config_name: str, local_files_only: bool = False) -> DatasetDict:
+ cache_dir = _get_manual_dataset_cache_dir() / "flores"
+ cache_dir.mkdir(parents=True, exist_ok=True)
+
+ archive_path = cache_dir / "flores200_dataset.tar.gz"
+ extract_dir = cache_dir / "flores200_dataset"
+
+ if not extract_dir.exists():
+ if not archive_path.exists():
+ if local_files_only:
+ raise FileNotFoundError(
+ f"FLORES archive not found in local cache: {archive_path}"
+ )
+ urllib.request.urlretrieve(
+ "https://dl.fbaipublicfiles.com/nllb/flores200_dataset.tar.gz",
+ archive_path,
+ )
+
+ with tarfile.open(archive_path, "r:gz") as tar:
+ tar.extractall(cache_dir)
+
+ if "-" in config_name:
+ langs = config_name.split("-", 1)
+ else:
+ langs = [config_name]
+
+ def _sentence_path(split: str, lang: str) -> Path:
+ return extract_dir / split / f"{lang}.{split}"
+
+ def _metadata_path(split: str) -> Path:
+ return extract_dir / f"metadata_{split}.tsv"
+
+ def _build_split(split: str) -> Dataset:
+ with open(_metadata_path(split), encoding="utf-8") as meta_file:
+ metadata_lines = [line.rstrip("\n") for line in meta_file]
+ header, rows = metadata_lines[0], metadata_lines[1:]
+ if header.split("\t")[:5] != [
+ "URL",
+ "domain",
+ "topic",
+ "has_image",
+ "has_hyperlink",
+ ]:
+ raise ValueError(f"Unexpected FLORES metadata header: {header}")
+
+ sentences_by_lang: dict[str, list[str]] = {}
+ for lang in langs:
+ with open(_sentence_path(split, lang), encoding="utf-8") as sentence_file:
+ sentences_by_lang[lang] = [
+ line.rstrip("\n") for line in sentence_file
+ ]
+
+ data = []
+ for idx, metadata in enumerate(rows, start=1):
+ url, domain, topic, has_image, has_hyperlink = metadata.split("\t")
+ row = {
+ "id": idx,
+ "URL": url,
+ "domain": domain,
+ "topic": topic,
+ "has_image": 1 if has_image == "yes" else 0,
+ "has_hyperlink": 1 if has_hyperlink == "yes" else 0,
+ }
+ if len(langs) == 1:
+ row["sentence"] = sentences_by_lang[langs[0]][idx - 1]
+ else:
+ for lang in langs:
+ row[f"sentence_{lang}"] = sentences_by_lang[lang][idx - 1]
+ data.append(row)
+
+ return Dataset.from_list(data)
+
+ return DatasetDict(
+ {
+ "dev": _build_split("dev"),
+ "devtest": _build_split("devtest"),
+ }
+ )
+
+
@dataclass
class LightevalTaskConfig:
"""Configuration dataclass for a LightevalTask.
@@ -521,6 +769,42 @@ def download_dataset_worker(
if not (_is_script_err or _is_offline):
raise
+ if _is_script_err and task.dataset_path == "juletxara/mgsm":
+ dataset = _load_mgsm_dataset(task.dataset_config_name)
+ if task.dataset_filter is not None:
+ dataset = dataset.filter(task.dataset_filter)
+ return dataset # type: ignore
+
+ if (_is_script_err or _is_offline) and task.dataset_path == "google/wmt24pp":
+ dataset = _load_wmt24pp_dataset(
+ task.dataset_config_name,
+ local_files_only=_is_offline,
+ )
+ if task.dataset_filter is not None:
+ dataset = dataset.filter(task.dataset_filter)
+ return dataset # type: ignore
+
+ if (_is_script_err or _is_offline) and task.dataset_path == "facebook/flores":
+ dataset = _load_flores_dataset(
+ task.dataset_config_name,
+ local_files_only=_is_offline,
+ )
+ if task.dataset_filter is not None:
+ dataset = dataset.filter(task.dataset_filter)
+ return dataset # type: ignore
+
+ if _is_script_err:
+ dataset = _load_hub_raw_dataset_files(
+ dataset_path=task.dataset_path,
+ config_name=task.dataset_config_name,
+ splits=list(task.config.hf_avail_splits or []),
+ revision=task.dataset_revision,
+ )
+ if dataset is not None:
+ if task.dataset_filter is not None:
+ dataset = dataset.filter(task.dataset_filter)
+ return dataset # type: ignore
+
_splits = list(task.config.hf_avail_splits or [])
if not _splits:
_splits = (
@@ -554,9 +838,24 @@ def download_dataset_worker(
if task.dataset_config_name
else _rev_dir
)
+ def _matches_cached_file(_filename: str, _split: str, _ext: str) -> bool:
+ if not _filename.endswith(_ext):
+ return False
+ if _filename.startswith(f"{_split}-") or _filename == f"{_split}{_ext}":
+ return True
+ return bool(
+ task.dataset_config_name
+ and _filename == f"{task.dataset_config_name}_{_split}{_ext}"
+ )
+
# Use os.listdir (not glob) so symlinks in the hub
# cache are resolved correctly.
- for _search_dir in [_config_dir, _os.path.join(_config_dir, "data")]:
+ for _search_dir in [
+ _config_dir,
+ _os.path.join(_config_dir, "data"),
+ _rev_dir,
+ _os.path.join(_rev_dir, "data"),
+ ]:
if not _os.path.isdir(_search_dir):
continue
try:
@@ -568,11 +867,7 @@ def download_dataset_worker(
_pq = sorted(
_os.path.join(_search_dir, f)
for f in _entries
- if f.endswith(_ext)
- and (
- f.startswith(f"{_split}-")
- or f == f"{_split}{_ext}"
- )
+ if _matches_cached_file(f, _split, _ext)
)
if _pq:
_data_files[_split] = _pq
diff --git a/src/lighteval/tasks/multilingual/tasks/flores200.py b/src/lighteval/tasks/multilingual/tasks/flores200.py
index f6adbef78..3d536bb3b 100644
--- a/src/lighteval/tasks/multilingual/tasks/flores200.py
+++ b/src/lighteval/tasks/multilingual/tasks/flores200.py
@@ -6,7 +6,8 @@
facebook/flores
abstract:
-Flores200 multilingual benchmark.
+Flores200 English-centric translation benchmark.
+Few-shot evaluation uses `dev` for exemplars and `devtest` for scoring.
languages:
arabic, armenian, bengali, cyrillic, devanagari, ethiopic, georgian, greek,
@@ -20,19 +21,16 @@
paper:
"""
-from itertools import permutations
-
-from lighteval.metrics.metrics import Metrics
-from lighteval.tasks.lighteval_task import LightevalTaskConfig
-from lighteval.tasks.templates.translation import get_translation_prompt_function
-from lighteval.tasks.templates.utils.formulation import (
- CFFormulation,
+from lighteval.tasks.multilingual.utils.translation import (
+ TRANSLATION_METRICS,
+ EnglishCentricTranslationConfig,
+ build_english_centric_translation_tasks,
)
from lighteval.utils.language import Language, manage_duplicate_language_codes
flores_200_languages = [
- # "ace_Arab",
+ "ace_Arab",
"ace_Latn",
"acm_Arab",
"acq_Arab",
@@ -43,7 +41,7 @@
"amh_Ethi",
"apc_Arab",
"arb_Arab",
- # "arb_Latn",
+ "arb_Latn",
"ars_Arab",
"ary_Arab",
"arz_Arab",
@@ -60,7 +58,7 @@
"bem_Latn",
"ben_Beng",
"bho_Deva",
- # "bjn_Arab",
+ "bjn_Arab",
"bjn_Latn",
"bod_Tibt",
"bos_Latn",
@@ -115,10 +113,10 @@
"kac_Latn",
"kam_Latn",
"kan_Knda",
- # "kas_Arab",
+ "kas_Arab",
"kas_Deva",
"kat_Geor",
- # "knc_Arab",
+ "knc_Arab",
"knc_Latn",
"kaz_Cyrl",
"kbp_Latn",
@@ -148,7 +146,7 @@
"mai_Deva",
"mal_Mlym",
"mar_Deva",
- # "min_Arab",
+ "min_Arab",
"min_Latn",
"mkd_Cyrl",
"plt_Latn",
@@ -233,11 +231,14 @@
"yor_Latn",
"yue_Hant",
"zho_Hans",
- # "zho_Hant",
+ "zho_Hant",
"zsm_Latn",
"zul_Latn",
]
+_ENGLISH_FLORES_CODE = "eng_Latn"
+_FLORES_METRICS = TRANSLATION_METRICS
+
def flores_adapter(lang1, lang2):
return lambda line: {
@@ -246,25 +247,30 @@ def flores_adapter(lang1, lang2):
}
-TASKS_TABLE = [
- LightevalTaskConfig(
- name=f"flores200:{lang1}-{lang2}",
- prompt_function=get_translation_prompt_function(
- source_language=Language(manage_duplicate_language_codes(lang1.split("_")[0])),
- target_language=Language(manage_duplicate_language_codes(lang2.split("_")[0])),
- adapter=flores_adapter(lang1, lang2),
- formulation=CFFormulation(),
- ),
- hf_repo="facebook/flores",
- hf_subset=f"{lang1}-{lang2}",
- hf_avail_splits=["dev", "devtest"],
- evaluation_splits=["devtest"],
- few_shots_split="dev",
- few_shots_select=None,
- generation_size=300,
- metrics=[Metrics.chrf_plus, Metrics.bleu, Metrics.bleu_1, Metrics.bleu_4],
- stop_sequence=["\n"],
- version=0,
- )
- for (lang1, lang2) in permutations(flores_200_languages, 2)
-]
+def _flores_pair_subset(lang_code: str) -> str:
+ return "-".join(sorted((_ENGLISH_FLORES_CODE, lang_code)))
+
+
+TASKS_TABLE = build_english_centric_translation_tasks(
+ base_name="flores200",
+ hf_repo="facebook/flores",
+ language_configs=[
+ EnglishCentricTranslationConfig(
+ task_id=lang_code,
+ target_language=Language(
+ manage_duplicate_language_codes(lang_code.split("_")[0])
+ ),
+ forward_hf_subset=_flores_pair_subset(lang_code),
+ reverse_hf_subset=_flores_pair_subset(lang_code),
+ forward_adapter=flores_adapter(_ENGLISH_FLORES_CODE, lang_code),
+ reverse_adapter=flores_adapter(lang_code, _ENGLISH_FLORES_CODE),
+ hf_avail_splits=("dev", "devtest"),
+ evaluation_splits=("devtest",),
+ few_shots_split="dev",
+ few_shots_select="random_sampling_from_train",
+ )
+ for lang_code in flores_200_languages
+ if lang_code != _ENGLISH_FLORES_CODE
+ ],
+ metrics=_FLORES_METRICS,
+)
diff --git a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py
index 894f15a3c..6167a5018 100644
--- a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py
+++ b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py
@@ -9,11 +9,19 @@
Translated MMLU using both professional and non-professional translators.
Contains tags for cultural sensitivity.
+Refactored to use unified :cf / :mcf suffixes with 3-part names (base:lang:suffix)
+so group selectors like `global_mmlu:cf|5` expand automatically.
+English is excluded — use mmlu.py for English evaluation.
+
+Per-category metrics (STEM / Humanities / Social / Other) are reported via
+MMLUCategoryGroupingCF / MCF, routed via doc.specific["subject"].
+
languages:
-amharic, arabic, bengali, chinese, czech, dutch, english, french, german,
-hebrew, hindi, indonesian, italian, japanese, korean, malay, norwegian, polish,
-portuguese, romanian, russian, serbian, spanish, swahili, swedish, tamil,
-telugu, thai, turkish, ukrainian, urdu, vietnamese, yoruba, zulu
+amharic, arabic, bengali, chinese, czech, dutch, french, german, greek,
+hausa, hebrew, hindi, igbo, indonesian, italian, japanese, kirghiz, korean,
+lithuanian, malagasy, malay, nepali, nyanja, persian, polish, portuguese,
+romanian, russian, serbian, shona, sinhala, somali, spanish, swahili,
+swedish, tagalog, telugu, turkish, ukrainian, vietnamese, yoruba
tags:
knowledge, multilingual, multiple-choice
@@ -22,158 +30,149 @@
https://huggingface.co/papers/2412.03304
"""
-from functools import partial
from string import ascii_uppercase
from langcodes import standardize_tag
-from lighteval.metrics.dynamic_metrics import (
- LogLikelihoodAccMetric,
-)
-from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm
+from lighteval.metrics.dynamic_metrics import MMLUCategoryGroupingCF, MMLUCategoryGroupingMCF
+from lighteval.metrics.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
-from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation
+from lighteval.tasks.requests import Doc
from lighteval.tasks.templates.multichoice import get_mcq_prompt_function
-from lighteval.tasks.templates.utils.formulation import (
- MCFFormulation,
-)
+from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
from lighteval.utils.language import Language
-MMLU_SUBSETS = [
- "abstract_algebra",
- "anatomy",
- "astronomy",
- "business_ethics",
- "clinical_knowledge",
- "college_biology",
- "college_chemistry",
- "college_computer_science",
- "college_mathematics",
- "college_medicine",
- "college_physics",
- "computer_security",
- "conceptual_physics",
- "econometrics",
- "electrical_engineering",
- "elementary_mathematics",
- "formal_logic",
- "global_facts",
- "high_school_biology",
- "high_school_chemistry",
- "high_school_computer_science",
- "high_school_european_history",
- "high_school_geography",
- "high_school_government_and_politics",
- "high_school_macroeconomics",
- "high_school_mathematics",
- "high_school_microeconomics",
- "high_school_physics",
- "high_school_psychology",
- "high_school_statistics",
- "high_school_us_history",
- "high_school_world_history",
- "human_aging",
- "human_sexuality",
- "international_law",
- "jurisprudence",
- "logical_fallacies",
- "machine_learning",
- "management",
- "marketing",
- "medical_genetics",
- "miscellaneous",
- "moral_disputes",
- "moral_scenarios",
- "nutrition",
- "philosophy",
- "prehistory",
- "professional_accounting",
- "professional_law",
- "professional_medicine",
- "professional_psychology",
- "public_relations",
- "security_studies",
- "sociology",
- "us_foreign_policy",
- "virology",
- "world_religions",
+# All available non-English configs in Global-MMLU. The dataset exposes
+# language configs as BCP-47-style short tags (for example, `fa`, `fil`, `sw`)
+# so the task names stay on Language enum values while hf_subset uses
+# standardize_tag(language.value).
+_LANGUAGES = [
+ Language.AMHARIC,
+ Language.ARABIC,
+ Language.BENGALI,
+ Language.CHINESE,
+ Language.CZECH,
+ Language.DUTCH,
+ Language.FRENCH,
+ Language.GERMAN,
+ Language.GREEK,
+ Language.HAUSA,
+ Language.HEBREW,
+ Language.HINDI,
+ Language.IGBO,
+ Language.INDONESIAN,
+ Language.ITALIAN,
+ Language.JAPANESE,
+ Language.KIRGHIZ,
+ Language.KOREAN,
+ Language.LITHUANIAN,
+ Language.MALAGASY,
+ Language.MALAY,
+ Language.NEPALI,
+ Language.NYANJA,
+ Language.PERSIAN,
+ Language.SPANISH,
+ Language.POLISH,
+ Language.PORTUGUESE,
+ Language.ROMANIAN,
+ Language.RUSSIAN,
+ Language.SERBIAN,
+ Language.SHONA,
+ Language.SINHALA,
+ Language.SOMALI,
+ Language.SWEDISH,
+ Language.SWAHILI,
+ Language.TAGALOG,
+ Language.TELUGU,
+ Language.TURKISH,
+ Language.UKRAINIAN,
+ Language.VIETNAMESE,
+ Language.YORUBA,
]
+def _make_global_mmlu_cf_prompt(language: Language):
+ """CF prompt: score each full answer text. Sets doc.specific with subject."""
+ inner = get_mcq_prompt_function(language, _cf_adapter, formulation=CFFormulation())
+
+ def prompt_fn(line, task_name=None):
+ doc = inner(line, task_name)
+ if doc is not None:
+ doc.specific = {"subject": (line.get("subject") or "").lower()}
+ return doc
+
+ return prompt_fn
+
+
+def _make_global_mmlu_mcf_prompt(language: Language):
+ """MCF prompt: score label tokens A/B/C/D. Sets doc.specific with subject."""
+ inner = get_mcq_prompt_function(language, _mcf_adapter, formulation=MCFFormulation())
+
+ def prompt_fn(line, task_name=None):
+ doc = inner(line, task_name)
+ if doc is not None:
+ doc.specific = {"subject": (line.get("subject") or "").lower()}
+ return doc
+
+ return prompt_fn
+
+
+def _cf_adapter(line):
+ choices = [line["option_a"], line["option_b"], line["option_c"], line["option_d"]]
+ if any(c is None or not str(c).strip() for c in choices):
+ return None
+ return {
+ "question": line["question"],
+ "choices": choices,
+ "gold_idx": ascii_uppercase.index(line["answer"]),
+ }
+
+
+def _mcf_adapter(line):
+ choices = [line["option_a"], line["option_b"], line["option_c"], line["option_d"]]
+ if any(c is None or not str(c).strip() for c in choices):
+ return None
+ return {
+ "question": line["question"],
+ "choices": choices,
+ "gold_idx": ascii_uppercase.index(line["answer"]),
+ }
+
+
+_MMLU_CF_METRICS = [MMLUCategoryGroupingCF]
+_MMLU_MCF_METRICS = [MMLUCategoryGroupingMCF]
+
+
TASKS_TABLE = [
LightevalTaskConfig(
- name=f"global_mmlu_{sensitivity_label.lower()}_{language.value}_{formulation.name.lower()}:{subset}",
- prompt_function=get_mcq_prompt_function(
- language,
- lambda line: {
- "question": line["question"],
- "choices": [line["option_a"], line["option_b"], line["option_c"], line["option_d"]],
- "gold_idx": ascii_uppercase.index(line["answer"]),
- },
- formulation=formulation,
- ),
+ name=f"global_mmlu:{language.value}:{suffix}",
+ prompt_function=prompt_fn(language),
hf_repo="CohereForAI/Global-MMLU",
hf_subset=standardize_tag(language.value),
evaluation_splits=("test",),
few_shots_split="dev",
- hf_filter=partial(
- lambda subset, sensitivity_label, x: x["subject"].lower() == subset
- and (
- sensitivity_label == "ALL" or sensitivity_label in x["cultural_sensitivity_label"].replace("-", "UNK")
- )
- and all(x[f"option_{opt}"] is not None and x[f"option_{opt}"].strip() for opt in "abcd"),
- subset,
- sensitivity_label,
- ),
- metrics=get_metrics_for_formulation(
- formulation,
- [
- LogLikelihoodAccMetric(normalization=LogProbTokenNorm()),
- LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
- LogLikelihoodAccMetric(normalization=LogProbPMINorm()),
- ],
- ),
+ metrics=metrics,
)
- for subset in MMLU_SUBSETS
- for language in [
- Language.AMHARIC,
- Language.ARABIC,
- Language.BENGALI,
- Language.CHINESE,
- Language.CZECH,
- Language.GERMAN,
- Language.ENGLISH,
- Language.SPANISH,
- Language.FRENCH,
- Language.HEBREW,
- Language.HINDI,
- Language.INDONESIAN,
- Language.ITALIAN,
- Language.JAPANESE,
- Language.KOREAN,
- Language.MALAY,
- Language.DUTCH,
- Language.NORWEGIAN,
- Language.POLISH,
- Language.PORTUGUESE,
- Language.ROMANIAN,
- Language.RUSSIAN,
- Language.SERBIAN,
- Language.SWEDISH,
- Language.SWAHILI,
- Language.TAMIL,
- Language.TELUGU,
- Language.THAI,
- Language.TURKISH,
- Language.UKRAINIAN,
- Language.URDU,
- Language.VIETNAMESE,
- Language.YORUBA,
- Language.ZULU,
+ for language in _LANGUAGES
+ for suffix, prompt_fn, metrics in [
+ ("cf", _make_global_mmlu_cf_prompt, _MMLU_CF_METRICS),
+ ("mcf", _make_global_mmlu_mcf_prompt, _MMLU_MCF_METRICS),
]
- for formulation in [
- MCFFormulation(),
- ]
- for sensitivity_label in ["ALL"]
+]
+
+# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match
+TASKS_TABLE += [
+ LightevalTaskConfig(
+ name=f"global_mmlu:{language.value}:mcf_em",
+ prompt_function=_make_global_mmlu_mcf_prompt(language),
+ hf_repo="CohereForAI/Global-MMLU",
+ hf_subset=standardize_tag(language.value),
+ evaluation_splits=("test",),
+ few_shots_split="dev",
+ generation_size=5,
+ stop_sequence=["\n"],
+ metrics=[Metrics.exact_match],
+ )
+ for language in _LANGUAGES
]
diff --git a/src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py b/src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py
new file mode 100644
index 000000000..1de77591d
--- /dev/null
+++ b/src/lighteval/tasks/multilingual/tasks/global_mmlu_lite.py
@@ -0,0 +1,131 @@
+"""
+name:
+Global Mmlu Lite
+
+dataset:
+CohereLabs/Global-MMLU-Lite
+
+abstract:
+A lighter, culturally-annotated subset of MMLU covering 18 languages total;
+17 non-English languages evaluated here (400 test samples and 215 dev samples
+per language across 43 subjects). Designed for quick multilingual MMLU-style
+evaluation.
+
+English is excluded — use mmlu.py for English evaluation.
+
+Metrics:
+- :cf — completion formulation; reports acc, acc_norm (char), target_bpb
+- :mcf — multiple-choice formulation; reports acc, acc_norm (char)
+
+languages:
+arabic, bengali, welsh, german, spanish, french, hindi, indonesian, italian,
+japanese, korean, burmese, portuguese, albanian, swahili, yoruba, chinese
+
+tags:
+knowledge, multilingual, multiple-choice
+
+paper:
+https://huggingface.co/datasets/CohereLabs/Global-MMLU-Lite
+"""
+
+from string import ascii_uppercase
+
+from langcodes import standardize_tag
+
+from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
+from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.normalizations import LogProbCharNorm
+from lighteval.tasks.lighteval_task import LightevalTaskConfig
+from lighteval.tasks.templates.multichoice import get_mcq_prompt_function
+from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
+from lighteval.utils.language import Language
+
+
+_CF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+ Metrics.target_bits_per_byte,
+]
+
+_MCF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+]
+
+# 18 languages in Global-MMLU-Lite, English excluded
+_LANGUAGES = [
+ Language.ARABIC, # ar
+ Language.BENGALI, # bn
+ Language.WELSH, # cy
+ Language.GERMAN, # de
+ Language.SPANISH, # es
+ Language.FRENCH, # fr
+ Language.HINDI, # hi
+ Language.INDONESIAN, # id
+ Language.ITALIAN, # it
+ Language.JAPANESE, # ja
+ Language.KOREAN, # ko
+ Language.BURMESE, # my
+ Language.PORTUGUESE, # pt
+ Language.ALBANIAN, # sq
+ Language.SWAHILI, # sw
+ Language.YORUBA, # yo
+ Language.CHINESE, # zh
+]
+
+
+def _global_mmlu_lite_adapter(line):
+ return {
+ "question": line["question"],
+ "choices": [line["option_a"], line["option_b"], line["option_c"], line["option_d"]],
+ "gold_idx": ascii_uppercase.index(line["answer"]),
+ }
+
+
+def _valid_options_filter(x):
+ """Skip rows where any option is missing."""
+ return all(x[f"option_{opt}"] is not None and x[f"option_{opt}"].strip() for opt in "abcd")
+
+
+TASKS_TABLE = [
+ LightevalTaskConfig(
+ name=f"global_mmlu_lite:{language.value}:{suffix}",
+ prompt_function=get_mcq_prompt_function(
+ language,
+ _global_mmlu_lite_adapter,
+ formulation=formulation,
+ ),
+ hf_repo="CohereLabs/Global-MMLU-Lite",
+ hf_subset=standardize_tag(language.value),
+ evaluation_splits=("test",),
+ few_shots_split="dev",
+ hf_filter=_valid_options_filter,
+ metrics=metrics,
+ )
+ for language in _LANGUAGES
+ for suffix, formulation, metrics in [
+ ("cf", CFFormulation(), _CF_METRICS),
+ ("mcf", MCFFormulation(), _MCF_METRICS),
+ ]
+]
+
+# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match
+TASKS_TABLE += [
+ LightevalTaskConfig(
+ name=f"global_mmlu_lite:{language.value}:mcf_em",
+ prompt_function=get_mcq_prompt_function(
+ language,
+ _global_mmlu_lite_adapter,
+ formulation=MCFFormulation(),
+ ),
+ hf_repo="CohereLabs/Global-MMLU-Lite",
+ hf_subset=standardize_tag(language.value),
+ evaluation_splits=("test",),
+ few_shots_split="dev",
+ hf_filter=_valid_options_filter,
+ generation_size=5,
+ stop_sequence=["\n"],
+ metrics=[Metrics.exact_match],
+ )
+ for language in _LANGUAGES
+]
diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py
index ba5d9a323..1f03a0b26 100644
--- a/src/lighteval/tasks/multilingual/tasks/mgsm.py
+++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py
@@ -6,37 +6,57 @@
juletxara/mgsm
abstract:
-Mgsm multilingual benchmark.
+MGSM (Multilingual Grade School Math) is a multilingual benchmark testing
+mathematical reasoning across languages, derived from GSM8K.
+
+Refactored to use unified :gen suffix consistent with English gsm8k.py.
+English is excluded — use gsm8k.py for English evaluation.
+Reports both expr_gold_metric (math expression parser) and
+MultilingualQuasiExactMatchMetric (language-aware fuzzy match, handles
+non-ASCII digit systems like Japanese/Thai).
languages:
-bengali, chinese, english, french, german, japanese, russian, spanish, swahili,
-telugu, thai
+bengali, french, german, japanese, russian, spanish, swahili, telugu, thai,
+chinese
tags:
math, multilingual, reasoning
paper:
+https://arxiv.org/abs/2210.03057
"""
from langcodes import standardize_tag
-from lighteval.metrics.dynamic_metrics import (
- MultilingualQuasiExactMatchMetric,
-)
+from lighteval.metrics.dynamic_metrics import MultilingualQuasiExactMatchMetric
+from lighteval.metrics.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.templates.qa import get_qa_prompt_function
from lighteval.utils.language import Language
+# Languages covered by juletxara/mgsm (English excluded — use gsm8k.py)
+_LANGUAGES = [
+ Language.SPANISH,
+ Language.FRENCH,
+ Language.GERMAN,
+ Language.RUSSIAN,
+ Language.CHINESE,
+ Language.JAPANESE,
+ Language.THAI,
+ Language.SWAHILI,
+ Language.BENGALI,
+ Language.TELUGU,
+]
+
+
TASKS_TABLE = [
LightevalTaskConfig(
- name=f"mgsm_{language.value}",
+ name=f"mgsm:{language.value}:gen",
prompt_function=get_qa_prompt_function(
language,
lambda line: {
"question": line["question"],
- # The cot is available but we have no use:
- # line["answer"]
"choices": [str(line["answer_number"])],
},
),
@@ -44,23 +64,12 @@
hf_subset=standardize_tag(language.value),
evaluation_splits=("test",),
few_shots_split="train",
- generation_size=25,
+ generation_size=512,
metrics=[
+ Metrics.expr_gold_metric,
MultilingualQuasiExactMatchMetric(language, "full"),
],
- stop_sequence=("\n",),
+ stop_sequence=["\n"],
)
- for language in [
- Language.ENGLISH,
- Language.SPANISH,
- Language.FRENCH,
- Language.GERMAN,
- Language.RUSSIAN,
- Language.CHINESE,
- Language.JAPANESE,
- Language.THAI,
- Language.SWAHILI,
- Language.BENGALI,
- Language.TELUGU,
- ]
+ for language in _LANGUAGES
]
diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py b/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py
index f7485124f..200f14f11 100644
--- a/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py
+++ b/src/lighteval/tasks/multilingual/tasks/mlmm_arc_challenge.py
@@ -9,10 +9,11 @@
ARC (AI2 Reasoning Challenge) is a dataset for question answering that requires
reasoning. It consists of multiple-choice science questions from 3rd to 9th
grade exams. The dataset is split into two parts: ARC-Easy and ARC-Challenge.
-ARC-Easy contains questions that can be answered correctly by both humans and
-simple baseline models. ARC-Challenge contains questions that are difficult for
-both humans and current AI systems. Similar to MMLU, ARC tasks uses PMI
-normalization by default but only for the challenge set.
+ARC-Challenge contains questions that are difficult for both humans and current
+AI systems.
+
+Refactored to use unified :cf / :mcf suffixes consistent with English arc.py.
+English is excluded — use arc.py for English evaluation.
languages:
arabic, bengali, catalan, chinese, croatian, danish, dutch, french, german,
@@ -31,33 +32,91 @@
from langcodes import standardize_tag
-from lighteval.metrics.dynamic_metrics import (
- LogLikelihoodAccMetric,
-)
-from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm
+from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
+from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.normalizations import LogProbCharNorm
from lighteval.tasks.lighteval_task import LightevalTaskConfig
-from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation
from lighteval.tasks.templates.multichoice import get_mcq_prompt_function
-from lighteval.tasks.templates.utils.formulation import (
- CFFormulation,
- HybridFormulation,
- MCFFormulation,
-)
+from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
from lighteval.utils.language import Language
+_CF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+ Metrics.target_bits_per_byte,
+]
+
+_MCF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+]
+
+# Languages covered by jon-tow/okapi_arc_challenge (English excluded)
+_LANGUAGES = [
+ Language.RUSSIAN,
+ Language.GERMAN,
+ Language.CHINESE,
+ Language.FRENCH,
+ Language.SPANISH,
+ Language.ITALIAN,
+ Language.DUTCH,
+ Language.VIETNAMESE,
+ Language.INDONESIAN,
+ Language.ARABIC,
+ Language.HUNGARIAN,
+ Language.ROMANIAN,
+ Language.DANISH,
+ Language.SLOVAK,
+ Language.UKRAINIAN,
+ Language.CATALAN,
+ Language.SERBIAN,
+ Language.CROATIAN,
+ Language.HINDI,
+ Language.BENGALI,
+ Language.TAMIL,
+ Language.NEPALI,
+ Language.MALAYALAM,
+ Language.MARATHI,
+ Language.TELUGU,
+ Language.KANNADA,
+]
+
+
+def _arc_adapter(line):
+ if "question" in line and "choices" in line:
+ choices = line["choices"]["text"]
+ answer_key = line["answerKey"]
+ else:
+ choices = [
+ line[key]
+ for key in ("option_a", "option_b", "option_c", "option_d", "option_e")
+ if line.get(key)
+ ]
+ answer_key = line["answer"]
+ return {
+ "question": line["instruction"],
+ "choices": choices,
+ "gold_idx": int(answer_key) - 1
+ if answer_key.isdigit()
+ else ascii_uppercase.index(answer_key),
+ }
+
+ return {
+ "question": line["question"],
+ "choices": choices,
+ "gold_idx": int(answer_key) - 1
+ if answer_key.isdigit()
+ else ascii_uppercase.index(answer_key),
+ }
+
+
TASKS_TABLE = [
LightevalTaskConfig(
- name=f"mlmm_arc_{language.value}_{formulation.name.lower()}:challenge",
+ name=f"mlmm_arc:{language.value}:{suffix}",
prompt_function=get_mcq_prompt_function(
language,
- lambda line: {
- "question": line["question"],
- "choices": line["choices"]["text"],
- "gold_idx": int(line["answerKey"]) - 1
- if line["answerKey"].isdigit()
- else ascii_uppercase.index(line["answerKey"]),
- },
+ _arc_adapter,
formulation=formulation,
),
hf_repo="jon-tow/okapi_arc_challenge",
@@ -65,46 +124,32 @@
hf_revision="823d5d7bfaf8974a3ab52a825b6cf4903b35dbc4",
evaluation_splits=("test",),
few_shots_split="train",
- metrics=get_metrics_for_formulation(
- formulation,
- [
- LogLikelihoodAccMetric(normalization=LogProbTokenNorm()),
- LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
- LogLikelihoodAccMetric(normalization=LogProbPMINorm()),
- ],
- ),
+ metrics=metrics,
)
- for language in [
- Language.RUSSIAN,
- Language.GERMAN,
- Language.CHINESE,
- Language.FRENCH,
- Language.SPANISH,
- Language.ITALIAN,
- Language.DUTCH,
- Language.VIETNAMESE,
- Language.INDONESIAN,
- Language.ARABIC,
- Language.HUNGARIAN,
- Language.ROMANIAN,
- Language.DANISH,
- Language.SLOVAK,
- Language.UKRAINIAN,
- Language.CATALAN,
- Language.SERBIAN,
- Language.CROATIAN,
- Language.HINDI,
- Language.BENGALI,
- Language.TAMIL,
- Language.NEPALI,
- Language.MALAYALAM,
- Language.MARATHI,
- Language.TELUGU,
- Language.KANNADA,
- ]
- for formulation in [
- MCFFormulation(),
- CFFormulation(),
- HybridFormulation(),
+ for language in _LANGUAGES
+ for suffix, formulation, metrics in [
+ ("cf", CFFormulation(), _CF_METRICS),
+ ("mcf", MCFFormulation(), _MCF_METRICS),
]
]
+
+# Greedy variant: MCF-style prompt, generate 1 token, exact match
+TASKS_TABLE += [
+ LightevalTaskConfig(
+ name=f"mlmm_arc:{language.value}:mcf_em",
+ prompt_function=get_mcq_prompt_function(
+ language,
+ _arc_adapter,
+ formulation=MCFFormulation(),
+ ),
+ hf_repo="jon-tow/okapi_arc_challenge",
+ hf_subset=standardize_tag(language.value),
+ hf_revision="823d5d7bfaf8974a3ab52a825b6cf4903b35dbc4",
+ evaluation_splits=("test",),
+ few_shots_split="train",
+ generation_size=1,
+ stop_sequence=["\n"],
+ metrics=[Metrics.exact_match],
+ )
+ for language in _LANGUAGES
+]
diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py
index 3945b8947..bbd631278 100644
--- a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py
+++ b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py
@@ -3,18 +3,23 @@
Mlmm Hellaswag
dataset:
-jon-tow/okapi_hellaswag
+alexandrainst/m_hellaswag (most languages)
+jon-tow/okapi_hellaswag (Chinese)
abstract:
Hellaswag is a commonsense reasoning task that requires models to complete a
given scenario with the most plausible ending. It tests the model's ability to
understand and reason about everyday situations and human behavior.
-MLMM-Hellaswag: Multilingual adaptation of Hellaswag
+MLMM-Hellaswag: Multilingual adaptation of Hellaswag.
+
+Refactored to use unified :cf / :mcf / :mcf_em suffixes consistent with English hellaswag.py.
+HellaSwag has 4 candidate continuations; the MCF formulation labels them A–D and scores (or
+generates) the label token. English is excluded — use hellaswag.py for English evaluation.
languages:
arabic, armenian, basque, bengali, catalan, chinese, croatian, danish, dutch,
french, german, gujarati, hindi, hungarian, icelandic, indonesian, italian,
-kannada, malayalam, marathi, nepali, norwegian, portuguese, romanian, russian,
+kannada, malayalam, marathi, nepali, portuguese, romanian, russian,
serbian, slovak, spanish, swedish, tamil, telugu, ukrainian, vietnamese
tags:
@@ -26,82 +31,127 @@
from langcodes import standardize_tag
-from lighteval.metrics.dynamic_metrics import (
- LogLikelihoodAccMetric,
-)
-from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm
+from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
+from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.normalizations import LogProbCharNorm
from lighteval.tasks.lighteval_task import LightevalTaskConfig
-from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation
from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function
-from lighteval.tasks.templates.utils.formulation import (
- CFFormulation,
- HybridFormulation,
- MCFFormulation,
-)
+from lighteval.tasks.templates.utils.formulation import CFFormulation, MCFFormulation
from lighteval.utils.language import Language
+_CF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+ Metrics.target_bits_per_byte,
+]
+
+_MCF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+]
+
+# Languages covered by alexandrainst/m_hellaswag and jon-tow/okapi_hellaswag.
+# English excluded — use hellaswag.py for English evaluation.
+_LANGUAGES = [
+ Language.ARABIC,
+ Language.BENGALI,
+ Language.CATALAN,
+ Language.DANISH,
+ Language.GERMAN,
+ Language.SPANISH,
+ Language.BASQUE,
+ Language.FRENCH,
+ Language.GUJARATI,
+ Language.HINDI,
+ Language.CROATIAN,
+ Language.HUNGARIAN,
+ Language.ARMENIAN,
+ Language.INDONESIAN,
+ Language.ICELANDIC,
+ Language.ITALIAN,
+ Language.KANNADA,
+ Language.MALAYALAM,
+ Language.MARATHI,
+ Language.NEPALI,
+ Language.DUTCH,
+ Language.PORTUGUESE,
+ Language.ROMANIAN,
+ Language.RUSSIAN,
+ Language.SLOVAK,
+ Language.SERBIAN,
+ Language.SWEDISH,
+ Language.TAMIL,
+ Language.TELUGU,
+ Language.UKRAINIAN,
+ Language.VIETNAMESE,
+ Language.CHINESE,
+]
+
+
+def _hellaswag_adapter(line):
+ return {
+ "ctx_a": line["ctx_a"],
+ "ctx_b": line["ctx_b"],
+ "continuations": line["endings"],
+ "gold_idx": int(line["label"]),
+ }
+
+
TASKS_TABLE = [
LightevalTaskConfig(
- name=f"mlmm_hellaswag_{lang.value}_{formulation.name.lower()}",
+ name=f"mlmm_hellaswag:{lang.value}:cf",
prompt_function=get_hellaswag_prompt_function(
language=lang,
- adapter=lambda line: {
- # We don't use activity_label as they are not available
- "ctx_a": line["ctx_a"],
- "ctx_b": line["ctx_b"],
- "continuations": line["endings"],
- "gold_idx": int(line["label"]),
- },
- formulation=formulation,
+ adapter=_hellaswag_adapter,
+ formulation=CFFormulation(),
),
hf_repo="alexandrainst/m_hellaswag" if lang != Language.CHINESE else "jon-tow/okapi_hellaswag",
hf_subset=standardize_tag(lang.value),
- # hf_revision="96ed8e0dfc6172dad1d3df338d7b8ba6c1ff9d83",
evaluation_splits=["val" if lang != Language.CHINESE else "validation"],
hf_avail_splits=["val" if lang != Language.CHINESE else "validation"],
- metrics=get_metrics_for_formulation(
- formulation,
- [
- LogLikelihoodAccMetric(normalization=LogProbTokenNorm()),
- LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
- ],
+ few_shots_split=None,
+ metrics=_CF_METRICS,
+ )
+ for lang in _LANGUAGES
+]
+
+# MCF variant: labeled options in prompt, score label tokens via logprobs
+TASKS_TABLE += [
+ LightevalTaskConfig(
+ name=f"mlmm_hellaswag:{lang.value}:mcf",
+ prompt_function=get_hellaswag_prompt_function(
+ language=lang,
+ adapter=_hellaswag_adapter,
+ formulation=MCFFormulation(),
),
+ hf_repo="alexandrainst/m_hellaswag" if lang != Language.CHINESE else "jon-tow/okapi_hellaswag",
+ hf_subset=standardize_tag(lang.value),
+ evaluation_splits=["val" if lang != Language.CHINESE else "validation"],
+ hf_avail_splits=["val" if lang != Language.CHINESE else "validation"],
+ few_shots_split=None,
+ metrics=_MCF_METRICS,
+ )
+ for lang in _LANGUAGES
+]
+
+# Greedy variant: MCF-style prompt, generate 1 token, exact match
+TASKS_TABLE += [
+ LightevalTaskConfig(
+ name=f"mlmm_hellaswag:{lang.value}:mcf_em",
+ prompt_function=get_hellaswag_prompt_function(
+ language=lang,
+ adapter=_hellaswag_adapter,
+ formulation=MCFFormulation(),
+ ),
+ hf_repo="alexandrainst/m_hellaswag" if lang != Language.CHINESE else "jon-tow/okapi_hellaswag",
+ hf_subset=standardize_tag(lang.value),
+ evaluation_splits=["val" if lang != Language.CHINESE else "validation"],
+ hf_avail_splits=["val" if lang != Language.CHINESE else "validation"],
+ few_shots_split=None,
+ generation_size=1,
+ stop_sequence=["\n"],
+ metrics=[Metrics.exact_match],
)
- for lang in [
- Language.ARABIC,
- Language.BENGALI,
- Language.CATALAN,
- Language.DANISH,
- Language.GERMAN,
- Language.SPANISH,
- Language.BASQUE,
- Language.FRENCH,
- Language.GUJARATI,
- Language.HINDI,
- Language.CROATIAN,
- Language.HUNGARIAN,
- Language.ARMENIAN,
- Language.INDONESIAN,
- Language.ICELANDIC,
- Language.ITALIAN,
- Language.KANNADA,
- Language.MALAYALAM,
- Language.MARATHI,
- Language.NORWEGIAN,
- Language.NEPALI,
- Language.DUTCH,
- Language.PORTUGUESE,
- Language.ROMANIAN,
- Language.RUSSIAN,
- Language.SLOVAK,
- Language.SERBIAN,
- Language.SWEDISH,
- Language.TAMIL,
- Language.TELUGU,
- Language.UKRAINIAN,
- Language.VIETNAMESE,
- Language.CHINESE,
- ]
- for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()]
+ for lang in _LANGUAGES
]
diff --git a/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py
new file mode 100644
index 000000000..a21c3b5f2
--- /dev/null
+++ b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py
@@ -0,0 +1,170 @@
+"""
+name:
+MMLU-ProX (Multilingual MMLU-Pro)
+
+dataset:
+li-lab/MMLU-ProX
+
+abstract:
+MMLU-ProX is a multilingual extension of MMLU-Pro covering 29 languages. It
+uses up to 10 answer options per question (labels A–J), making it more
+challenging than standard 4-option MMLU.
+
+English is excluded — use mmlu_pro.py for English evaluation.
+
+Metrics:
+- :cf — completion formulation; scores each non-null option text; reports
+ acc, acc_norm (char), target_bpb
+- :mcf — multiple-choice formulation; scores label tokens A–J; reports
+ acc, acc_norm (char)
+
+languages:
+afrikaans, arabic, bengali, czech, german, spanish, french, hindi, hungarian,
+indonesian, italian, japanese, korean, marathi, nepali, portuguese, russian,
+serbian, swahili, telugu, thai, ukrainian, urdu, vietnamese, wolof, yoruba,
+chinese, zulu
+
+tags:
+knowledge, multilingual, multiple-choice
+
+paper:
+https://huggingface.co/datasets/li-lab/MMLU-ProX
+"""
+
+from string import ascii_uppercase
+
+from langcodes import standardize_tag
+
+from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
+from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.normalizations import LogProbCharNorm
+from lighteval.tasks.lighteval_task import LightevalTaskConfig
+from lighteval.tasks.requests import Doc
+from lighteval.utils.language import Language
+
+
+_CF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+ Metrics.target_bits_per_byte,
+]
+
+_MCF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+]
+
+# MMLU-ProX languages (English excluded — use mmlu_pro.py for English)
+_LANGUAGES = [
+ Language.AFRIKAANS, # af
+ Language.ARABIC, # ar
+ Language.BENGALI, # bn
+ Language.CZECH, # cs
+ Language.GERMAN, # de
+ Language.SPANISH, # es
+ Language.FRENCH, # fr
+ Language.HINDI, # hi
+ Language.HUNGARIAN, # hu
+ Language.INDONESIAN, # id
+ Language.ITALIAN, # it
+ Language.JAPANESE, # ja
+ Language.KOREAN, # ko
+ Language.MARATHI, # mr
+ Language.NEPALI, # ne
+ Language.PORTUGUESE, # pt
+ Language.RUSSIAN, # ru
+ Language.SERBIAN, # sr
+ Language.SWAHILI, # sw
+ Language.TELUGU, # te
+ Language.THAI, # th
+ Language.UKRAINIAN, # uk
+ Language.URDU, # ur
+ Language.VIETNAMESE, # vi
+ Language.WOLOF, # wo
+ Language.YORUBA, # yo
+ Language.CHINESE, # zh
+ Language.ZULU, # zu
+]
+
+# option_0..option_9 column names
+_OPTION_COLS = [f"option_{i}" for i in range(10)]
+
+
+def _get_options(line):
+ """Return list of non-null options from option_0..option_9."""
+ return [line[col] for col in _OPTION_COLS if line.get(col) is not None and str(line[col]).strip()]
+
+
+def mmlu_prox_cf_prompt(line, task_name: str = None):
+ """CF: score each non-null option text directly."""
+ options = _get_options(line)
+ if not options:
+ return None
+ choices = [c if c and c[0].isspace() else " " + c for c in options]
+ return Doc(
+ task_name=task_name,
+ query=f"Question: {line['question'].strip()}\nAnswer:",
+ choices=choices,
+ gold_index=line["answer_index"],
+ )
+
+
+def mmlu_prox_mcf_prompt(line, task_name: str = None):
+ """MCF: show labeled options A–J, score label tokens only."""
+ options = _get_options(line)
+ if not options:
+ return None
+ labels = list(ascii_uppercase[: len(options)])
+ query = f"Question: {line['question'].strip()}\n"
+ query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)])
+ query += "Answer:"
+ return Doc(
+ task_name=task_name,
+ query=query,
+ choices=[" " + lbl for lbl in labels],
+ gold_index=line["answer_index"],
+ )
+
+
+def _valid_filter(x):
+ """Skip rows where no options are present."""
+ return any(
+ x.get(col) is not None and str(x[col]).strip()
+ for col in _OPTION_COLS
+ )
+
+
+TASKS_TABLE = [
+ LightevalTaskConfig(
+ name=f"mmlu_prox:{language.value}:{suffix}",
+ prompt_function=prompt_fn,
+ hf_repo="li-lab/MMLU-ProX",
+ hf_subset=standardize_tag(language.value),
+ evaluation_splits=("test",),
+ few_shots_split="validation",
+ hf_filter=_valid_filter,
+ metrics=metrics,
+ )
+ for language in _LANGUAGES
+ for suffix, prompt_fn, metrics in [
+ ("cf", mmlu_prox_cf_prompt, _CF_METRICS),
+ ("mcf", mmlu_prox_mcf_prompt, _MCF_METRICS),
+ ]
+]
+
+# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match
+TASKS_TABLE += [
+ LightevalTaskConfig(
+ name=f"mmlu_prox:{language.value}:mcf_em",
+ prompt_function=mmlu_prox_mcf_prompt,
+ hf_repo="li-lab/MMLU-ProX",
+ hf_subset=standardize_tag(language.value),
+ evaluation_splits=("test",),
+ few_shots_split="validation",
+ hf_filter=_valid_filter,
+ generation_size=5,
+ stop_sequence=["\n"],
+ metrics=[Metrics.exact_match],
+ )
+ for language in _LANGUAGES
+]
diff --git a/src/lighteval/tasks/multilingual/tasks/wmt24pp.py b/src/lighteval/tasks/multilingual/tasks/wmt24pp.py
new file mode 100644
index 000000000..27c5244bd
--- /dev/null
+++ b/src/lighteval/tasks/multilingual/tasks/wmt24pp.py
@@ -0,0 +1,114 @@
+"""
+name:
+WMT24++
+
+dataset:
+google/wmt24pp
+
+abstract:
+WMT24++ is an extension of the WMT2024 general translation shared task test
+sets, covering 55+ language pairs with English as one side. Data includes
+post-edited professional translations and quality flags.
+
+Tasks are 0-shot generation with chrF++ and BLEU-4 metrics.
+The public selector is `wmt24pp:{lp}|0`, which expands to the internal
+directional tasks `wmt24pp:en_to_x:{lp}` and `wmt24pp:x_to_en:{lp}`.
+
+NOTE: The dataset's "train" split is the evaluation data (WMT test sets are
+published in HuggingFace with split="train").
+Few-shot evaluation reuses this same split, since no separate few-shot split is
+published for WMT24++ in the dataset.
+
+NOTE: `lp` column values use regional codes (e.g., "de_DE", "fr_FR"). Tasks
+are named using these codes directly. Check the dataset for all available lp
+values: https://huggingface.co/datasets/google/wmt24pp
+
+tags:
+multilingual, translation
+
+paper:
+https://arxiv.org/abs/2412.06378
+"""
+
+from functools import partial
+
+from lighteval.tasks.multilingual.utils.translation import (
+ TRANSLATION_METRICS,
+ EnglishCentricTranslationConfig,
+ build_english_centric_translation_tasks,
+)
+from lighteval.utils.language import Language
+
+
+_TRANSLATION_METRICS = TRANSLATION_METRICS
+
+_LANGUAGE_CONFIGS = [
+ ("de_DE", Language.GERMAN),
+ ("fr_FR", Language.FRENCH),
+ ("cs_CZ", Language.CZECH),
+ ("es_MX", Language.SPANISH),
+ ("ru_RU", Language.RUSSIAN),
+ ("zh_CN", Language.CHINESE),
+ ("ja_JP", Language.JAPANESE),
+ ("uk_UA", Language.UKRAINIAN),
+ ("hi_IN", Language.HINDI),
+ ("ar_EG", Language.ARABIC),
+ ("ko_KR", Language.KOREAN),
+ ("pt_BR", Language.PORTUGUESE),
+ ("tr_TR", Language.TURKISH),
+ ("pl_PL", Language.POLISH),
+ ("he_IL", Language.HEBREW),
+ ("nl_NL", Language.DUTCH),
+ ("it_IT", Language.ITALIAN),
+ ("sv_SE", Language.SWEDISH),
+ ("fi_FI", Language.FINNISH),
+ ("vi_VN", Language.VIETNAMESE),
+ ("bn_IN", Language.BENGALI),
+ ("th_TH", Language.THAI),
+ ("id_ID", Language.INDONESIAN),
+ ("hu_HU", Language.HUNGARIAN),
+]
+
+
+def _make_forward_adapter(_lp_code):
+ """Adapter for en→X direction: source=English, target=other language."""
+ return lambda line: {
+ "source_text": line["source"],
+ "target_text": line["target"],
+ }
+
+
+def _make_reverse_adapter(_lp_code):
+ """Adapter for X→en direction: source=other language, target=English."""
+ return lambda line: {
+ "source_text": line["target"],
+ "target_text": line["source"],
+ }
+
+
+def _make_filter(lp_code):
+ """Skip bad source sentences inside a specific WMT24++ pair config."""
+ return partial(lambda _lp, x: not x.get("is_bad_source", False), lp_code)
+
+
+TASKS_TABLE = build_english_centric_translation_tasks(
+ base_name="wmt24pp",
+ hf_repo="google/wmt24pp",
+ language_configs=[
+ EnglishCentricTranslationConfig(
+ task_id=lp_code,
+ target_language=target_lang,
+ forward_hf_subset=f"en-{lp_code}",
+ reverse_hf_subset=f"en-{lp_code}",
+ forward_adapter=_make_forward_adapter(lp_code),
+ reverse_adapter=_make_reverse_adapter(lp_code),
+ hf_filter=_make_filter(lp_code),
+ hf_avail_splits=("train",),
+ evaluation_splits=("train",),
+ few_shots_split="train",
+ few_shots_select="random_sampling",
+ )
+ for lp_code, target_lang in _LANGUAGE_CONFIGS
+ ],
+ metrics=_TRANSLATION_METRICS,
+)
diff --git a/src/lighteval/tasks/multilingual/utils/translation.py b/src/lighteval/tasks/multilingual/utils/translation.py
new file mode 100644
index 000000000..be7eaf227
--- /dev/null
+++ b/src/lighteval/tasks/multilingual/utils/translation.py
@@ -0,0 +1,103 @@
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+
+from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.utils.metric_utils import Metric
+from lighteval.tasks.lighteval_task import LightevalTaskConfig
+
+# Canonical metric set for English-centric translation tasks.
+# chrf++ and BLEU are instant (sacrebleu); COMET-22 requires unbabel-comet (pip install lighteval[multilingual]).
+try:
+ import comet # noqa: F401
+
+ _comet_available = True
+except ImportError:
+ _comet_available = False
+
+TRANSLATION_METRICS: list[Metrics] = [
+ Metrics.chrf_plus,
+ Metrics.bleu,
+ *([ Metrics.comet22] if _comet_available else []),
+]
+from lighteval.tasks.templates.translation import (
+ TranslationInput,
+ get_translation_prompt_function,
+)
+from lighteval.tasks.templates.utils.formulation import CFFormulation
+from lighteval.utils.language import Language
+
+
+@dataclass(frozen=True)
+class EnglishCentricTranslationConfig:
+ task_id: str
+ target_language: Language
+ forward_hf_subset: str
+ reverse_hf_subset: str
+ forward_adapter: Callable[[dict], TranslationInput | None]
+ reverse_adapter: Callable[[dict], TranslationInput | None]
+ hf_filter: Callable[[dict], bool] | None = None
+ hf_avail_splits: Sequence[str] = ("train", "validation", "test")
+ evaluation_splits: Sequence[str] = ("validation",)
+ few_shots_split: str | None = None
+ few_shots_select: str | None = None
+
+
+def build_english_centric_translation_tasks(
+ *,
+ base_name: str,
+ hf_repo: str,
+ language_configs: Sequence[EnglishCentricTranslationConfig],
+ metrics: Sequence[Metric | Metrics],
+ generation_size: int = 300,
+ stop_sequence: Sequence[str] = ("\n",),
+ version: int = 0,
+) -> list[LightevalTaskConfig]:
+ tasks = []
+
+ for config in language_configs:
+ tasks.append(
+ LightevalTaskConfig(
+ name=f"{base_name}:en_to_x:{config.task_id}",
+ prompt_function=get_translation_prompt_function(
+ source_language=Language.ENGLISH,
+ target_language=config.target_language,
+ adapter=config.forward_adapter,
+ formulation=CFFormulation(),
+ ),
+ hf_repo=hf_repo,
+ hf_subset=config.forward_hf_subset,
+ hf_filter=config.hf_filter,
+ hf_avail_splits=config.hf_avail_splits,
+ evaluation_splits=config.evaluation_splits,
+ few_shots_split=config.few_shots_split,
+ few_shots_select=config.few_shots_select,
+ generation_size=generation_size,
+ metrics=metrics,
+ stop_sequence=stop_sequence,
+ version=version,
+ )
+ )
+ tasks.append(
+ LightevalTaskConfig(
+ name=f"{base_name}:x_to_en:{config.task_id}",
+ prompt_function=get_translation_prompt_function(
+ source_language=config.target_language,
+ target_language=Language.ENGLISH,
+ adapter=config.reverse_adapter,
+ formulation=CFFormulation(),
+ ),
+ hf_repo=hf_repo,
+ hf_subset=config.reverse_hf_subset,
+ hf_filter=config.hf_filter,
+ hf_avail_splits=config.hf_avail_splits,
+ evaluation_splits=config.evaluation_splits,
+ few_shots_split=config.few_shots_split,
+ few_shots_select=config.few_shots_select,
+ generation_size=generation_size,
+ metrics=metrics,
+ stop_sequence=stop_sequence,
+ version=version,
+ )
+ )
+
+ return tasks
diff --git a/src/lighteval/tasks/tasks/mmlu_prox.py b/src/lighteval/tasks/tasks/mmlu_prox.py
new file mode 100644
index 000000000..08c04703b
--- /dev/null
+++ b/src/lighteval/tasks/tasks/mmlu_prox.py
@@ -0,0 +1,133 @@
+"""
+name:
+MMLU-ProX (English)
+
+dataset:
+li-lab/MMLU-ProX
+
+abstract:
+MMLU-ProX is a multilingual extension of MMLU-Pro. This file covers the
+English ("en") subset only. It uses up to 10 answer options per question
+(labels A–J), consistent with MMLU-Pro's harder format.
+
+For multilingual evaluation see:
+lighteval/tasks/multilingual/tasks/mmlu_prox.py
+
+languages:
+english
+
+tags:
+general-knowledge, knowledge, multiple-choice
+
+paper:
+https://huggingface.co/datasets/li-lab/MMLU-ProX
+"""
+
+from string import ascii_uppercase
+
+from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
+from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.normalizations import LogProbCharNorm
+from lighteval.tasks.lighteval_task import LightevalTaskConfig
+from lighteval.tasks.requests import Doc
+
+
+_CF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+ Metrics.target_bits_per_byte,
+]
+
+_MCF_METRICS = [
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+]
+
+# option_0..option_9 column names
+_OPTION_COLS = [f"option_{i}" for i in range(10)]
+
+
+def _get_options(line):
+ """Return list of non-null options from option_0..option_9."""
+ return [line[col] for col in _OPTION_COLS if line.get(col) is not None and str(line[col]).strip()]
+
+
+def mmlu_prox_cf_prompt(line, task_name: str = None):
+ """CF: score each non-null option text directly."""
+ options = _get_options(line)
+ if not options:
+ return None
+ choices = [c if c and c[0].isspace() else " " + c for c in options]
+ return Doc(
+ task_name=task_name,
+ query=f"Question: {line['question'].strip()}\nAnswer:",
+ choices=choices,
+ gold_index=line["answer_index"],
+ )
+
+
+def mmlu_prox_mcf_prompt(line, task_name: str = None):
+ """MCF: show labeled options A–J, score label tokens only."""
+ options = _get_options(line)
+ if not options:
+ return None
+ labels = list(ascii_uppercase[: len(options)])
+ query = f"Question: {line['question'].strip()}\n"
+ query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)])
+ query += "Answer:"
+ return Doc(
+ task_name=task_name,
+ query=query,
+ choices=[" " + lbl for lbl in labels],
+ gold_index=line["answer_index"],
+ )
+
+
+def _valid_filter(x):
+ """Skip rows where no options are present."""
+ return any(
+ x.get(col) is not None and str(x[col]).strip()
+ for col in _OPTION_COLS
+ )
+
+
+mmlu_prox_cf = LightevalTaskConfig(
+ name="mmlu_prox_eng:cf",
+ prompt_function=mmlu_prox_cf_prompt,
+ hf_repo="li-lab/MMLU-ProX",
+ hf_subset="en",
+ evaluation_splits=("test",),
+ few_shots_split="validation",
+ hf_filter=_valid_filter,
+ metrics=_CF_METRICS,
+ version=0,
+)
+
+mmlu_prox_mcf = LightevalTaskConfig(
+ name="mmlu_prox_eng:mcf",
+ prompt_function=mmlu_prox_mcf_prompt,
+ hf_repo="li-lab/MMLU-ProX",
+ hf_subset="en",
+ evaluation_splits=("test",),
+ few_shots_split="validation",
+ hf_filter=_valid_filter,
+ metrics=_MCF_METRICS,
+ version=0,
+)
+
+# Greedy variant: MCF-style prompt, generate up to 5 tokens, exact match
+mmlu_prox_mcf_em = LightevalTaskConfig(
+ name="mmlu_prox_eng:mcf_em",
+ prompt_function=mmlu_prox_mcf_prompt,
+ hf_repo="li-lab/MMLU-ProX",
+ hf_subset="en",
+ evaluation_splits=("test",),
+ few_shots_split="validation",
+ hf_filter=_valid_filter,
+ generation_size=5,
+ stop_sequence=["\n"],
+ metrics=[Metrics.exact_match],
+ version=0,
+)
+
+TASKS_TABLE = [mmlu_prox_cf, mmlu_prox_mcf, mmlu_prox_mcf_em]
diff --git a/src/lighteval/tasks/templates/translation.py b/src/lighteval/tasks/templates/translation.py
index 6b4c54a62..174e05a38 100644
--- a/src/lighteval/tasks/templates/translation.py
+++ b/src/lighteval/tasks/templates/translation.py
@@ -145,7 +145,7 @@ def translation_prompt(
for text in as_list(input_data["target_text"])
]
- return continuation_prompt_fn(
+ doc = continuation_prompt_fn(
{
"instruction": input_data.get("instruction", ""),
"context": context,
@@ -154,5 +154,9 @@ def translation_prompt(
},
task_name,
)
+ # Store the cleaned source text so COMET metric can access it without re-parsing the prompt.
+ if doc is not None:
+ doc.specific = (doc.specific or {}) | {"source_text": source_text}
+ return doc
return translation_prompt
diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py
index 96cd8fc2c..99405c45b 100644
--- a/src/lighteval/tasks/templates/utils/translation_literals.py
+++ b/src/lighteval/tasks/templates/utils/translation_literals.py
@@ -1426,3 +1426,27 @@ def __getattribute__(self, name: str) -> str:
Language.YUE_CHINESE: TranslationLiterals(language=Language.YUE_CHINESE),
Language.ZULU: TranslationLiterals(language=Language.ZULU),
}
+
+# Some multilingual benchmark additions rely on the generic MCQ prompt template,
+# which requires at least localized question/answer labels. Where dedicated
+# translations are still missing, fall back to English labels so the task can run.
+for _language in [
+ Language.AMHARIC,
+ Language.HAUSA,
+ Language.HEBREW,
+ Language.IGBO,
+ Language.KIRGHIZ,
+ Language.KOREAN,
+ Language.LITHUANIAN,
+ Language.MALAGASY,
+ Language.MALAY,
+ Language.NEPALI,
+ Language.NYANJA,
+ Language.PERSIAN,
+ Language.SHONA,
+ Language.SINHALA,
+ Language.SOMALI,
+ Language.YORUBA,
+]:
+ TRANSLATION_LITERALS[_language].question_word = "question"
+ TRANSLATION_LITERALS[_language].answer = "answer"
From b6f86713bdc9cc7ddfc9c8eb59f3a26868e4721b Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Sun, 19 Apr 2026 13:22:03 +0200
Subject: [PATCH 2/8] fix msmg
---
src/lighteval/tasks/lighteval_task.py | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py
index 3f45315ba..ddbccb1ad 100644
--- a/src/lighteval/tasks/lighteval_task.py
+++ b/src/lighteval/tasks/lighteval_task.py
@@ -134,7 +134,7 @@ def _load_hub_raw_dataset_files(
return load_dataset(data_format, data_files=data_files)
-def _load_mgsm_dataset(config_name: str) -> DatasetDict:
+def _load_mgsm_dataset(config_name: str, local_files_only: bool = False) -> DatasetDict:
import csv
import importlib.util
@@ -144,11 +144,13 @@ def _load_mgsm_dataset(config_name: str) -> DatasetDict:
repo_id="juletxara/mgsm",
repo_type="dataset",
filename="exemplars.py",
+ local_files_only=local_files_only,
)
tsv_path = hf_hub_download(
repo_id="juletxara/mgsm",
repo_type="dataset",
filename=f"mgsm_{config_name}.tsv",
+ local_files_only=local_files_only,
)
spec = importlib.util.spec_from_file_location("mgsm_exemplars", exemplars_path)
@@ -769,8 +771,11 @@ def download_dataset_worker(
if not (_is_script_err or _is_offline):
raise
- if _is_script_err and task.dataset_path == "juletxara/mgsm":
- dataset = _load_mgsm_dataset(task.dataset_config_name)
+ if (_is_script_err or _is_offline) and task.dataset_path == "juletxara/mgsm":
+ dataset = _load_mgsm_dataset(
+ task.dataset_config_name,
+ local_files_only=_is_offline,
+ )
if task.dataset_filter is not None:
dataset = dataset.filter(task.dataset_filter)
return dataset # type: ignore
From 2ae0dbc5d601e4ddf6fe232bd56a383d6cbff7a0 Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Wed, 22 Apr 2026 22:36:35 +0200
Subject: [PATCH 3/8] fix bugs
---
.../models/endpoints/litellm_model.py | 46 +++++++------------
.../multilingual/tasks/mlmm_hellaswag.py | 1 +
src/lighteval/tasks/tasks/arc.py | 2 +
src/lighteval/tasks/tasks/hellaswag.py | 13 +++---
src/lighteval/tasks/tasks/jeopardy.py | 2 +-
src/lighteval/tasks/tasks/mmlu.py | 4 +-
src/lighteval/tasks/tasks/winogrande.py | 4 +-
7 files changed, 33 insertions(+), 39 deletions(-)
diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py
index f0036854d..65885d032 100644
--- a/src/lighteval/models/endpoints/litellm_model.py
+++ b/src/lighteval/models/endpoints/litellm_model.py
@@ -638,12 +638,13 @@ def _left_truncate_tokens(self, text: str, max_tokens: int) -> str:
keep_chars = max(1, int(len(text) * max_tokens / total))
return text[-keep_chars:]
- def _score_single(self, prompt: str, ctx_token_count: int) -> tuple[float, bool]:
+ def _score_single(self, prompt: str, choice_len: int) -> tuple[float, bool]:
"""Score one (context + choice) string via a single echo-based completions call.
Args:
prompt: Full text (context + choice) to score.
- ctx_token_count: Number of context tokens; continuation logprobs start here.
+ choice_len: Number of choice tokens; sliced from the right of the echoed logprobs,
+ before the one generated token. This approach is server-BOS-agnostic.
Returns:
(logprob_sum, is_greedy)
@@ -680,17 +681,12 @@ def _score_single(self, prompt: str, ctx_token_count: int) -> tuple[float, bool]
logprobs_obj = response.choices[0].logprobs
token_logprobs = logprobs_obj.token_logprobs or []
- # Infer choice token count from response: len = ctx + choice + 1 (generated token)
- # :-1 would be equivalent but fails when ctx_token_count is off by 1 due to BPE
- # boundary merges (e.g. "Answer:" + " Yes" -> "Answer:▁Yes" in joint tokenization),
- # producing an empty slice and logprob_sum = -inf, which poisons corpus BPB.
- start = ctx_token_count
- n_choice = len(token_logprobs) - 1 - ctx_token_count
- if n_choice <= 0 and ctx_token_count > 0:
- # ctx_token_count overestimates by 1 — shift start back by 1
- start -= 1
- n_choice = len(token_logprobs) - 1 - start
- end = start + max(n_choice, 0)
+ # Slice choice tokens from the right: the response is
+ # [BOS?] [context tokens...] [choice tokens...] [1 generated token]
+ # Anchoring from the right is BOS-agnostic and robust to any server-side
+ # special-token prepending.
+ end = len(token_logprobs) - 1 # exclude the one generated token
+ start = max(0, end - choice_len)
cont_logprobs = [lp for lp in token_logprobs[start:end] if lp is not None]
logprob_sum = sum(cont_logprobs) if cont_logprobs else float("-inf")
@@ -715,14 +711,14 @@ def _score_batch(
all pairs in the batch.
Args:
- batch: list of (doc_idx, choice_idx, full_text, ctx_token_count)
+ batch: list of (doc_idx, choice_idx, full_text, choice_len)
Returns:
list of (doc_idx, choice_idx, (logprob_sum, is_greedy))
"""
if len(batch) == 1:
- di, ci, full_text, ctx_len = batch[0]
- return [(di, ci, self._score_single(full_text, ctx_len))]
+ di, ci, full_text, choice_len = batch[0]
+ return [(di, ci, self._score_single(full_text, choice_len))]
prompts = [full_text for _, _, full_text, _ in batch]
@@ -761,7 +757,7 @@ def _score_batch(
choices_by_idx = {c.index: c for c in response.choices}
results = []
- for seq_idx, (di, ci, _, ctx_token_count) in enumerate(batch):
+ for seq_idx, (di, ci, _, choice_len) in enumerate(batch):
choice = choices_by_idx.get(seq_idx)
if choice is None:
logger.warning(
@@ -773,12 +769,8 @@ def _score_batch(
logprobs_obj = choice.logprobs
token_logprobs = logprobs_obj.token_logprobs or []
- start = ctx_token_count
- n_choice = len(token_logprobs) - 1 - ctx_token_count
- if n_choice <= 0 and ctx_token_count > 0:
- start -= 1
- n_choice = len(token_logprobs) - 1 - start
- end = start + max(n_choice, 0)
+ end = len(token_logprobs) - 1 # exclude the one generated token
+ start = max(0, end - choice_len)
cont_logprobs = [lp for lp in token_logprobs[start:end] if lp is not None]
logprob_sum = sum(cont_logprobs) if cont_logprobs else float("-inf")
@@ -846,11 +838,8 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]:
"""Tokenize a batch then score all pairs in one API call."""
prepared = []
for di, ci, context, choice in batch_items:
- # ctx_len = total - choice avoids the BPE boundary ±1 error when
- # tokenizing context alone (e.g. "Answer:" + " Yes" merges into
- # "Answer:▁Yes", making context_len overestimate by 1).
- total_len = self._count_tokens(context + choice)
choice_len = self._count_tokens(choice)
+ total_len = self._count_tokens(context + choice)
prompt = context + choice
# Left-truncate (OLMES-style): drop early few-shot examples when
# the combined prompt exceeds the model's context window.
@@ -862,8 +851,7 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]:
f"Prompt too long ({total_len} tokens > {max_input} max); left-truncating."
)
prompt = self._left_truncate_tokens(prompt, max_input)
- total_len = max_input
- prepared.append((di, ci, prompt, total_len - choice_len))
+ prepared.append((di, ci, prompt, choice_len))
return self._score_batch(prepared)
# Chunk work into batches — each batch becomes a single API call with a list
diff --git a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py
index bbd631278..dba27632c 100644
--- a/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py
+++ b/src/lighteval/tasks/multilingual/tasks/mlmm_hellaswag.py
@@ -91,6 +91,7 @@
def _hellaswag_adapter(line):
return {
+ "activity_label": line.get("activity_label", ""),
"ctx_a": line["ctx_a"],
"ctx_b": line["ctx_b"],
"continuations": line["endings"],
diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py
index 05ca9ae30..afad00779 100644
--- a/src/lighteval/tasks/tasks/arc.py
+++ b/src/lighteval/tasks/tasks/arc.py
@@ -74,6 +74,7 @@ def arc_mcf_prompt(line, task_name: str = None):
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select="random_sampling_from_train",
+ generation_size=-1,
metrics=_CF_METRICS,
stop_sequence=["\n"],
version=0,
@@ -88,6 +89,7 @@ def arc_mcf_prompt(line, task_name: str = None):
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select="random_sampling_from_train",
+ generation_size=-1,
metrics=_CF_METRICS,
stop_sequence=["\n"],
version=0,
diff --git a/src/lighteval/tasks/tasks/hellaswag.py b/src/lighteval/tasks/tasks/hellaswag.py
index 8a6ea4db0..14d9462fc 100644
--- a/src/lighteval/tasks/tasks/hellaswag.py
+++ b/src/lighteval/tasks/tasks/hellaswag.py
@@ -50,10 +50,12 @@ def harness_preprocess(text):
def hellaswag_mcf_prompt(line, task_name: str = None):
"""MCF variant: labeled options in prompt, score label tokens via logprobs."""
- query = "The following are multiple choice questions (with answers) about common sense.\n\n"
- query += f"Question: {line['activity_label']}: {line['ctx_a']} {line['ctx_b'].capitalize()}\n"
- query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["endings"])])
- query += "Answer:"
+ ctx = line["ctx_a"] + " " + line["ctx_b"].capitalize()
+ query = harness_preprocess(line["activity_label"] + ": " + ctx)
+ query += "".join(
+ [f"\n{key}. {harness_preprocess(choice)}" for key, choice in zip(ascii_uppercase, line["endings"])]
+ )
+ query += "\nAnswer:"
gold_ix = int(line["label"]) if line["label"] != "" else -1
return Doc(
@@ -61,7 +63,6 @@ def hellaswag_mcf_prompt(line, task_name: str = None):
query=query,
choices=[" " + i for i in ascii_uppercase[: len(line["endings"])]],
gold_index=gold_ix,
- instruction="The following are multiple choice questions (with answers) about common sense.\n\n",
)
def hellaswag_cf_prompt(line, task_name: str = None):
@@ -110,7 +111,6 @@ def hellaswag_cf_prompt(line, task_name: str = None):
stop_sequence=["\n"],
version=0,
)
-
# CF variant: completion-style, logprob on full answer text + BPB on gold choice
hellaswag_cf = LightevalTaskConfig(
name="hellaswag:cf",
@@ -121,6 +121,7 @@ def hellaswag_cf_prompt(line, task_name: str = None):
evaluation_splits=["validation"],
few_shots_split=None,
few_shots_select=None,
+ generation_size=-1,
metrics=_CF_METRICS,
stop_sequence=["\n"],
version=0,
diff --git a/src/lighteval/tasks/tasks/jeopardy.py b/src/lighteval/tasks/tasks/jeopardy.py
index de2dc0d41..67cf40360 100644
--- a/src/lighteval/tasks/tasks/jeopardy.py
+++ b/src/lighteval/tasks/tasks/jeopardy.py
@@ -94,7 +94,7 @@ def jeopardy_mc_mcf_prompt(line, task_name: str = None):
return Doc(
task_name=task_name,
query=f"Category: {category}\nQuestion: {question}\n{options}\nAnswer:",
- choices=labels,
+ choices=[" " + l for l in labels],
gold_index=gold,
)
diff --git a/src/lighteval/tasks/tasks/mmlu.py b/src/lighteval/tasks/tasks/mmlu.py
index a1472daed..dc31502c5 100644
--- a/src/lighteval/tasks/tasks/mmlu.py
+++ b/src/lighteval/tasks/tasks/mmlu.py
@@ -208,9 +208,11 @@ def mmlu_redux_cf_prompt(line, task_name: str = None):
else line["answer"]
)
+ query = f"The following are multiple choice questions about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}\nAnswer:"
+
return Doc(
task_name=task_name,
- query=f"The following are multiple choice questions about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}\nAnswer:",
+ query=query,
choices=[" " + c for c in line["choices"]],
gold_index=gold_ix,
fewshot_sorting_class=line["choices"][gold_ix],
diff --git a/src/lighteval/tasks/tasks/winogrande.py b/src/lighteval/tasks/tasks/winogrande.py
index 79b70fbde..0e5f7b361 100644
--- a/src/lighteval/tasks/tasks/winogrande.py
+++ b/src/lighteval/tasks/tasks/winogrande.py
@@ -46,8 +46,8 @@ def winogrande_cf_prompt(line, task_name: str = None):
end_of_target = end_of_target.strip()
return Doc(
task_name=task_name,
- query=query,
- choices=[f"{line['option1']} {end_of_target}", f"{line['option2']} {end_of_target}"],
+ query=query.rstrip(),
+ choices=[f" {line['option1']} {end_of_target}", f" {line['option2']} {end_of_target}"],
gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1,
)
From 7a2e52c6003472c7ccf1130d2d4461b54a2e0783 Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Fri, 1 May 2026 15:52:24 +0200
Subject: [PATCH 4/8] update ruler
---
src/lighteval/metrics/metrics_sample.py | 58 ++++--
src/lighteval/tasks/tasks/ruler.py | 196 ++++++++++++--------
tests/unit/metrics/test_ruler_metrics.py | 79 ++++++++
tests/unit/tasks/test_ruler.py | 226 +++++++++++++++++++++++
4 files changed, 467 insertions(+), 92 deletions(-)
create mode 100644 tests/unit/metrics/test_ruler_metrics.py
create mode 100644 tests/unit/tasks/test_ruler.py
diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py
index 4863d7239..08d932846 100644
--- a/src/lighteval/metrics/metrics_sample.py
+++ b/src/lighteval/metrics/metrics_sample.py
@@ -27,6 +27,7 @@
import inspect
import logging
import os
+import re
from abc import ABC, abstractmethod
from typing import Callable, Literal, Union
@@ -46,6 +47,7 @@
from lighteval.metrics.normalizations import (
LogProbNormalization,
LogProbTokenNorm,
+ helm_normalizer,
normalize_log_probs,
remove_braces,
remove_braces_and_strip,
@@ -1562,14 +1564,11 @@ class RulerStringMatchAll(SampleLevelComputation):
"""
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
- pred = (
- model_response.final_text[0].strip().lower()
- if model_response.final_text
- else ""
- )
+ pred = model_response.final_text[0] if model_response.final_text else ""
golds = doc.get_golds()
if not golds:
return 0.0
+ pred = pred.strip().lower()
return sum(g.lower() in pred for g in golds) / len(golds)
@@ -1582,15 +1581,38 @@ class RulerStringMatchAny(SampleLevelComputation):
"""
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
- pred = (
- model_response.final_text[0].strip().lower()
- if model_response.final_text
- else ""
- )
+ pred = model_response.final_text[0] if model_response.final_text else ""
golds = doc.get_golds()
if not golds:
return 0.0
- return float(any(g.lower() in pred for g in golds))
+ return float(self._has_match(pred, golds))
+
+ @staticmethod
+ def _parse_output(output: str, prefix: str = "Answer:") -> str | None:
+ def _lstrip_prefix(text: str, sub: str) -> str:
+ return re.sub(f"^{re.escape(sub)}", "", text, flags=re.IGNORECASE)
+
+ patterns = [
+ re.compile(f"(?:{re.escape(prefix)})(.*)(?:\\n|$)", flags=re.IGNORECASE),
+ re.compile(r"(?:^)(.*)(?:\n|$)"),
+ ]
+ for pattern in patterns:
+ match = pattern.search(output)
+ if match is not None:
+ return _lstrip_prefix(match[1].strip(), prefix).strip()
+ return None
+
+ @staticmethod
+ def _normalized_substring_match(pred: str, gold: str) -> bool:
+ return helm_normalizer(gold) in helm_normalizer(pred)
+
+ @classmethod
+ def _has_match(cls, pred: str, golds: list[str]) -> bool:
+ parsed_pred = cls._parse_output(pred)
+ candidates = [pred]
+ if parsed_pred is not None:
+ candidates.append(parsed_pred)
+ return any(cls._normalized_substring_match(candidate, gold) for candidate in candidates for gold in golds)
class RulerStringMatch(SampleLevelComputation):
@@ -1601,15 +1623,17 @@ class RulerStringMatch(SampleLevelComputation):
"""
def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> float:
- pred = (
- model_response.final_text[0].strip().lower()
- if model_response.final_text
- else ""
- )
+ pred = model_response.final_text[0] if model_response.final_text else ""
+ # Prepend gen_prefix (answer prefix) so scoring matches olmo's behaviour:
+ # olmo does: output = doc["prepend_text"] + output
+ gen_prefix = (doc.specific or {}).get("gen_prefix", "")
+ if gen_prefix:
+ pred = gen_prefix + pred
golds = doc.get_golds()
if not golds:
return 0.0
task_name = doc.task_name or ""
if ":qa_" in task_name:
- return float(any(g.lower() in pred for g in golds))
+ return float(RulerStringMatchAny._has_match(pred, golds))
+ pred = pred.strip().lower()
return sum(g.lower() in pred for g in golds) / len(golds)
diff --git a/src/lighteval/tasks/tasks/ruler.py b/src/lighteval/tasks/tasks/ruler.py
index 8f141ab12..7b6648861 100644
--- a/src/lighteval/tasks/tasks/ruler.py
+++ b/src/lighteval/tasks/tasks/ruler.py
@@ -90,7 +90,8 @@
DEFAULT_LENGTHS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
-NUM_SAMPLES = 500
+# NUM_SAMPLES = 500
+NUM_SAMPLES = 100
RANDOM_SEED = 42
# ---------------------------------------------------------------------------
@@ -147,7 +148,13 @@ def _get_cache_dir(tokenizer_path: str) -> Path:
# ---------------------------------------------------------------------------
NIAH_NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
-NIAH_TEMPLATE = (
+NIAH_TEMPLATE_SINGLE = (
+ "A special magic {type_needle_v} is hidden within the following text. "
+ "Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n"
+ "{context}\n"
+ "What is the special magic {type_needle_v} for {query} mentioned in the provided text?"
+)
+NIAH_TEMPLATE_MULTI = (
"Some special magic {type_needle_v} are hidden within the following text. "
"Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n"
"{context}\n"
@@ -321,7 +328,7 @@ def _niah_generate_samples(
type_needle_v: str,
template: str,
num_samples: int = NUM_SAMPLES,
- tokens_to_generate: int = 128,
+ tokens_to_generate: int = 50,
num_needle_v: int = 1,
num_needle_k: int = 1,
num_needle_q: int = 1,
@@ -361,7 +368,7 @@ def _gen_prefix(tnv, nq, nv, query):
random_seed=random_seed,
)
gen_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, query)
- prompt = input_text + " " + gen_prefix
+ prompt = _build_runtime_prompt(input_text, gen_prefix)
total_tokens = len(tokenizer(prompt).input_ids)
if total_tokens + tokens_to_generate > budget:
num_haystack -= incremental
@@ -400,7 +407,7 @@ def _gen_prefix(tnv, nq, nv, query):
random_seed=sample_seed,
)
gen_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, query)
- prompt = input_text + " " + gen_prefix
+ prompt = _build_runtime_prompt(input_text, gen_prefix)
length = len(tokenizer(prompt).input_ids) + tokens_to_generate
assert length <= budget
break
@@ -434,7 +441,7 @@ def _gen_prefix(tnv, nq, nv, query):
# ---------------------------------------------------------------------------
VT_CONFIG = {
- "tokens_to_generate": 30,
+ "tokens_to_generate": 50,
"template": (
"Memorize and track the chain(s) of variable assignment hidden in the "
"following text.\n\n{context}\n"
@@ -446,6 +453,7 @@ def _gen_prefix(tnv, nq, nv, query):
),
}
VT_TEMPLATE = VT_CONFIG["template"] + VT_CONFIG["answer_prefix"]
+VT_ANSWER_PREFIX_BASE = " Answer: According to the chain(s) of variable assignment"
def _vt_generate_chains(
@@ -506,6 +514,47 @@ def _vt_randomize_icl(icl_example: str) -> str:
return icl_example
+def _split_prompt_on_answer_prefix(
+ prompt: str, answer_prefix: str, *, strip_input: bool = False
+) -> tuple[str, str]:
+ prefix_index = prompt.rfind(answer_prefix)
+ if prefix_index == -1:
+ raise ValueError("Answer prefix not found in prompt")
+
+ input_text = prompt[:prefix_index]
+ if strip_input:
+ input_text = input_text.strip()
+ gen_prefix = prompt[prefix_index:].strip()
+ return input_text, gen_prefix
+
+
+def _runtime_prompt_budget_length(
+ tokenizer, input_text: str, gen_prefix: str, tokens_to_generate: int
+) -> int:
+ prompt = _build_runtime_prompt(input_text, gen_prefix)
+ return len(tokenizer(prompt).input_ids) + tokens_to_generate
+
+
+def _vt_build_cached_sample(input_text: str, icl_prompt: str) -> tuple[str, str]:
+ cutoff = input_text.index(VT_CONFIG["template"][:20])
+ full_prompt = input_text[:cutoff] + icl_prompt + "\n\n" + input_text[cutoff:]
+ return _split_prompt_on_answer_prefix(full_prompt, VT_ANSWER_PREFIX_BASE)
+
+
+def _cwe_build_cached_sample(input_example: str, input_text: str) -> tuple[str, str]:
+ input_text_clean, gen_prefix = _split_prompt_on_answer_prefix(
+ input_text, CWE_CONFIG["answer_prefix"]
+ )
+ full_input = (input_example + "\n" + input_text_clean).strip()
+ return full_input, gen_prefix
+
+
+def _fwe_build_cached_sample(input_text: str) -> tuple[str, str]:
+ return _split_prompt_on_answer_prefix(
+ input_text, FWE_CONFIG["answer_prefix"], strip_input=True
+ )
+
+
def _vt_generate_samples(
tokenizer,
max_seq_length: int,
@@ -513,7 +562,7 @@ def _vt_generate_samples(
incremental: int = 10,
num_chains: int = 1,
num_hops: int = 4,
- tokens_to_generate: int = 30,
+ tokens_to_generate: int = VT_CONFIG["tokens_to_generate"],
) -> list[dict]:
budget = max_seq_length
@@ -542,15 +591,16 @@ def _vt_generate_samples(
icl_str = (
icl_text_raw + " " + icl_out
) # full ICL text: question + gen_prefix + answers
- example_tokens = len(tokenizer(icl_str + "\n\n").input_ids)
-
# --- Step 2: Find num_noises for main examples ---
num_noises = incremental
total_tokens = 0
- while total_tokens + tokens_to_generate + example_tokens < budget:
+ while total_tokens + tokens_to_generate < budget:
input_text, _ = _vt_generate_input_output(num_noises, num_chains, num_hops)
- total_tokens = len(tokenizer(input_text).input_ids)
- if total_tokens + tokens_to_generate + example_tokens > budget:
+ cached_input, cached_gen_prefix = _vt_build_cached_sample(input_text, icl_str)
+ total_tokens = len(
+ tokenizer(_build_runtime_prompt(cached_input, cached_gen_prefix)).input_ids
+ )
+ if total_tokens + tokens_to_generate > budget:
num_noises -= incremental
break
num_noises += incremental
@@ -571,10 +621,11 @@ def _vt_generate_samples(
input_text, answer = _vt_generate_input_output(
used_noises, num_chains, num_hops
)
- length = (
- len(tokenizer(input_text).input_ids)
- + tokens_to_generate
- + example_tokens
+ cached_input, cached_gen_prefix = _vt_build_cached_sample(
+ input_text, _vt_randomize_icl(icl_str)
+ )
+ length = _runtime_prompt_budget_length(
+ tokenizer, cached_input, cached_gen_prefix, tokens_to_generate
)
assert length <= budget
break
@@ -587,30 +638,14 @@ def _vt_generate_samples(
if answer is None:
continue
- # Insert ICL example between any model template prefix and the task template
- cutoff = input_text.index(VT_CONFIG["template"][:20])
- input_text = (
- input_text[:cutoff]
- + _vt_randomize_icl(icl_str)
- + "\n\n"
- + input_text[cutoff:]
- )
-
- # Split off gen_prefix (the answer_prefix at the end of VT_TEMPLATE)
- gen_prefix_index = input_text.rfind(
- " Answer: According to the chain(s) of variable assignment"
- )
- gen_prefix = input_text[gen_prefix_index:].strip()
- input_text = input_text[:gen_prefix_index]
-
write_jsons.append(
{
"index": index,
- "input": input_text,
+ "input": cached_input,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
- "gen_prefix": gen_prefix,
+ "gen_prefix": cached_gen_prefix,
}
)
@@ -624,7 +659,7 @@ def _vt_generate_samples(
# ---------------------------------------------------------------------------
CWE_CONFIG = {
- "tokens_to_generate": 120,
+ "tokens_to_generate": 100,
"template": (
"Below is a numbered list of words. In these words, some appear more often than others. "
"Memorize the ones that appear most often.\n{context}\n"
@@ -692,7 +727,7 @@ def _cwe_generate_samples(
max_seq_length: int,
num_samples: int = NUM_SAMPLES,
incremental: int = 10,
- tokens_to_generate: int = 120,
+ tokens_to_generate: int = CWE_CONFIG["tokens_to_generate"],
) -> list[dict]:
words = _get_cwe_words()
budget = max_seq_length
@@ -703,14 +738,9 @@ def _cwe_generate_samples(
input_example, input_text, answer = _cwe_generate_input_output(
num_words, max_seq_length, words
)
+ full_input, gen_prefix = _cwe_build_cached_sample(input_example, input_text)
total_tokens = len(
- tokenizer(
- input_example
- + "\n"
- + input_text
- + " "
- + " ".join(f"{i+1}. {w}" for i, w in enumerate(answer))
- ).input_ids
+ tokenizer(_build_runtime_prompt(full_input, gen_prefix)).input_ids
)
if total_tokens + tokens_to_generate > budget:
num_words -= incremental
@@ -737,7 +767,12 @@ def _cwe_generate_samples(
input_example, input_text, answer = _cwe_generate_input_output(
used_words, max_seq_length, words
)
- length = len(tokenizer(input_text).input_ids) + tokens_to_generate
+ full_input, gen_prefix = _cwe_build_cached_sample(
+ input_example, input_text
+ )
+ length = _runtime_prompt_budget_length(
+ tokenizer, full_input, gen_prefix, tokens_to_generate
+ )
assert length <= budget
break
except Exception:
@@ -749,13 +784,10 @@ def _cwe_generate_samples(
if answer is None:
continue
- gen_prefix_idx = input_text.rfind(CWE_CONFIG["answer_prefix"])
- gen_prefix = input_text[gen_prefix_idx:].strip()
- input_text = input_text[:gen_prefix_idx]
write_jsons.append(
{
"index": index,
- "input": input_text.strip(),
+ "input": full_input,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
@@ -776,7 +808,7 @@ def _cwe_generate_samples(
"tokens_to_generate": 50,
"template": (
"Read the following coded text and track the frequency of each coded word. "
- "Find the three most frequently appeared coded words. {context}\n"
+ "Find the three most frequently appeared coded words.\n{context}\n"
"Question: Do not provide any explanation. Please ignore the dots '....'. "
"What are the three most frequently appeared words in the above coded text?"
),
@@ -875,14 +907,12 @@ def _fwe_generate_samples(
alpha=alpha,
sample_seed=sample_seed,
)
- length = len(tokenizer(input_text).input_ids) + tokens_to_generate
+ input_text_clean, gen_prefix = _fwe_build_cached_sample(input_text)
+ length = _runtime_prompt_budget_length(
+ tokenizer, input_text_clean, gen_prefix, tokens_to_generate
+ )
assert length <= budget
- # Strip answer prefix from input
- ans_prefix_idx = input_text.rfind(FWE_CONFIG["answer_prefix"])
- gen_prefix = input_text[ans_prefix_idx:].strip()
- input_text_clean = input_text[:ans_prefix_idx].strip()
-
write_jsons.append(
{
"index": index,
@@ -904,7 +934,7 @@ def _fwe_generate_samples(
# ---------------------------------------------------------------------------
QA_CONFIG = {
- "tokens_to_generate": 32,
+ "tokens_to_generate": 50,
"template": (
"Answer the question based on the given documents. "
"Only give me the answer and do not output any other words.\n\n"
@@ -1033,7 +1063,7 @@ def _qa_generate_samples(
qas: list[dict],
max_seq_length: int,
num_samples: int = NUM_SAMPLES,
- tokens_to_generate: int = 32,
+ tokens_to_generate: int = QA_CONFIG["tokens_to_generate"],
incremental: int = 10,
) -> list[dict]:
budget = max_seq_length
@@ -1043,7 +1073,7 @@ def _qa_generate_samples(
total_tokens = 0
while total_tokens + tokens_to_generate < budget:
input_text, _ = _qa_generate_input_output(0, num_docs, qas=qas, docs=docs)
- prompt = input_text + " " + gen_prefix
+ prompt = _build_runtime_prompt(input_text, gen_prefix)
total_tokens = len(tokenizer(prompt).input_ids)
if total_tokens + tokens_to_generate > budget:
num_docs -= incremental
@@ -1069,7 +1099,7 @@ def _qa_generate_samples(
input_text, answer = _qa_generate_input_output(
index, used_docs, qas=qas, docs=docs
)
- prompt = input_text + " " + gen_prefix
+ prompt = _build_runtime_prompt(input_text, gen_prefix)
length = len(tokenizer(prompt).input_ids) + tokens_to_generate
assert length <= budget
break
@@ -1110,7 +1140,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_SINGLE,
)
elif subset == "niah_single_2":
_ensure_nltk()
@@ -1121,7 +1151,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_SINGLE,
)
elif subset == "niah_single_3":
_ensure_nltk()
@@ -1132,7 +1162,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_SINGLE,
)
elif subset == "niah_multikey_1":
_ensure_nltk()
@@ -1144,7 +1174,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_SINGLE,
)
elif subset == "niah_multikey_2":
return _niah_generate_samples(
@@ -1155,7 +1185,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_SINGLE,
)
elif subset == "niah_multikey_3":
return _niah_generate_samples(
@@ -1166,7 +1196,8 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_needle_k="uuids",
type_needle_v="uuids",
num_needle_k=4,
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_SINGLE,
+ tokens_to_generate=100,
)
elif subset == "niah_multiquery":
_ensure_nltk()
@@ -1178,7 +1209,8 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_MULTI,
+ tokens_to_generate=100,
)
elif subset == "niah_multivalue":
_ensure_nltk()
@@ -1190,7 +1222,8 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
- template=NIAH_TEMPLATE,
+ template=NIAH_TEMPLATE_MULTI,
+ tokens_to_generate=300 if length == 4096 else 250,
)
elif subset == "vt":
return _vt_generate_samples(tokenizer, max_seq_length=length)
@@ -1272,19 +1305,27 @@ def ensure_ruler_cache(
# ---------------------------------------------------------------------------
+def _build_runtime_prompt(input_text: str, gen_prefix: str = "") -> str:
+ """Build the exact prompt text used at evaluation time.
+
+ RULER caches `input` and `gen_prefix` separately, then `ruler_prompt`
+ joins them with a newline before sending the prompt to the model.
+ """
+ if gen_prefix:
+ return input_text + "\n" + gen_prefix
+ return input_text
+
+
def ruler_prompt(line: dict, task_name: str = None) -> Doc:
outputs = line["outputs"]
- # Append gen_prefix so the model receives the completion-eliciting prefix
- # (matches lm-eval-harness YAML: doc_to_text="{{input}}" + gen_prefix="{{gen_prefix}}")
- query = line["input"]
- gp = line.get("gen_prefix", "")
- if gp:
- query = query + " " + gp
+ gen_prefix = line.get("gen_prefix", "")
+ query = _build_runtime_prompt(line["input"], gen_prefix)
return Doc(
query=query,
choices=outputs,
gold_index=list(range(len(outputs))),
task_name=task_name,
+ specific={"gen_prefix": gen_prefix} if gen_prefix else None,
)
@@ -1315,11 +1356,16 @@ def get_ruler_tasks(
# subset never triggers generation of all subsets.
cache_path = cache_base / str(length) / subset
metric = [Metrics.ruler_match]
+ _niah_gen_sizes = {
+ "niah_multikey_3": 100,
+ "niah_multiquery": 100,
+ "niah_multivalue": 300 if length == 4096 else 250,
+ }
gen_size = (
- 128
+ _niah_gen_sizes.get(subset, 50)
if "niah" in subset
else (
- 30 if subset == "vt" else 120 if subset == "cwe" else 50
+ 50 if subset == "vt" else 100 if subset == "cwe" else 50
) # fwe and qa
)
tasks.append(
diff --git a/tests/unit/metrics/test_ruler_metrics.py b/tests/unit/metrics/test_ruler_metrics.py
new file mode 100644
index 000000000..0817057dd
--- /dev/null
+++ b/tests/unit/metrics/test_ruler_metrics.py
@@ -0,0 +1,79 @@
+# MIT License
+#
+# Copyright (c) 2024 The HuggingFace Team
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import pytest
+
+from lighteval.metrics.metrics_sample import RulerStringMatch, RulerStringMatchAny
+from lighteval.models.model_output import ModelResponse
+from lighteval.tasks.requests import Doc
+
+
+def make_doc(golds: list[str], task_name: str) -> Doc:
+ return Doc(query="", choices=golds, gold_index=list(range(len(golds))), task_name=task_name)
+
+
+@pytest.mark.parametrize(
+ ("prediction", "golds"),
+ [
+ ("eiffel tower", ["The Eiffel Tower"]),
+ ("Answer: eiffel tower", ["The Eiffel Tower"]),
+ ("ANSWER: the eiffel tower", ["eiffel tower"]),
+ ("the answer is london", ["Paris", "London"]),
+ ],
+)
+def test_ruler_qa_match_uses_normalized_substring_semantics(prediction, golds):
+ metric = RulerStringMatch()
+ doc = make_doc(golds, "ruler_512:qa_2|0")
+
+ assert metric.compute(doc, ModelResponse(text=[prediction])) == 1.0
+
+
+def test_ruler_qa_parses_answer_prefix():
+ assert RulerStringMatchAny._parse_output("Answer: Paris\nExtra") == "Paris"
+
+
+def test_ruler_qa_returns_zero_when_no_gold_matches():
+ metric = RulerStringMatch()
+ doc = make_doc(["Paris"], "ruler_512:qa_1|0")
+
+ assert metric.compute(doc, ModelResponse(text=["London"])) == 0.0
+
+
+def test_ruler_non_qa_full_recall_scores_one():
+ metric = RulerStringMatch()
+ doc = make_doc(["alpha", "beta"], "ruler_512:cwe|0")
+
+ assert metric.compute(doc, ModelResponse(text=["alpha beta gamma"])) == 1.0
+
+
+def test_ruler_non_qa_partial_recall_scores_fraction():
+ metric = RulerStringMatch()
+ doc = make_doc(["alpha", "beta", "gamma"], "ruler_512:vt|0")
+
+ assert metric.compute(doc, ModelResponse(text=["alpha gamma"])) == pytest.approx(2 / 3)
+
+
+def test_ruler_non_qa_empty_prediction_scores_zero():
+ metric = RulerStringMatch()
+ doc = make_doc(["alpha", "beta"], "ruler_512:fwe|0")
+
+ assert metric.compute(doc, ModelResponse(text=[""])) == 0.0
diff --git a/tests/unit/tasks/test_ruler.py b/tests/unit/tasks/test_ruler.py
new file mode 100644
index 000000000..8f65bad4c
--- /dev/null
+++ b/tests/unit/tasks/test_ruler.py
@@ -0,0 +1,226 @@
+# MIT License
+#
+# Copyright (c) 2024 The HuggingFace Team
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+from lighteval.tasks.tasks import ruler
+
+
+def test_ruler_prompt_appends_generation_prefix():
+ doc = ruler.ruler_prompt(
+ {
+ "input": "Question text",
+ "gen_prefix": "Answer:",
+ "outputs": ["Paris"],
+ },
+ "ruler_512:qa_1|0",
+ )
+
+ assert doc.query == "Question text\nAnswer:"
+ assert doc.choices == ["Paris"]
+ assert doc.gold_index == [0]
+ assert doc.task_name == "ruler_512:qa_1|0"
+
+
+def test_runtime_prompt_helper_matches_ruler_prompt():
+ line = {
+ "input": "Question text",
+ "gen_prefix": "Answer:",
+ "outputs": ["Paris"],
+ }
+ doc = ruler.ruler_prompt(line, "ruler_512:qa_1|0")
+
+ assert ruler._build_runtime_prompt(line["input"], line["gen_prefix"]) == doc.query
+
+
+def test_ruler_generation_sizes_match_reference_defaults():
+ tasks = ruler.get_ruler_tasks(
+ "dummy-tokenizer",
+ lengths=[512],
+ subsets=["niah_single_1", "vt", "cwe", "fwe", "qa_2"],
+ )
+ task_map = {task.name: task.generation_size for task in tasks}
+
+ assert task_map["ruler_512:niah_single_1"] == 50
+ assert task_map["ruler_512:vt"] == 50
+ assert task_map["ruler_512:cwe"] == 100
+ assert task_map["ruler_512:fwe"] == 50
+ assert task_map["ruler_512:qa_2"] == 50
+
+
+def test_ruler_qa_prompt_matches_reference_template():
+ prompt, answers = ruler._qa_generate_input_output(
+ 0,
+ 1,
+ qas=[{"query": "Where is the Eiffel Tower?", "outputs": ["Paris"], "context": [0]}],
+ docs=["Paris is the capital of France."],
+ )
+
+ assert prompt == (
+ "Answer the question based on the given documents. "
+ "Only give me the answer and do not output any other words.\n\n"
+ "The following are given documents.\n\n"
+ "Document 1:\nParis is the capital of France.\n\n"
+ "Answer the question based on the given documents. "
+ "Only give me the answer and do not output any other words.\n\n"
+ "Question: Where is the Eiffel Tower?"
+ )
+ assert answers == ["Paris"]
+
+
+def test_ruler_uses_singular_niah_template_for_single_tasks(monkeypatch):
+ captured = {}
+
+ def fake_niah_generate_samples(*args, **kwargs):
+ captured["template"] = kwargs["template"]
+ return []
+
+ monkeypatch.setattr(ruler, "_niah_generate_samples", fake_niah_generate_samples)
+
+ ruler._generate_subset("niah_single_1", 128, tokenizer=None)
+
+ assert captured["template"] == ruler.NIAH_TEMPLATE_SINGLE
+
+
+def test_ruler_uses_plural_niah_template_for_multi_tasks(monkeypatch):
+ captured = {}
+
+ def fake_niah_generate_samples(*args, **kwargs):
+ captured["template"] = kwargs["template"]
+ return []
+
+ monkeypatch.setattr(ruler, "_niah_generate_samples", fake_niah_generate_samples)
+ monkeypatch.setattr(ruler, "_ensure_nltk", lambda: None)
+ monkeypatch.setattr(ruler, "_get_essay_haystack", lambda: ["one", "two"])
+
+ ruler._generate_subset("niah_multivalue", 128, tokenizer=None)
+
+ assert captured["template"] == ruler.NIAH_TEMPLATE_MULTI
+
+
+def test_ruler_budget_helper_uses_newline_before_generation_prefix():
+ prompt = ruler._build_runtime_prompt("Question text", "Answer:")
+
+ assert prompt == "Question text\nAnswer:"
+
+
+class _FakeTokenizer:
+ def __call__(self, text):
+ return type("Tokens", (), {"input_ids": list(range(len(text)))})()
+
+
+def test_qa_length_matches_runtime_prompt_shape():
+ input_text = "Question text"
+ gen_prefix = "Answer:"
+ tokens_to_generate = ruler.QA_CONFIG["tokens_to_generate"]
+
+ expected = len(_FakeTokenizer()(ruler._build_runtime_prompt(input_text, gen_prefix)).input_ids) + tokens_to_generate
+
+ assert expected == len("Question text\nAnswer:") + tokens_to_generate
+
+
+def test_cwe_full_prompt_length_includes_icl_example():
+ full_input = "Example block\nTask text"
+ gen_prefix = "Answer: The top 10 words that appear most often in the list are:"
+ tokens_to_generate = ruler.CWE_CONFIG["tokens_to_generate"]
+
+ expected = len(_FakeTokenizer()(ruler._build_runtime_prompt(full_input, gen_prefix)).input_ids) + tokens_to_generate
+
+ assert expected == len(full_input + "\n" + gen_prefix) + tokens_to_generate
+
+
+def test_vt_cached_sample_uses_final_runtime_prompt_shape():
+ task_prompt = ruler.VT_TEMPLATE.format(
+ context="VAR A = 12345",
+ query="12345",
+ num_v=5,
+ )
+ cached_input, gen_prefix = ruler._vt_build_cached_sample(task_prompt, "ICL EXAMPLE")
+ length = ruler._runtime_prompt_budget_length(
+ _FakeTokenizer(),
+ cached_input,
+ gen_prefix,
+ ruler.VT_CONFIG["tokens_to_generate"],
+ )
+
+ assert cached_input == (
+ "ICL EXAMPLE\n\n"
+ "Memorize and track the chain(s) of variable assignment hidden in the "
+ "following text.\n\nVAR A = 12345\n"
+ "Question: Find all variables that are assigned the value 12345 in the text above."
+ )
+ assert gen_prefix == (
+ "Answer: According to the chain(s) of variable assignment in the text above, "
+ "5 variables are assigned the value 12345, they are:"
+ )
+ assert length == len(cached_input + "\n" + gen_prefix) + ruler.VT_CONFIG["tokens_to_generate"]
+
+
+def test_fwe_length_matches_runtime_prompt_shape():
+ full_prompt = (
+ "Read the following coded text and track the frequency of each coded word. "
+ "Find the three most frequently appeared coded words.\nalpha beta gamma\n"
+ "Question: Do not provide any explanation. Please ignore the dots '....'. "
+ "What are the three most frequently appeared words in the above coded text?"
+ " Answer: According to the coded text above, the three most frequently appeared words are:"
+ )
+ input_text, gen_prefix = ruler._fwe_build_cached_sample(full_prompt)
+ length = ruler._runtime_prompt_budget_length(
+ _FakeTokenizer(),
+ input_text,
+ gen_prefix,
+ ruler.FWE_CONFIG["tokens_to_generate"],
+ )
+
+ assert length == len(input_text + "\n" + gen_prefix) + ruler.FWE_CONFIG["tokens_to_generate"]
+
+
+def test_non_qa_cached_lengths_match_retokenized_runtime_prompts():
+ samples = [
+ {
+ "input": "ICL EXAMPLE\n\nTask block",
+ "gen_prefix": "Answer: According to the chain(s) of variable assignment in the text above, 5 variables are assigned the value 12345, they are:",
+ "tokens_to_generate": ruler.VT_CONFIG["tokens_to_generate"],
+ "max_length": 500,
+ },
+ {
+ "input": "Example block\nTask text",
+ "gen_prefix": "Answer: The top 10 words that appear most often in the list are:",
+ "tokens_to_generate": ruler.CWE_CONFIG["tokens_to_generate"],
+ "max_length": 500,
+ },
+ {
+ "input": "Task text",
+ "gen_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:",
+ "tokens_to_generate": ruler.FWE_CONFIG["tokens_to_generate"],
+ "max_length": 500,
+ },
+ ]
+
+ for sample in samples:
+ length = ruler._runtime_prompt_budget_length(
+ _FakeTokenizer(),
+ sample["input"],
+ sample["gen_prefix"],
+ sample["tokens_to_generate"],
+ )
+
+ assert length == len(sample["input"] + "\n" + sample["gen_prefix"]) + sample["tokens_to_generate"]
+ assert length <= sample["max_length"]
From c627beb22a49833b5d23abe0e0b21399bcd6393b Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Mon, 11 May 2026 15:30:35 +0200
Subject: [PATCH 5/8] fix stuff
---
.../models/endpoints/litellm_model.py | 24 ++--
src/lighteval/tasks/lighteval_task.py | 22 ++--
src/lighteval/tasks/tasks/arc.py | 2 +-
src/lighteval/tasks/tasks/coqa.py | 64 ++++++-----
src/lighteval/tasks/tasks/drop_qa.py | 4 +-
src/lighteval/tasks/tasks/hellaswag.py | 9 +-
src/lighteval/tasks/tasks/humaneval.py | 16 ++-
src/lighteval/tasks/tasks/math.py | 3 +-
src/lighteval/tasks/tasks/mbpp.py | 20 ++--
src/lighteval/tasks/tasks/med.py | 10 +-
src/lighteval/tasks/tasks/mmlu.py | 2 +-
src/lighteval/tasks/tasks/mt_mbpp.py | 13 ++-
src/lighteval/tasks/tasks/winogrande.py | 108 +++++++++++++-----
13 files changed, 196 insertions(+), 101 deletions(-)
diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py
index 65885d032..3cf03e31f 100644
--- a/src/lighteval/models/endpoints/litellm_model.py
+++ b/src/lighteval/models/endpoints/litellm_model.py
@@ -827,19 +827,29 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
[(0.0, False)] * len(doc.choices) for doc in doc_list
]
- # Flat work list: (doc_idx, choice_idx, context, choice)
+ # Pre-compute token lengths once per context (not 4× per context via choices)
+ # to avoid redundant BPE tokenization in the thread pool.
+ context_lens: list[int] = [self._count_tokens(ctx) for ctx in context_list]
+
+ # Pre-compute choice lengths once per unique choice text.
+ _unique_choices: dict[str, int] = {}
+ for doc in doc_list:
+ for choice in doc.choices:
+ if choice not in _unique_choices:
+ _unique_choices[choice] = self._count_tokens(choice)
+
+ # Flat work list: (doc_idx, choice_idx, context, choice, context_len, choice_len)
work = [
- (di, ci, context_list[di], choice)
+ (di, ci, context_list[di], choice, context_lens[di], _unique_choices[choice])
for di, doc in enumerate(doc_list)
for ci, choice in enumerate(doc.choices)
]
def _score_batch_work(batch_items: list[tuple]) -> list[tuple]:
- """Tokenize a batch then score all pairs in one API call."""
+ """Score a batch of (context+choice) pairs in a single batched echo call."""
prepared = []
- for di, ci, context, choice in batch_items:
- choice_len = self._count_tokens(choice)
- total_len = self._count_tokens(context + choice)
+ for di, ci, context, choice, ctx_len, ch_len in batch_items:
+ total_len = ctx_len + ch_len
prompt = context + choice
# Left-truncate (OLMES-style): drop early few-shot examples when
# the combined prompt exceeds the model's context window.
@@ -851,7 +861,7 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]:
f"Prompt too long ({total_len} tokens > {max_input} max); left-truncating."
)
prompt = self._left_truncate_tokens(prompt, max_input)
- prepared.append((di, ci, prompt, choice_len))
+ prepared.append((di, ci, prompt, ch_len))
return self._score_batch(prepared)
# Chunk work into batches — each batch becomes a single API call with a list
diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py
index ddbccb1ad..370e3a249 100644
--- a/src/lighteval/tasks/lighteval_task.py
+++ b/src/lighteval/tasks/lighteval_task.py
@@ -577,14 +577,20 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
if doc is None or doc == []:
continue
- doc.id = str(ix)
-
- # Transfer task-level generation parameters to the document
- doc.generation_grammar = self.generation_grammar
- doc.generation_size = self.generation_size
- doc.stop_sequences = self.stop_sequence
-
- docs.append(doc)
+ # Support multi-doc returns (list of Docs per row, e.g. multi-turn tasks)
+ if isinstance(doc, list):
+ for sub_ix, sub_doc in enumerate(doc):
+ sub_doc.id = f"{ix}_{sub_ix}"
+ sub_doc.generation_grammar = self.generation_grammar
+ sub_doc.generation_size = self.generation_size
+ sub_doc.stop_sequences = self.stop_sequence
+ docs.append(sub_doc)
+ else:
+ doc.id = str(ix)
+ doc.generation_grammar = self.generation_grammar
+ doc.generation_size = self.generation_size
+ doc.stop_sequences = self.stop_sequence
+ docs.append(doc)
return docs
diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py
index afad00779..13bd1ae1a 100644
--- a/src/lighteval/tasks/tasks/arc.py
+++ b/src/lighteval/tasks/tasks/arc.py
@@ -52,7 +52,7 @@ def arc_prompt(line, task_name: str = None):
def arc_mcf_prompt(line, task_name: str = None):
query = f"Question: {line['question']}\n"
- query += "".join([f"{key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"]["text"])])
+ query += "".join([f" {key}. {choice}\n" for key, choice in zip(ascii_uppercase, line["choices"]["text"])])
query += "Answer:"
gold_ix = line["choices"]["label"].index(line["answerKey"])
diff --git a/src/lighteval/tasks/tasks/coqa.py b/src/lighteval/tasks/tasks/coqa.py
index da59985ed..f1bd3bad6 100644
--- a/src/lighteval/tasks/tasks/coqa.py
+++ b/src/lighteval/tasks/tasks/coqa.py
@@ -43,45 +43,55 @@ def _build_query(story: str, questions: list, answers: list, turn_idx: int) -> s
def coqa_gen_prompt(line, task_name: str = None):
- """GenQA: evaluate all turns, build full preceding-QA context per turn."""
+ """GenQA: evaluate ALL turns, build full preceding-QA context per turn.
+
+ Returns a list of Docs (one per turn), matching OLMO's CoQA._process_doc_to_multi.
+ The lighteval_task.py extension (list-of-Docs support) handles multi-doc rows.
+ """
questions = line["questions"]["input_text"]
answers = line["answers"]["input_text"]
if not questions or not answers:
return None
- # Evaluate all turns; return last turn with full context (lighteval calls
- # prompt_function once per row — we evaluate first turn as primary instance).
- # For full multi-turn eval, the dataset would need to be pre-flattened.
- turn_idx = 0
- gold_text = answers[turn_idx]
- if not gold_text:
- return None
is_few_shots = line.get("__few_shots", False)
- return Doc(
- task_name=task_name,
- query=_build_query(line["story"], questions, answers, turn_idx),
- choices=[f"{' ' if is_few_shots else ''}{gold_text}"],
- gold_index=0,
- )
+ docs = []
+ for turn_idx in range(len(questions)):
+ gold_text = answers[turn_idx]
+ if not gold_text:
+ continue
+ docs.append(Doc(
+ task_name=task_name,
+ query=_build_query(line["story"], questions, answers, turn_idx),
+ choices=[f"{' ' if is_few_shots else ''}{gold_text}"],
+ gold_index=0,
+ ))
+ return docs if docs else None
def coqa_bpb_prompt(line, task_name: str = None):
- """BPB/CF: first turn with passage context."""
+ """BPB: evaluate ALL turns of a story, returning one Doc per turn.
+
+ Matches OLMO which creates one evaluation instance per (story, turn) pair.
+ The full preceding-QA context is built for each turn.
+ """
questions = line["questions"]["input_text"]
answers = line["answers"]["input_text"]
if not questions or not answers:
return None
- turn_idx = 0
- gold_text = answers[turn_idx]
- if not gold_text:
- return None
- if not gold_text[0].isspace():
- gold_text = " " + gold_text
- return Doc(
- task_name=task_name,
- query=_build_query(line["story"], questions, answers, turn_idx),
- choices=[gold_text],
- gold_index=0,
- )
+
+ docs = []
+ for turn_idx in range(len(questions)):
+ gold_text = answers[turn_idx]
+ if not gold_text:
+ continue
+ if not gold_text[0].isspace():
+ gold_text = " " + gold_text
+ docs.append(Doc(
+ task_name=task_name,
+ query=_build_query(line["story"], questions, answers, turn_idx),
+ choices=[gold_text],
+ gold_index=0,
+ ))
+ return docs if docs else None
TASKS_TABLE = [
diff --git a/src/lighteval/tasks/tasks/drop_qa.py b/src/lighteval/tasks/tasks/drop_qa.py
index f62793987..838ae0952 100644
--- a/src/lighteval/tasks/tasks/drop_qa.py
+++ b/src/lighteval/tasks/tasks/drop_qa.py
@@ -76,7 +76,7 @@ def drop_bpb_prompt(line, task_name: str = None):
few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
- stop_sequence=["\n"],
+ stop_sequence=["\n\n", "Passage:", "Question:"],
metrics=[Metrics.target_bits_per_byte],
version=0,
)
@@ -90,7 +90,7 @@ def drop_bpb_prompt(line, task_name: str = None):
few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=100,
- stop_sequence=["\n"],
+ stop_sequence=["\n\n", "Passage:", "Question:"],
metrics=[Metrics.drop],
version=0,
)
diff --git a/src/lighteval/tasks/tasks/hellaswag.py b/src/lighteval/tasks/tasks/hellaswag.py
index 14d9462fc..940ef8d61 100644
--- a/src/lighteval/tasks/tasks/hellaswag.py
+++ b/src/lighteval/tasks/tasks/hellaswag.py
@@ -52,8 +52,9 @@ def hellaswag_mcf_prompt(line, task_name: str = None):
"""MCF variant: labeled options in prompt, score label tokens via logprobs."""
ctx = line["ctx_a"] + " " + line["ctx_b"].capitalize()
query = harness_preprocess(line["activity_label"] + ": " + ctx)
+ query += "\nChoose the best continuation:"
query += "".join(
- [f"\n{key}. {harness_preprocess(choice)}" for key, choice in zip(ascii_uppercase, line["endings"])]
+ [f"\n {key}. {harness_preprocess(choice)}" for key, choice in zip(ascii_uppercase, line["endings"])]
)
query += "\nAnswer:"
@@ -88,7 +89,7 @@ def hellaswag_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
+ few_shots_split="test",
few_shots_select=None,
generation_size=-1,
metrics=_MCF_METRICS,
@@ -104,7 +105,7 @@ def hellaswag_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
+ few_shots_split="test",
few_shots_select=None,
generation_size=1,
metrics=[Metrics.exact_match],
@@ -119,7 +120,7 @@ def hellaswag_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
+ few_shots_split="test",
few_shots_select=None,
generation_size=-1,
metrics=_CF_METRICS,
diff --git a/src/lighteval/tasks/tasks/humaneval.py b/src/lighteval/tasks/tasks/humaneval.py
index 9d8a353fe..d61c5763a 100644
--- a/src/lighteval/tasks/tasks/humaneval.py
+++ b/src/lighteval/tasks/tasks/humaneval.py
@@ -13,14 +13,20 @@
from lighteval.tasks.requests import Doc
+_ANSWER_PREFIX = "Here is the completed function:\n\n```python\n"
+
+
def humaneval_bpb_prompt(line, task_name=None):
- """Prompt = function signature + docstring; gold = canonical solution body."""
+ """Prompt = function signature + answer_prefix; gold = canonical solution.
+
+ Matches OLMO: query = prompt + answer_prefix so the gold continuation is
+ scored in the context of the opening code fence, not the raw docstring end.
+ Leading space is added to the gold to align with OLMO's perplexity_leading_space=True.
+ """
return Doc(
task_name=task_name,
- # `prompt` already ends at the opening of the function body (after the
- # closing triple-quote of the docstring), ready for the solution to follow.
- query=line["prompt"],
- choices=[line["canonical_solution"]],
+ query=line["prompt"] + _ANSWER_PREFIX,
+ choices=[" " + line["canonical_solution"]],
gold_index=0,
)
diff --git a/src/lighteval/tasks/tasks/math.py b/src/lighteval/tasks/tasks/math.py
index b335d8986..7fac4dc78 100644
--- a/src/lighteval/tasks/tasks/math.py
+++ b/src/lighteval/tasks/tasks/math.py
@@ -33,7 +33,8 @@
""".strip()
# OLMo easy-suite style: minimal prompt, 4-shot from training data
-MATH_BPB_PROMPT_TEMPLATE = "Problem: {prompt}\nSolution:"
+# Format matches OLMO's minerva cot_style: "Problem:\n{problem}\n\nSolution:"
+MATH_BPB_PROMPT_TEMPLATE = "Problem:\n{prompt}\n\nSolution:"
_MATH_SUBSETS = [
'algebra',
diff --git a/src/lighteval/tasks/tasks/mbpp.py b/src/lighteval/tasks/tasks/mbpp.py
index 32b041759..a23ebcaaf 100644
--- a/src/lighteval/tasks/tasks/mbpp.py
+++ b/src/lighteval/tasks/tasks/mbpp.py
@@ -18,19 +18,17 @@
def mbpp_bpb_prompt(line, task_name=None):
- """Build a docstring prompt; gold = reference solution code."""
- tests = "\n".join(line.get("test_list", []))
- # Format as a Python docstring block so the model sees:
- # """
- #
- #
- # """
- # def function_name(...): ← start of gold continuation
- prompt = f'"""\n{line["prompt"]}\n{tests}\n"""\n'
+ """Prompt = task description + function header; gold = full code with leading space.
+
+ Matches OLMO's default MBPP format: query ends at the function header colon
+ so the model predicts the entire function (signature + body) as the gold.
+ Leading space matches OLMO's perplexity_leading_space=True.
+ """
+ query = line["prompt"] + line["code"].split(":")[0] + ":"
return Doc(
task_name=task_name,
- query=prompt,
- choices=[line["code"]],
+ query=query,
+ choices=[" " + line["code"]],
gold_index=0,
)
diff --git a/src/lighteval/tasks/tasks/med.py b/src/lighteval/tasks/tasks/med.py
index 90b730bcd..2493fd3bc 100644
--- a/src/lighteval/tasks/tasks/med.py
+++ b/src/lighteval/tasks/tasks/med.py
@@ -39,11 +39,14 @@
def med_mcqa_mcf_prompt(line, task_name: str = None):
- """MCF variant: labeled A/B/C/D options, score label tokens via logprobs."""
- query = f"Give a letter answer among A, B, C or D.\nQuestion: {line['question']}\n"
+ """MCF variant: labeled A/B/C/D options, score label tokens via logprobs.
+
+ Matches OLMO's MedMCQAMC: plain Question + options, no instruction prefix.
+ """
+ query = f"Question: {line['question']}\n"
query += "".join(
[
- f"{key}. {choice}\n"
+ f" {key}. {choice}\n"
for key, choice in zip(ascii_uppercase, [line["opa"], line["opb"], line["opc"], line["opd"]])
]
)
@@ -53,7 +56,6 @@ def med_mcqa_mcf_prompt(line, task_name: str = None):
query=query,
choices=[" " + c for c in list(ascii_uppercase)[:4]],
gold_index=line["cop"] - 1,
- instruction="Give a letter answer among A, B, C or D.\n",
)
diff --git a/src/lighteval/tasks/tasks/mmlu.py b/src/lighteval/tasks/tasks/mmlu.py
index dc31502c5..c4b04c8f0 100644
--- a/src/lighteval/tasks/tasks/mmlu.py
+++ b/src/lighteval/tasks/tasks/mmlu.py
@@ -124,7 +124,7 @@ def mmlu_mcf_prompt(line, task_name: str = None):
subject = line["subject"]
query = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\nQuestion: {line['question']}"
query += "".join(
- [f"\n{key}. {choice}" for key, choice in zip(ascii_uppercase, line["choices"])]
+ [f"\n {key}. {choice}" for key, choice in zip(ascii_uppercase, line["choices"])]
)
query += "\nAnswer:"
diff --git a/src/lighteval/tasks/tasks/mt_mbpp.py b/src/lighteval/tasks/tasks/mt_mbpp.py
index 32abdfcdd..695907746 100644
--- a/src/lighteval/tasks/tasks/mt_mbpp.py
+++ b/src/lighteval/tasks/tasks/mt_mbpp.py
@@ -49,11 +49,18 @@
def mt_mbpp_bpb_prompt(line, task_name=None):
- """Prompt = problem description (text); gold = code solution."""
+ """Prompt = problem description + opening code fence; gold = code + closing fence.
+
+ Matches OLMO: the query ends with ```{language}\n so the model scores the
+ code body and closing fence as the gold continuation. perplexity_leading_space
+ is implicitly False here (gold has no leading space).
+ """
+ lang = line["language"]
+ code = line["code"].strip()
return Doc(
task_name=task_name,
- query=line["text"],
- choices=[line["code"]],
+ query=line["text"].strip() + f"\n```{lang}\n",
+ choices=[code + "\n```"],
gold_index=0,
)
diff --git a/src/lighteval/tasks/tasks/winogrande.py b/src/lighteval/tasks/tasks/winogrande.py
index 0e5f7b361..be13d4210 100644
--- a/src/lighteval/tasks/tasks/winogrande.py
+++ b/src/lighteval/tasks/tasks/winogrande.py
@@ -22,18 +22,12 @@
https://arxiv.org/abs/1907.10641
"""
-from lighteval.metrics.metrics import Metrics
from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
+from lighteval.metrics.metrics import Metrics
from lighteval.metrics.normalizations import LogProbCharNorm
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
-_CF_METRICS = [
- LogLikelihoodAccMetric(),
- LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
- Metrics.target_bits_per_byte,
-]
-
_MCF_METRICS = [
LogLikelihoodAccMetric(),
LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
@@ -41,25 +35,38 @@
def winogrande_cf_prompt(line, task_name: str = None):
- """CF variant: completion-style, score full answer texts via logprobs."""
- query, end_of_target = line["sentence"].split("_")
- end_of_target = end_of_target.strip()
+ """Partial evaluation (Trinh & Le 2018), matching OLMO's implementation.
+
+ Context = sentence with the gold option substituted in (prefix + option).
+ Gold = the sentence suffix after the blank, with a leading space.
+ BPB = logprob(suffix | prefix + gold_option) / bytes(suffix).
+
+ Accuracy is not meaningful in this single-choice format (always 1.0), so
+ only BPB is reported. Use winogrande:mcf for accuracy scoring.
+ """
+ sentence = line["sentence"]
+ blank_pos = sentence.index("_")
+ prefix = sentence[:blank_pos].rstrip()
+ suffix = sentence[blank_pos + 1:].strip()
+
+ gold_ix = int(line["answer"]) - 1 if line["answer"].strip() else -1
+ gold_option = [line["option1"], line["option2"]][gold_ix]
+
return Doc(
task_name=task_name,
- query=query.rstrip(),
- choices=[f" {line['option1']} {end_of_target}", f" {line['option2']} {end_of_target}"],
- gold_index=int(line["answer"]) - 1 if line["answer"] != "" else -1,
+ query=prefix + " " + gold_option,
+ choices=[" " + suffix],
+ gold_index=0,
)
def winogrande_mcf_prompt(line, task_name: str = None):
- """MCF variant: labeled A/B options, score label tokens via logprobs."""
- query, end_of_target = line["sentence"].split("_")
- end_of_target = end_of_target.strip()
- query = query.strip()
- opt1 = f"{line['option1']} {end_of_target}"
- opt2 = f"{line['option2']} {end_of_target}"
- query = f"{query}\nA. {opt1}\nB. {opt2}\nAnswer:"
+ """MCF variant: Fill-in-the-blank format matching OLMO's WinograndeMC.
+
+ Replaces _ with ___ in the sentence, lists options as A/B without the suffix.
+ """
+ sentence_with_blank = line["sentence"].replace("_", "___")
+ query = f"Fill in the blank: {sentence_with_blank}\n A. {line['option1']}\n B. {line['option2']}\nAnswer:"
gold_ix = int(line["answer"]) - 1 if line["answer"] != "" else -1
return Doc(
task_name=task_name,
@@ -69,7 +76,32 @@ def winogrande_mcf_prompt(line, task_name: str = None):
)
-# CF variant: completion-style, logprob on full answer text + BPB on gold choice
+def winogrande_rc_prompt(line, task_name: str = None):
+ """RC (rank choice) variant: standard CF approach for accuracy.
+
+ Context = sentence prefix before blank.
+ Choices = full sentence completions with each option substituted in.
+ Approximates OLMO's RC_none (partial eval) within lighteval's Doc framework:
+ computes P(option+suffix | prefix) rather than P(suffix | prefix+option).
+ """
+ sentence = line["sentence"]
+ blank_pos = sentence.index("_")
+ prefix = sentence[:blank_pos].rstrip()
+ suffix = " " + sentence[blank_pos + 1:].strip()
+ gold_ix = int(line["answer"]) - 1 if line["answer"].strip() else -1
+ return Doc(
+ task_name=task_name,
+ query=prefix,
+ choices=[
+ " " + line["option1"] + suffix,
+ " " + line["option2"] + suffix,
+ ],
+ gold_index=gold_ix,
+ )
+
+
+# CF variant: partial evaluation (OLMO-style), BPB only
+# Accuracy is not reported because partial eval uses one fixed gold choice.
winogrande_cf = LightevalTaskConfig(
name="winogrande:cf",
prompt_function=winogrande_cf_prompt,
@@ -80,12 +112,12 @@ def winogrande_mcf_prompt(line, task_name: str = None):
few_shots_split=None,
few_shots_select="random_sampling",
generation_size=-1,
- metrics=_CF_METRICS,
+ metrics=[Metrics.target_bits_per_byte],
stop_sequence=["\n"],
version=0,
)
-# MCF variant: labeled options, score label tokens via logprobs (TRUE MCF)
+# MCF variant: "Fill in the blank:" format matching OLMO's WinograndeMC
winogrande_mcf = LightevalTaskConfig(
name="winogrande:mcf",
prompt_function=winogrande_mcf_prompt,
@@ -93,8 +125,8 @@ def winogrande_mcf_prompt(line, task_name: str = None):
hf_subset="winogrande_xl",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
- few_shots_select="random_sampling",
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_MCF_METRICS,
stop_sequence=["\n"],
@@ -109,16 +141,38 @@ def winogrande_mcf_prompt(line, task_name: str = None):
hf_subset="winogrande_xl",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
- few_shots_select="random_sampling",
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=1,
metrics=[Metrics.exact_match],
stop_sequence=["\n"],
version=0,
)
+# RC variant: standard cloze-form for accuracy (approximates Table 46 RC_none).
+# Scores P(option+suffix | prefix) rather than OLMO's P(suffix | prefix+option),
+# but gives meaningful accuracy using lighteval's standard Doc framework.
+winogrande_rc = LightevalTaskConfig(
+ name="winogrande:rc",
+ prompt_function=winogrande_rc_prompt,
+ hf_repo="allenai/winogrande",
+ hf_subset="winogrande_xl",
+ hf_avail_splits=["train", "test", "validation"],
+ evaluation_splits=["validation"],
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
+ generation_size=-1,
+ metrics=[
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+ ],
+ stop_sequence=["\n"],
+ version=0,
+)
+
TASKS_TABLE = [
winogrande_cf,
winogrande_mcf,
winogrande_mcf_em,
+ winogrande_rc,
]
From 9b1dc98ab81b43102c242b290e97967ada95b942 Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Mon, 11 May 2026 16:04:49 +0200
Subject: [PATCH 6/8] fix pr#2 comments
---
src/lighteval/metrics/sample_preparator.py | 5 +-
.../tasks/multilingual/tasks/flores200.py | 2 +-
.../tasks/multilingual/tasks/global_mmlu.py | 2 +-
.../tasks/multilingual/tasks/mgsm.py | 59 +++++++++++++++----
.../tasks/multilingual/tasks/mmlu_prox.py | 2 +-
src/lighteval/tasks/tasks/mmlu_pro.py | 9 +--
src/lighteval/tasks/tasks/mmlu_prox.py | 2 +-
src/lighteval/tasks/templates/hellaswag.py | 2 +-
.../templates/utils/translation_literals.py | 10 +++-
9 files changed, 67 insertions(+), 26 deletions(-)
diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py
index a80a1e573..d84f28c7e 100644
--- a/src/lighteval/metrics/sample_preparator.py
+++ b/src/lighteval/metrics/sample_preparator.py
@@ -110,7 +110,10 @@ def prepare(doc: Doc, model_response: ModelResponse, **kwargs) -> "COMETCorpusMe
"""
source = (doc.specific or {}).get("source_text", "")
golds = as_list(doc.get_golds())
- return COMETCorpusMetricInput(source=source, hyp=model_response.final_text, ref=golds)
+ preds = model_response.final_text
+ if len(preds) > 1:
+ logger.warning("Multiple predictions present, keeping only the first prediction (for COMET).")
+ return COMETCorpusMetricInput(source=source, hyp=preds[0], ref=golds)
class LoglikelihoodPreparator(Preparator):
diff --git a/src/lighteval/tasks/multilingual/tasks/flores200.py b/src/lighteval/tasks/multilingual/tasks/flores200.py
index 3d536bb3b..c3e8d7e70 100644
--- a/src/lighteval/tasks/multilingual/tasks/flores200.py
+++ b/src/lighteval/tasks/multilingual/tasks/flores200.py
@@ -267,7 +267,7 @@ def _flores_pair_subset(lang_code: str) -> str:
hf_avail_splits=("dev", "devtest"),
evaluation_splits=("devtest",),
few_shots_split="dev",
- few_shots_select="random_sampling_from_train",
+ few_shots_select="random_sampling",
)
for lang_code in flores_200_languages
if lang_code != _ENGLISH_FLORES_CODE
diff --git a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py
index 6167a5018..dc594632c 100644
--- a/src/lighteval/tasks/multilingual/tasks/global_mmlu.py
+++ b/src/lighteval/tasks/multilingual/tasks/global_mmlu.py
@@ -140,7 +140,7 @@ def _mcf_adapter(line):
}
-_MMLU_CF_METRICS = [MMLUCategoryGroupingCF]
+_MMLU_CF_METRICS = [MMLUCategoryGroupingCF, Metrics.target_bits_per_byte]
_MMLU_MCF_METRICS = [MMLUCategoryGroupingMCF]
diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py
index 1f03a0b26..3189314b0 100644
--- a/src/lighteval/tasks/multilingual/tasks/mgsm.py
+++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py
@@ -3,7 +3,7 @@
Mgsm
dataset:
-juletxara/mgsm
+CohereLabs/global-mgsm
abstract:
MGSM (Multilingual Grade School Math) is a multilingual benchmark testing
@@ -16,8 +16,11 @@
non-ASCII digit systems like Japanese/Thai).
languages:
-bengali, french, german, japanese, russian, spanish, swahili, telugu, thai,
-chinese
+amharic, arabic, bengali, catalan, czech, welsh, german, greek, spanish,
+basque, french, galician, gujarati, hausa, hungarian, japanese, khmer,
+kannada, korean, kyrgyz, ganda, burmese, nepali, russian, sinhala, shona,
+serbian, southern sotho, swahili, tamil, telugu, thai, urdu, uzbek,
+vietnamese, wolof, xhosa, yoruba, chinese, zulu
tags:
math, multilingual, reasoning
@@ -35,18 +38,48 @@
from lighteval.utils.language import Language
-# Languages covered by juletxara/mgsm (English excluded — use gsm8k.py)
+# Languages covered by CohereLabs/global-mgsm (English excluded — use gsm8k.py)
_LANGUAGES = [
+ Language.AMHARIC,
+ Language.ARABIC,
+ Language.BENGALI,
+ Language.CATALAN,
+ Language.CZECH,
+ Language.WELSH,
+ Language.GERMAN,
+ Language.GREEK,
Language.SPANISH,
+ Language.BASQUE,
Language.FRENCH,
- Language.GERMAN,
- Language.RUSSIAN,
- Language.CHINESE,
+ Language.GALICIAN,
+ Language.GUJARATI,
+ Language.HAUSA,
+ Language.HUNGARIAN,
Language.JAPANESE,
- Language.THAI,
+ Language.KHMER,
+ Language.KANNADA,
+ Language.KOREAN,
+ Language.KIRGHIZ,
+ Language.GANDA,
+ Language.BURMESE,
+ Language.NEPALI,
+ Language.RUSSIAN,
+ Language.SINHALA,
+ Language.SHONA,
+ Language.SERBIAN,
+ Language.SOUTHERN_SOTHO,
Language.SWAHILI,
- Language.BENGALI,
+ Language.TAMIL,
Language.TELUGU,
+ Language.THAI,
+ Language.URDU,
+ Language.UZBEK,
+ Language.VIETNAMESE,
+ Language.WOLOF,
+ Language.XHOSA,
+ Language.YORUBA,
+ Language.CHINESE,
+ Language.ZULU,
]
@@ -57,19 +90,19 @@
language,
lambda line: {
"question": line["question"],
- "choices": [str(line["answer_number"])],
+ "choices": [line["answer"]],
},
),
- hf_repo="juletxara/mgsm",
+ hf_repo="CohereLabs/global-mgsm",
hf_subset=standardize_tag(language.value),
evaluation_splits=("test",),
- few_shots_split="train",
+ few_shots_split=None,
generation_size=512,
metrics=[
Metrics.expr_gold_metric,
MultilingualQuasiExactMatchMetric(language, "full"),
],
- stop_sequence=["\n"],
+ stop_sequence=["Question:", "Answer:"],
)
for language in _LANGUAGES
]
diff --git a/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py
index a21c3b5f2..c9f574e8d 100644
--- a/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py
+++ b/src/lighteval/tasks/multilingual/tasks/mmlu_prox.py
@@ -116,7 +116,7 @@ def mmlu_prox_mcf_prompt(line, task_name: str = None):
return None
labels = list(ascii_uppercase[: len(options)])
query = f"Question: {line['question'].strip()}\n"
- query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)])
+ query += "".join([f" {lbl}. {opt}\n" for lbl, opt in zip(labels, options)])
query += "Answer:"
return Doc(
task_name=task_name,
diff --git a/src/lighteval/tasks/tasks/mmlu_pro.py b/src/lighteval/tasks/tasks/mmlu_pro.py
index b443d577b..2ede01f50 100644
--- a/src/lighteval/tasks/tasks/mmlu_pro.py
+++ b/src/lighteval/tasks/tasks/mmlu_pro.py
@@ -49,13 +49,13 @@
def mmlu_pro_mcf_prompt_function(line, task_name: str = None):
query = f"Answer the following multiple choice question.\n\nQuestion: {line['question'].strip()}"
- query += "".join([f"\n{key}. {choice}" for key, choice in zip(ascii_uppercase, line["options"])])
- query += "\nAnswer: "
+ query += "".join([f"\n {key}. {choice}" for key, choice in zip(ascii_uppercase, line["options"])])
+ query += "\nAnswer:"
return Doc(
task_name=task_name,
query=query,
- choices=ascii_uppercase[: len(line["options"])],
+ choices=[" " + c for c in ascii_uppercase[:len(line["options"])]],
gold_index=line["answer_index"],
)
@@ -78,7 +78,7 @@ def mmlu_pro_prompt_function(line, task_name: str = None):
return Doc(
task_name=task_name,
query=query,
- choices=ascii_uppercase[: len(choices)],
+ choices=list(ascii_uppercase[: len(line["options"])]),
gold_index=line["answer_index"],
instruction=query,
)
@@ -135,6 +135,7 @@ def record_to_sample(record):
metrics=[
LogLikelihoodAccMetric(),
LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
+ Metrics.target_bits_per_byte,
],
)
diff --git a/src/lighteval/tasks/tasks/mmlu_prox.py b/src/lighteval/tasks/tasks/mmlu_prox.py
index 08c04703b..318b5fc7f 100644
--- a/src/lighteval/tasks/tasks/mmlu_prox.py
+++ b/src/lighteval/tasks/tasks/mmlu_prox.py
@@ -73,7 +73,7 @@ def mmlu_prox_mcf_prompt(line, task_name: str = None):
return None
labels = list(ascii_uppercase[: len(options)])
query = f"Question: {line['question'].strip()}\n"
- query += "".join([f"{lbl}. {opt}\n" for lbl, opt in zip(labels, options)])
+ query += "".join([f" {lbl}. {opt}\n" for lbl, opt in zip(labels, options)])
query += "Answer:"
return Doc(
task_name=task_name,
diff --git a/src/lighteval/tasks/templates/hellaswag.py b/src/lighteval/tasks/templates/hellaswag.py
index 2beb09574..c1aa25135 100644
--- a/src/lighteval/tasks/templates/hellaswag.py
+++ b/src/lighteval/tasks/templates/hellaswag.py
@@ -56,7 +56,7 @@ def hellaswag_preprocess(
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
if truncate_dots:
- text = text.replace(r"\.+", r"\.")
+ text = re.sub(r"\.+", ".", text)
if strip_text:
text = text.strip()
return text
diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py
index 99405c45b..b0098bb0b 100644
--- a/src/lighteval/tasks/templates/utils/translation_literals.py
+++ b/src/lighteval/tasks/templates/utils/translation_literals.py
@@ -760,6 +760,8 @@ def __getattribute__(self, name: str) -> str:
Language.KIRGHIZ: TranslationLiterals(language=Language.KIRGHIZ),
Language.KOREAN: TranslationLiterals(
language=Language.KOREAN,
+ question_word="질문",
+ answer="답",
confirmation_word="맞죠",
yes="예",
no="아니오",
@@ -773,7 +775,11 @@ def __getattribute__(self, name: str) -> str:
Language.LIGURIAN: TranslationLiterals(language=Language.LIGURIAN),
Language.LIMBURGISH: TranslationLiterals(language=Language.LIMBURGISH),
Language.LINGALA: TranslationLiterals(language=Language.LINGALA),
- Language.LITHUANIAN: TranslationLiterals(language=Language.LITHUANIAN),
+ Language.LITHUANIAN: TranslationLiterals(
+ language=Language.LITHUANIAN,
+ question_word="klausimas",
+ answer="atsakymas",
+ ),
Language.LOMBARD: TranslationLiterals(language=Language.LOMBARD),
Language.LUBA_KASAI: TranslationLiterals(language=Language.LUBA_KASAI),
Language.LUO: TranslationLiterals(language=Language.LUO),
@@ -1436,8 +1442,6 @@ def __getattribute__(self, name: str) -> str:
Language.HEBREW,
Language.IGBO,
Language.KIRGHIZ,
- Language.KOREAN,
- Language.LITHUANIAN,
Language.MALAGASY,
Language.MALAY,
Language.NEPALI,
From d444a2c0015152734f84ca2c204660322f3d8161 Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Mon, 11 May 2026 16:26:40 +0200
Subject: [PATCH 7/8] improve mgsm
---
.../tasks/multilingual/tasks/mgsm.py | 316 +++++++-
.../templates/utils/translation_literals.py | 753 +++++++++++++++++-
2 files changed, 1019 insertions(+), 50 deletions(-)
diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py
index 3189314b0..8dac452f1 100644
--- a/src/lighteval/tasks/multilingual/tasks/mgsm.py
+++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py
@@ -9,8 +9,12 @@
MGSM (Multilingual Grade School Math) is a multilingual benchmark testing
mathematical reasoning across languages, derived from GSM8K.
-Refactored to use unified :gen suffix consistent with English gsm8k.py.
+Uses CoT-style prompts matching English gsm8k.py ("Think step by step /
+Problem: / Solution:"), with language-specific translations for all non-English
+languages in CohereLabs/global-mgsm.
+
English is excluded — use gsm8k.py for English evaluation.
+
Reports both expr_gold_metric (math expression parser) and
MultilingualQuasiExactMatchMetric (language-aware fuzzy match, handles
non-ASCII digit systems like Japanese/Thai).
@@ -34,11 +38,12 @@
from lighteval.metrics.dynamic_metrics import MultilingualQuasiExactMatchMetric
from lighteval.metrics.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
-from lighteval.tasks.templates.qa import get_qa_prompt_function
+from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
-# Languages covered by CohereLabs/global-mgsm (English excluded — use gsm8k.py)
+# Languages covered by CohereLabs/global-mgsm.
+# English is intentionally excluded — use gsm8k.py for English evaluation.
_LANGUAGES = [
Language.AMHARIC,
Language.ARABIC,
@@ -83,26 +88,311 @@
]
+# Each entry is:
+# instruction, problem_label, solution_label
+#
+# The instruction is intentionally close in meaning across languages:
+# "Solve the following math problem. Think step by step before giving
+# the final answer."
+#
+# NOTE:
+# Some low-resource translations should ideally receive native-speaker review
+# before being treated as final benchmark wording.
+_MGSM_PROMPT_PARTS: dict[Language, tuple[str, str, str]] = {
+ Language.AMHARIC: (
+ "የሚከተለውን የሂሳብ ችግር ፍታ። የመጨረሻውን መልስ ከመስጠትህ በፊት ደረጃ በደረጃ አስብ።",
+ "ችግር:",
+ "መፍትሔ:",
+ ),
+ Language.ARABIC: (
+ "حل المسألة الرياضية التالية. فكّر خطوة بخطوة قبل تقديم الإجابة النهائية.",
+ "المسألة:",
+ "الحل:",
+ ),
+ Language.BENGALI: (
+ "নিম্নলিখিত গণিত সমস্যাটি সমাধান করো। চূড়ান্ত উত্তর দেওয়ার আগে ধাপে ধাপে চিন্তা করো।",
+ "সমস্যা:",
+ "সমাধান:",
+ ),
+ Language.CATALAN: (
+ "Resol el següent problema de matemàtiques. Pensa pas a pas abans de donar la resposta final.",
+ "Problema:",
+ "Solució:",
+ ),
+ Language.CZECH: (
+ "Vyřeš následující matematickou úlohu. Než uvedeš konečnou odpověď, přemýšlej krok za krokem.",
+ "Úloha:",
+ "Řešení:",
+ ),
+ Language.WELSH: (
+ "Datrysa'r broblem fathemateg ganlynol. Meddylia gam wrth gam cyn rhoi'r ateb terfynol.",
+ "Problem:",
+ "Datrysiad:",
+ ),
+ Language.GERMAN: (
+ "Löse die folgende Mathematikaufgabe. Denke Schritt für Schritt nach, bevor du die endgültige Antwort gibst.",
+ "Aufgabe:",
+ "Lösung:",
+ ),
+ Language.GREEK: (
+ "Λύσε το παρακάτω μαθηματικό πρόβλημα. Σκέψου βήμα προς βήμα πριν δώσεις την τελική απάντηση.",
+ "Πρόβλημα:",
+ "Λύση:",
+ ),
+ Language.SPANISH: (
+ "Resuelve el siguiente problema de matemáticas. Piensa paso a paso antes de dar la respuesta final.",
+ "Problema:",
+ "Solución:",
+ ),
+ Language.BASQUE: (
+ "Ebatzi honako matematikako problema hau. Pentsatu pausoz pauso azken erantzuna eman aurretik.",
+ "Problema:",
+ "Ebazpena:",
+ ),
+ Language.FRENCH: (
+ "Résous le problème de mathématiques suivant. Réfléchis étape par étape avant de donner la réponse finale.",
+ "Problème :",
+ "Solution:",
+ ),
+ Language.GALICIAN: (
+ "Resolve o seguinte problema matemático. Pensa paso a paso antes de dar a resposta final.",
+ "Problema:",
+ "Solución:",
+ ),
+ Language.GUJARATI: (
+ "નીચે આપેલ ગણિતનો પ્રશ્ન ઉકેલો. અંતિમ જવાબ આપતા પહેલાં પગલું દર પગલું વિચારો.",
+ "પ્રશ્ન:",
+ "ઉકેલ:",
+ ),
+ Language.HAUSA: (
+ "Warware matsalar lissafi mai zuwa. Yi tunani mataki-mataki kafin ka bayar da amsar ƙarshe.",
+ "Matsala:",
+ "Magani:",
+ ),
+ Language.HUNGARIAN: (
+ "Oldd meg a következő matematikai feladatot. Gondolkodj lépésről lépésre, mielőtt megadnád a végső választ.",
+ "Feladat:",
+ "Megoldás:",
+ ),
+ Language.JAPANESE: (
+ "次の数学の問題を解きなさい。最終的な答えを出す前に、順を追って考えなさい。",
+ "問題:",
+ "解答:",
+ ),
+ Language.KHMER: (
+ "ដោះស្រាយលំហាត់គណិតវិទ្យាខាងក្រោម។ សូមគិតជាជំហានៗ មុននឹងផ្តល់ចម្លើយចុងក្រោយ។",
+ "លំហាត់:",
+ "ដំណោះស្រាយ:",
+ ),
+ Language.KANNADA: (
+ "ಕೆಳಗಿನ ಗಣಿತ ಸಮಸ್ಯೆಯನ್ನು ಪರಿಹರಿಸಿ. ಅಂತಿಮ ಉತ್ತರವನ್ನು ನೀಡುವ ಮೊದಲು ಹಂತ ಹಂತವಾಗಿ ಯೋಚಿಸಿ.",
+ "ಸಮಸ್ಯೆ:",
+ "ಪರಿಹಾರ:",
+ ),
+ Language.KOREAN: (
+ "다음 수학 문제를 풀어라. 최종 답을 제시하기 전에 단계별로 생각하라.",
+ "문제:",
+ "풀이:",
+ ),
+ Language.KIRGHIZ: (
+ "Төмөнкү математикалык маселени чыгар. Акыркы жоопту берүүдөн мурун кадам сайын ойлон.",
+ "Маселе:",
+ "Чечим:",
+ ),
+ Language.GANDA: (
+ "Gonjoola ekibuuzo kya okubala kino wammanga. Lowooza mutendera ku mutendera nga tonnawa ky’okuddamu ekisembayo.",
+ "Ekibuuzo:",
+ "Okugonjoola:",
+ ),
+ Language.BURMESE: (
+ "အောက်ပါ သင်္ချာပုစ္ဆာကို ဖြေရှင်းပါ။ နောက်ဆုံးအဖြေ မပေးမီ အဆင့်လိုက် စဉ်းစားပါ။",
+ "ပုစ္ဆာ:",
+ "ဖြေရှင်းချက်:",
+ ),
+ Language.NEPALI: (
+ "तलको गणित समस्या हल गर। अन्तिम उत्तर दिनुअघि चरणबद्ध रूपमा सोच।",
+ "समस्या:",
+ "समाधान:",
+ ),
+ Language.RUSSIAN: (
+ "Реши следующую математическую задачу. Рассуждай пошагово, прежде чем дать окончательный ответ.",
+ "Задача:",
+ "Решение:",
+ ),
+ Language.SINHALA: (
+ "පහත ගණිත ගැටලුව විසඳන්න. අවසාන පිළිතුර ලබාදීමට පෙර පියවරෙන් පියවර සිතන්න.",
+ "ගැටලුව:",
+ "විසඳුම:",
+ ),
+ Language.SHONA: (
+ "Gadzirisa dambudziko remasvomhu rinotevera. Funga nhanho nenhanho usati wapa mhinduro yekupedzisira.",
+ "Dambudziko:",
+ "Mhinduro:",
+ ),
+ Language.SERBIAN: (
+ "Реши следећи математички задатак. Размишљај корак по корак пре него што даш коначан одговор.",
+ "Задатак:",
+ "Решење:",
+ ),
+ Language.SOUTHERN_SOTHO: (
+ "Rarolla bothata bona ba dipalo bo latelang. Nahana mohato ka mohato pele o fana ka karabo ya ho qetela.",
+ "Bothata:",
+ "Tharollo:",
+ ),
+ Language.SWAHILI: (
+ "Tatua tatizo lifuatalo la hisabati. Fikiri hatua kwa hatua kabla ya kutoa jibu la mwisho.",
+ "Tatizo:",
+ "Suluhisho:",
+ ),
+ Language.TAMIL: (
+ "பின்வரும் கணிதப் பிரச்சினையைத் தீர்க்கவும். இறுதி விடையை அளிப்பதற்கு முன் படிப்படியாக சிந்திக்கவும்.",
+ "பிரச்சினை:",
+ "தீர்வு:",
+ ),
+ Language.TELUGU: (
+ "క్రింది గణిత సమస్యను పరిష్కరించండి. తుది సమాధానం చెప్పే ముందు దశలవారీగా ఆలోచించండి.",
+ "సమస్య:",
+ "పరిష్కారం:",
+ ),
+ Language.THAI: (
+ "จงแก้โจทย์คณิตศาสตร์ต่อไปนี้ คิดทีละขั้นตอนก่อนให้คำตอบสุดท้าย",
+ "โจทย์:",
+ "วิธีทำ:",
+ ),
+ Language.URDU: (
+ "درج ذیل ریاضی کا مسئلہ حل کریں۔ حتمی جواب دینے سے پہلے مرحلہ وار سوچیں۔",
+ "مسئلہ:",
+ "حل:",
+ ),
+ Language.UZBEK: (
+ "Quyidagi matematika masalasini yeching. Yakuniy javobni berishdan oldin bosqichma-bosqich o‘ylang.",
+ "Masala:",
+ "Yechim:",
+ ),
+ Language.VIETNAMESE: (
+ "Giải bài toán sau đây. Hãy suy nghĩ từng bước trước khi đưa ra câu trả lời cuối cùng.",
+ "Bài toán:",
+ "Lời giải:",
+ ),
+ Language.WOLOF: (
+ "Saafara jafe-jafe xayma bii ci suuf. Xalaatal ndànk-ndànk, jéego bu nekk, bala nga joxe tontu bu mujj bi.",
+ "Jafe-jafe:",
+ "Saafara:",
+ ),
+ Language.XHOSA: (
+ "Sombulula le ngxaki yezibalo ilandelayo. Cinga inyathelo nenyathelo phambi kokunika impendulo yokugqibela.",
+ "Ingxaki:",
+ "Isisombululo:",
+ ),
+ Language.YORUBA: (
+ "Yanju iṣoro iṣiro atẹle yii. Ronu ni igbesẹ-nipasẹ-igbesẹ ṣaaju ki o to fun idahun ikẹhin.",
+ "Iṣoro:",
+ "Ojútùú:",
+ ),
+ Language.CHINESE: (
+ "解答下面的数学题。在给出最终答案前,请一步一步思考。",
+ "问题:",
+ "解答:",
+ ),
+ Language.ZULU: (
+ "Xazulula inkinga yezibalo elandelayo. Cabanga isinyathelo ngesinyathelo ngaphambi kokunikeza impendulo yokugcina.",
+ "Inkinga:",
+ "Isixazululo:",
+ ),
+}
+
+
+# Extra problem-label variants that models may generate.
+# Keep this small: too many stop strings increase accidental truncation risk.
+_MGSM_PROBLEM_LABEL_STOP_ALIASES: dict[Language, list[str]] = {
+ Language.FRENCH: ["Problème:"], # no-space variant
+ Language.CHINESE: ["问题:"], # ASCII-colon variant
+ Language.JAPANESE: ["問題:"], # ASCII-colon variant
+}
+
+
+def validate_mgsm_prompt_parts() -> None:
+ """Fail fast if a language is missing localized prompt parts."""
+ missing_languages = [
+ language for language in _LANGUAGES if language not in _MGSM_PROMPT_PARTS
+ ]
+
+ if missing_languages:
+ missing = ", ".join(language.value for language in missing_languages)
+ raise ValueError(f"Missing MGSM prompt parts for: {missing}")
+
+
+def mgsm_cot_template(language: Language) -> str:
+ """Return the localized MGSM CoT template for a language."""
+ try:
+ instruction, problem_label, solution_label = _MGSM_PROMPT_PARTS[language]
+ except KeyError as exc:
+ raise ValueError(
+ f"Missing MGSM prompt parts for language: {language.value}"
+ ) from exc
+
+ return f"{instruction}\n\n{problem_label}\n{{prompt}}\n\n{solution_label}"
+
+
+def mgsm_stop_sequences(language: Language) -> list[str]:
+ """Return stop sequences for a language.
+
+ Stop when the model appears to start a new problem.
+
+ We keep English stops because models sometimes switch back to English even
+ when the prompt is localized. We also add only the current language's
+ problem label and a few high-value aliases to avoid over-stopping.
+ """
+ try:
+ _, problem_label, _ = _MGSM_PROMPT_PARTS[language]
+ except KeyError as exc:
+ raise ValueError(
+ f"Missing MGSM prompt parts for language: {language.value}"
+ ) from exc
+
+ stop_sequences = [
+ "Question:",
+ "Problem:",
+ problem_label,
+ *_MGSM_PROBLEM_LABEL_STOP_ALIASES.get(language, []),
+ ]
+
+ # Preserve order while removing duplicates.
+ return list(dict.fromkeys(stop_sequences))
+
+
+def mgsm_cot_prompt(language: Language):
+ """Build a CoT prompt function for MGSM in the given language."""
+ template = mgsm_cot_template(language)
+
+ def prompt_fn(line, task_name: str = None):
+ return Doc(
+ task_name=task_name,
+ query=template.format(prompt=line["question"]),
+ choices=[line["answer"]],
+ gold_index=0,
+ )
+
+ return prompt_fn
+
+
+validate_mgsm_prompt_parts()
+
+
TASKS_TABLE = [
LightevalTaskConfig(
name=f"mgsm:{language.value}:gen",
- prompt_function=get_qa_prompt_function(
- language,
- lambda line: {
- "question": line["question"],
- "choices": [line["answer"]],
- },
- ),
+ prompt_function=mgsm_cot_prompt(language),
hf_repo="CohereLabs/global-mgsm",
hf_subset=standardize_tag(language.value),
evaluation_splits=("test",),
few_shots_split=None,
- generation_size=512,
+ generation_size=1024,
metrics=[
Metrics.expr_gold_metric,
MultilingualQuasiExactMatchMetric(language, "full"),
],
- stop_sequence=["Question:", "Answer:"],
+ stop_sequence=mgsm_stop_sequences(language),
)
for language in _LANGUAGES
-]
+]
\ No newline at end of file
diff --git a/src/lighteval/tasks/templates/utils/translation_literals.py b/src/lighteval/tasks/templates/utils/translation_literals.py
index b0098bb0b..f42cc4765 100644
--- a/src/lighteval/tasks/templates/utils/translation_literals.py
+++ b/src/lighteval/tasks/templates/utils/translation_literals.py
@@ -76,9 +76,55 @@ def __getattribute__(self, name: str) -> str:
TRANSLATION_LITERALS: dict[Language, TranslationLiterals] = {
Language.ACEHNESE: TranslationLiterals(language=Language.ACEHNESE),
- Language.AFRIKAANS: TranslationLiterals(language=Language.AFRIKAANS),
+ Language.AFRIKAANS: TranslationLiterals(
+ language=Language.AFRIKAANS,
+ question_word="vraag",
+ answer="antwoord",
+ confirmation_word="nè",
+ yes="ja",
+ no="nee",
+ also="ook",
+ cause_word="omdat",
+ effect_word="daarom",
+ or_word="of",
+ and_word="en",
+ true="waar",
+ false="vals",
+ neither="nie een van die twee nie",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.AKAN: TranslationLiterals(language=Language.AKAN),
- Language.ALBANIAN: TranslationLiterals(language=Language.ALBANIAN),
+ Language.ALBANIAN: TranslationLiterals(
+ language=Language.ALBANIAN,
+ question_word="pyetje",
+ answer="përgjigje",
+ confirmation_word="apo jo",
+ yes="po",
+ no="jo",
+ also="gjithashtu",
+ cause_word="sepse",
+ effect_word="prandaj",
+ or_word="ose",
+ and_word="dhe",
+ true="e vërtetë",
+ false="e rreme",
+ neither="asnjëra",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.AMHARIC: TranslationLiterals(language=Language.AMHARIC),
Language.ARABIC: TranslationLiterals(
language=Language.ARABIC,
@@ -109,7 +155,30 @@ def __getattribute__(self, name: str) -> str:
Language.ASTURIAN: TranslationLiterals(language=Language.ASTURIAN),
Language.AWADHI: TranslationLiterals(language=Language.AWADHI),
Language.AYACUCHO_QUECHUA: TranslationLiterals(language=Language.AYACUCHO_QUECHUA),
- Language.AZERBAIJANI: TranslationLiterals(language=Language.AZERBAIJANI),
+ Language.AZERBAIJANI: TranslationLiterals(
+ language=Language.AZERBAIJANI,
+ question_word="sual",
+ answer="cavab",
+ confirmation_word="elə deyilmi",
+ yes="bəli",
+ no="xeyr",
+ also="həmçinin",
+ cause_word="çünki",
+ effect_word="buna görə",
+ or_word="və ya",
+ and_word="və",
+ true="doğru",
+ false="yanlış",
+ neither="heç biri",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.BALINESE: TranslationLiterals(language=Language.BALINESE),
Language.BAMBARA: TranslationLiterals(language=Language.BAMBARA),
Language.BANJAR: TranslationLiterals(language=Language.BANJAR),
@@ -209,7 +278,30 @@ def __getattribute__(self, name: str) -> str:
),
Language.BHOJPURI: TranslationLiterals(language=Language.BHOJPURI),
Language.BIHARI: TranslationLiterals(language=Language.BIHARI), # Deprecated
- Language.BOSNIAN: TranslationLiterals(language=Language.BOSNIAN),
+ Language.BOSNIAN: TranslationLiterals(
+ language=Language.BOSNIAN,
+ question_word="pitanje",
+ answer="odgovor",
+ confirmation_word="zar ne",
+ yes="da",
+ no="ne",
+ also="također",
+ cause_word="jer",
+ effect_word="zato",
+ or_word="ili",
+ and_word="i",
+ true="tačno",
+ false="netačno",
+ neither="nijedno",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.BRETON: TranslationLiterals(language=Language.BRETON),
Language.BUGINESE: TranslationLiterals(language=Language.BUGINESE),
Language.BULGARIAN: TranslationLiterals(
@@ -519,7 +611,30 @@ def __getattribute__(self, name: str) -> str:
semicolon=";",
),
Language.GANDA: TranslationLiterals(language=Language.GANDA),
- Language.GEORGIAN: TranslationLiterals(language=Language.GEORGIAN),
+ Language.GEORGIAN: TranslationLiterals(
+ language=Language.GEORGIAN,
+ question_word="კითხვა",
+ answer="პასუხი",
+ confirmation_word="მართალია",
+ yes="დიახ",
+ no="არა",
+ also="ასევე",
+ cause_word="რადგან",
+ effect_word="ამიტომ",
+ or_word="ან",
+ and_word="და",
+ true="ჭეშმარიტი",
+ false="მცდარი",
+ neither="არცერთი",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.GERMAN: TranslationLiterals(
language=Language.GERMAN,
question_word="frage",
@@ -597,10 +712,56 @@ def __getattribute__(self, name: str) -> str:
cause_word="poukisa",
effect_word="donk sa",
),
- Language.HAITIAN_CREOLE: TranslationLiterals(language=Language.HAITIAN_CREOLE),
+ Language.HAITIAN_CREOLE: TranslationLiterals(
+ language=Language.HAITIAN_CREOLE,
+ question_word="kesyon",
+ answer="repons",
+ confirmation_word="se vre",
+ yes="wi",
+ no="non",
+ also="tou",
+ cause_word="paske",
+ effect_word="kidonk",
+ or_word="oswa",
+ and_word="ak",
+ true="vre",
+ false="fo",
+ neither="okenn nan yo",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.HALH_MONGOLIAN: TranslationLiterals(language=Language.HALH_MONGOLIAN),
Language.HAUSA: TranslationLiterals(language=Language.HAUSA),
- Language.HEBREW: TranslationLiterals(language=Language.HEBREW),
+ Language.HEBREW: TranslationLiterals(
+ language=Language.HEBREW,
+ question_word="שאלה",
+ answer="תשובה",
+ confirmation_word="נכון",
+ yes="כן",
+ no="לא",
+ also="גם",
+ cause_word="כי",
+ effect_word="לכן",
+ or_word="או",
+ and_word="ו",
+ true="אמת",
+ false="שקר",
+ neither="אף אחד מהם",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.HINDI: TranslationLiterals(
language=Language.HINDI,
question_word="सवाल",
@@ -696,7 +857,30 @@ def __getattribute__(self, name: str) -> str:
colon=":",
semicolon=";",
),
- Language.IRISH: TranslationLiterals(language=Language.IRISH),
+ Language.IRISH: TranslationLiterals(
+ language=Language.IRISH,
+ question_word="ceist",
+ answer="freagra",
+ confirmation_word="nach ea",
+ yes="sea",
+ no="ní hea",
+ also="freisin",
+ cause_word="mar",
+ effect_word="dá bhrí sin",
+ or_word="nó",
+ and_word="agus",
+ true="fíor",
+ false="bréagach",
+ neither="ceachtar acu",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.ITALIAN: TranslationLiterals(
language=Language.ITALIAN,
question_word="domanda",
@@ -728,9 +912,10 @@ def __getattribute__(self, name: str) -> str:
yes="はい",
no="いいえ",
also="また",
- cause_word="なので",
- effect_word="なぜなら",
+ cause_word="なぜなら",
+ effect_word="だから",
or_word="または",
+ and_word="そして",
true="正解",
false="不正解",
neither="どちらでもない",
@@ -749,29 +934,163 @@ def __getattribute__(self, name: str) -> str:
Language.KABUVERDIANU: TranslationLiterals(language=Language.KABUVERDIANU),
Language.KABYLE: TranslationLiterals(language=Language.KABYLE),
Language.KAMBA: TranslationLiterals(language=Language.KAMBA),
- Language.KANNADA: TranslationLiterals(language=Language.KANNADA),
+ Language.KANNADA: TranslationLiterals(
+ language=Language.KANNADA,
+ question_word="ಪ್ರಶ್ನೆ",
+ answer="ಉತ್ತರ",
+ confirmation_word="ಅಲ್ಲವೆ",
+ yes="ಹೌದು",
+ no="ಇಲ್ಲ",
+ also="ಹಾಗೂ",
+ cause_word="ಏಕೆಂದರೆ",
+ effect_word="ಆದ್ದರಿಂದ",
+ or_word="ಅಥವಾ",
+ and_word="ಮತ್ತು",
+ true="ಸತ್ಯ",
+ false="ಸುಳ್ಳು",
+ neither="ಎರಡೂ ಅಲ್ಲ",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ indices=["ಅ", "ಆ", "ಇ", "ಈ", "ಉ", "ಊ"],
+ ),
Language.KASHMIRI: TranslationLiterals(language=Language.KASHMIRI),
- Language.KAZAKH: TranslationLiterals(language=Language.KAZAKH),
+ Language.KAZAKH: TranslationLiterals(
+ language=Language.KAZAKH,
+ question_word="сұрақ",
+ answer="жауап",
+ confirmation_word="солай ма",
+ yes="иә",
+ no="жоқ",
+ also="сондай-ақ",
+ cause_word="өйткені",
+ effect_word="сондықтан",
+ or_word="немесе",
+ and_word="және",
+ true="дұрыс",
+ false="жалған",
+ neither="ешқайсысы",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ indices=["А", "Ә", "Б", "В", "Г", "Ғ"],
+ ),
Language.KHMER: TranslationLiterals(language=Language.KHMER),
Language.KIKONGO: TranslationLiterals(language=Language.KIKONGO),
Language.KIKUYU: TranslationLiterals(language=Language.KIKUYU),
Language.KIMBUNDU: TranslationLiterals(language=Language.KIMBUNDU),
Language.KINYARWANDA: TranslationLiterals(language=Language.KINYARWANDA),
- Language.KIRGHIZ: TranslationLiterals(language=Language.KIRGHIZ),
+ Language.KIRGHIZ: TranslationLiterals(
+ language=Language.KIRGHIZ,
+ question_word="суроо",
+ answer="жооп",
+ confirmation_word="туурабы",
+ yes="ооба",
+ no="жок",
+ also="ошондой эле",
+ cause_word="анткени",
+ effect_word="ошондуктан",
+ or_word="же",
+ and_word="жана",
+ true="чын",
+ false="жалган",
+ neither="эч бири",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ indices=["А", "Б", "В", "Г", "Д", "Е"],
+ ),
Language.KOREAN: TranslationLiterals(
language=Language.KOREAN,
question_word="질문",
- answer="답",
+ answer="답변",
confirmation_word="맞죠",
yes="예",
no="아니오",
+ also="또한",
+ cause_word="왜냐하면",
+ effect_word="그러므로",
+ or_word="또는",
+ and_word="그리고",
+ true="참",
+ false="거짓",
+ neither="둘 다 아님",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
),
Language.KURDISH: TranslationLiterals(language=Language.KURDISH),
- Language.KYRGYZ: TranslationLiterals(language=Language.KYRGYZ),
+ Language.KYRGYZ: TranslationLiterals(
+ language=Language.KYRGYZ,
+ question_word="суроо",
+ answer="жооп",
+ confirmation_word="туурабы",
+ yes="ооба",
+ no="жок",
+ also="ошондой эле",
+ cause_word="анткени",
+ effect_word="ошондуктан",
+ or_word="же",
+ and_word="жана",
+ true="чын",
+ false="жалган",
+ neither="эч бири",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ indices=["А", "Б", "В", "Г", "Д", "Е"],
+ ),
Language.LAO: TranslationLiterals(language=Language.LAO),
Language.LATGALIAN: TranslationLiterals(language=Language.LATGALIAN),
Language.LATIN: TranslationLiterals(language=Language.LATIN),
- Language.LATVIAN: TranslationLiterals(language=Language.LATVIAN),
+ Language.LATVIAN: TranslationLiterals(
+ language=Language.LATVIAN,
+ question_word="jautājums",
+ answer="atbilde",
+ confirmation_word="vai ne",
+ yes="jā",
+ no="nē",
+ also="arī",
+ cause_word="jo",
+ effect_word="tāpēc",
+ or_word="vai",
+ and_word="un",
+ true="patiesi",
+ false="aplami",
+ neither="neviens",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.LIGURIAN: TranslationLiterals(language=Language.LIGURIAN),
Language.LIMBURGISH: TranslationLiterals(language=Language.LIMBURGISH),
Language.LINGALA: TranslationLiterals(language=Language.LINGALA),
@@ -779,6 +1098,25 @@ def __getattribute__(self, name: str) -> str:
language=Language.LITHUANIAN,
question_word="klausimas",
answer="atsakymas",
+ confirmation_word="ar ne",
+ yes="taip",
+ no="ne",
+ also="taip pat",
+ cause_word="nes",
+ effect_word="todėl",
+ or_word="arba",
+ and_word="ir",
+ true="teisinga",
+ false="neteisinga",
+ neither="nei vienas",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
),
Language.LOMBARD: TranslationLiterals(language=Language.LOMBARD),
Language.LUBA_KASAI: TranslationLiterals(language=Language.LUBA_KASAI),
@@ -810,20 +1148,158 @@ def __getattribute__(self, name: str) -> str:
Language.MAGAHI: TranslationLiterals(language=Language.MAGAHI),
Language.MAITHILI: TranslationLiterals(language=Language.MAITHILI),
Language.MALAGASY: TranslationLiterals(language=Language.MALAGASY),
- Language.MALAY: TranslationLiterals(language=Language.MALAY),
- Language.MALAYALAM: TranslationLiterals(language=Language.MALAYALAM),
- Language.MALTESE: TranslationLiterals(language=Language.MALTESE),
+ Language.MALAY: TranslationLiterals(
+ language=Language.MALAY,
+ question_word="soalan",
+ answer="jawapan",
+ confirmation_word="bukan",
+ yes="ya",
+ no="tidak",
+ also="juga",
+ cause_word="kerana",
+ effect_word="oleh itu",
+ or_word="atau",
+ and_word="dan",
+ true="benar",
+ false="salah",
+ neither="kedua-duanya tidak",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
+ Language.MALAYALAM: TranslationLiterals(
+ language=Language.MALAYALAM,
+ question_word="ചോദ്യം",
+ answer="ഉത്തരം",
+ confirmation_word="അല്ലേ",
+ yes="അതെ",
+ no="അല്ല",
+ also="കൂടാതെ",
+ cause_word="കാരണം",
+ effect_word="അതുകൊണ്ട്",
+ or_word="അല്ലെങ്കിൽ",
+ and_word="കൂടാതെ",
+ true="ശരി",
+ false="തെറ്റ്",
+ neither="ഇരുവരും അല്ല",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ indices=["അ", "ആ", "ഇ", "ഈ", "ഉ", "ഊ"],
+ ),
+ Language.MALTESE: TranslationLiterals(
+ language=Language.MALTESE,
+ question_word="mistoqsija",
+ answer="tweġiba",
+ confirmation_word="hux hekk",
+ yes="iva",
+ no="le",
+ also="ukoll",
+ cause_word="għax",
+ effect_word="għalhekk",
+ or_word="jew",
+ and_word="u",
+ true="veru",
+ false="falz",
+ neither="l-ebda waħda",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.MAORI: TranslationLiterals(language=Language.MAORI),
- Language.MARATHI: TranslationLiterals(language=Language.MARATHI),
+ Language.MARATHI: TranslationLiterals(
+ language=Language.MARATHI,
+ question_word="प्रश्न",
+ answer="उत्तर",
+ confirmation_word="बरोबर ना",
+ yes="हो",
+ no="नाही",
+ also="तसेच",
+ cause_word="कारण",
+ effect_word="म्हणून",
+ or_word="किंवा",
+ and_word="आणि",
+ true="सत्य",
+ false="असत्य",
+ neither="दोन्ही नाही",
+ full_stop="।",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ indices=["क", "ख", "ग", "घ", "ङ", "च"],
+ ),
Language.MEITEI: TranslationLiterals(language=Language.MEITEI),
Language.MESOPOTAMIAN_ARABIC: TranslationLiterals(language=Language.MESOPOTAMIAN_ARABIC),
Language.MINANGKABAU: TranslationLiterals(language=Language.MINANGKABAU),
Language.MIZO: TranslationLiterals(language=Language.MIZO),
- Language.MODERN_STANDARD_ARABIC: TranslationLiterals(language=Language.MODERN_STANDARD_ARABIC),
+ Language.MODERN_STANDARD_ARABIC: TranslationLiterals(
+ language=Language.MODERN_STANDARD_ARABIC,
+ question_word="سؤال",
+ answer="إجابة",
+ confirmation_word="صحيح",
+ yes="نعم",
+ no="لا",
+ also="كذلك",
+ cause_word="لأن",
+ effect_word="لذلك",
+ true="صحيح",
+ false="خاطئ",
+ neither="لا هذا ولا ذاك",
+ or_word="أو",
+ and_word="و",
+ full_stop=".",
+ comma="،",
+ question_mark="؟",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ indices=["أ", "ب", "ج", "د", "هـ", "و", "ز", "ح"],
+ ),
Language.MOROCCAN_ARABIC: TranslationLiterals(language=Language.MOROCCAN_ARABIC),
Language.MOSSI: TranslationLiterals(language=Language.MOSSI),
Language.NAJDI_ARABIC: TranslationLiterals(language=Language.NAJDI_ARABIC),
- Language.NEPALI: TranslationLiterals(language=Language.NEPALI),
+ Language.NEPALI: TranslationLiterals(
+ language=Language.NEPALI,
+ question_word="प्रश्न",
+ answer="उत्तर",
+ confirmation_word="होइन र",
+ yes="हो",
+ no="होइन",
+ also="पनि",
+ cause_word="किनभने",
+ effect_word="त्यसैले",
+ or_word="वा",
+ and_word="र",
+ true="सत्य",
+ false="असत्य",
+ neither="दुवै होइन",
+ full_stop="।",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ indices=["क", "ख", "ग", "घ", "ङ", "च"],
+ ),
Language.NIGERIAN_FULFULDE: TranslationLiterals(language=Language.NIGERIAN_FULFULDE),
Language.NORTHERN_KURDISH: TranslationLiterals(language=Language.NORTHERN_KURDISH),
Language.NORTHERN_SOTHO: TranslationLiterals(language=Language.NORTHERN_SOTHO),
@@ -853,17 +1329,109 @@ def __getattribute__(self, name: str) -> str:
colon=":",
semicolon=";",
),
- Language.NORWEGIAN_BOKMAL: TranslationLiterals(language=Language.NORWEGIAN_BOKMAL),
- Language.NORWEGIAN_NYNORSK: TranslationLiterals(language=Language.NORWEGIAN_NYNORSK),
+ Language.NORWEGIAN_BOKMAL: TranslationLiterals(
+ language=Language.NORWEGIAN_BOKMAL,
+ question_word="spørsmål",
+ answer="svar",
+ confirmation_word="ikke sant",
+ yes="ja",
+ no="nei",
+ also="også",
+ cause_word="fordi",
+ effect_word="derfor",
+ or_word="eller",
+ and_word="og",
+ true="sant",
+ false="usant",
+ neither="ingen av delene",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
+ Language.NORWEGIAN_NYNORSK: TranslationLiterals(
+ language=Language.NORWEGIAN_NYNORSK,
+ question_word="spørsmål",
+ answer="svar",
+ confirmation_word="ikkje sant",
+ yes="ja",
+ no="nei",
+ also="òg",
+ cause_word="fordi",
+ effect_word="difor",
+ or_word="eller",
+ and_word="og",
+ true="sant",
+ false="usant",
+ neither="ingen av delane",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.NUER: TranslationLiterals(language=Language.NUER),
Language.NYANJA: TranslationLiterals(language=Language.NYANJA),
- Language.OCCITAN: TranslationLiterals(language=Language.OCCITAN),
+ Language.OCCITAN: TranslationLiterals(
+ language=Language.OCCITAN,
+ question_word="question",
+ answer="responsa",
+ confirmation_word="vertat",
+ yes="òc",
+ no="non",
+ also="tanben",
+ cause_word="perque",
+ effect_word="donc",
+ or_word="o",
+ and_word="e",
+ true="vertat",
+ false="fals",
+ neither="cap dels dos",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.ODIA: TranslationLiterals(language=Language.ODIA),
Language.ORIYA: TranslationLiterals(language=Language.ORIYA),
Language.PANGASINAN: TranslationLiterals(language=Language.PANGASINAN),
Language.PAPIAMENTO: TranslationLiterals(language=Language.PAPIAMENTO),
Language.PASHTO: TranslationLiterals(language=Language.PASHTO),
- Language.PERSIAN: TranslationLiterals(language=Language.PERSIAN),
+ Language.PERSIAN: TranslationLiterals(
+ language=Language.PERSIAN,
+ question_word="پرسش",
+ answer="پاسخ",
+ confirmation_word="درست است",
+ yes="بله",
+ no="نه",
+ also="همچنین",
+ cause_word="زیرا",
+ effect_word="بنابراین",
+ or_word="یا",
+ and_word="و",
+ true="درست",
+ false="نادرست",
+ neither="هیچکدام",
+ full_stop=".",
+ comma="،",
+ question_mark="؟",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon="؛",
+ ),
Language.PLATEAU_MALAGASY: TranslationLiterals(language=Language.PLATEAU_MALAGASY),
Language.POLISH: TranslationLiterals(
language=Language.POLISH,
@@ -1076,7 +1644,30 @@ def __getattribute__(self, name: str) -> str:
colon=":",
semicolon=";",
),
- Language.SLOVENIAN: TranslationLiterals(language=Language.SLOVENIAN),
+ Language.SLOVENIAN: TranslationLiterals(
+ language=Language.SLOVENIAN,
+ question_word="vprašanje",
+ answer="odgovor",
+ confirmation_word="kajne",
+ yes="da",
+ no="ne",
+ also="tudi",
+ cause_word="ker",
+ effect_word="zato",
+ or_word="ali",
+ and_word="in",
+ true="resnično",
+ false="napačno",
+ neither="nobeno",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.SOMALI: TranslationLiterals(language=Language.SOMALI),
Language.SORANI: TranslationLiterals(language=Language.SORANI),
Language.SOUTHERN_PASHTO: TranslationLiterals(language=Language.SOUTHERN_PASHTO),
@@ -1107,8 +1698,54 @@ def __getattribute__(self, name: str) -> str:
colon=":",
semicolon=";",
),
- Language.STANDARD_LATVIAN: TranslationLiterals(language=Language.STANDARD_LATVIAN),
- Language.STANDARD_MALAY: TranslationLiterals(language=Language.STANDARD_MALAY),
+ Language.STANDARD_LATVIAN: TranslationLiterals(
+ language=Language.STANDARD_LATVIAN,
+ question_word="jautājums",
+ answer="atbilde",
+ confirmation_word="vai ne",
+ yes="jā",
+ no="nē",
+ also="arī",
+ cause_word="jo",
+ effect_word="tāpēc",
+ or_word="vai",
+ and_word="un",
+ true="patiesi",
+ false="aplami",
+ neither="neviens",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
+ Language.STANDARD_MALAY: TranslationLiterals(
+ language=Language.STANDARD_MALAY,
+ question_word="soalan",
+ answer="jawapan",
+ confirmation_word="bukan",
+ yes="ya",
+ no="tidak",
+ also="juga",
+ cause_word="kerana",
+ effect_word="oleh itu",
+ or_word="atau",
+ and_word="dan",
+ true="benar",
+ false="salah",
+ neither="kedua-duanya tidak",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.STANDARD_TIBETAN: TranslationLiterals(language=Language.STANDARD_TIBETAN),
Language.SUNDANESE: TranslationLiterals(language=Language.SUNDANESE),
Language.SWAHILI: TranslationLiterals(
@@ -1143,8 +1780,9 @@ def __getattribute__(self, name: str) -> str:
no="nej",
also="också",
cause_word="eftersom",
- effect_word="därför att",
+ effect_word="därför",
or_word="eller",
+ and_word="och",
true="sant",
false="falskt",
neither="ingendera",
@@ -1421,9 +2059,55 @@ def __getattribute__(self, name: str) -> str:
),
Language.WAR: TranslationLiterals(language=Language.WAR),
Language.WARAY: TranslationLiterals(language=Language.WARAY),
- Language.WELSH: TranslationLiterals(language=Language.WELSH),
+ Language.WELSH: TranslationLiterals(
+ language=Language.WELSH,
+ question_word="cwestiwn",
+ answer="ateb",
+ confirmation_word="on'd yw e",
+ yes="ydy",
+ no="nac ydy",
+ also="hefyd",
+ cause_word="oherwydd",
+ effect_word="felly",
+ or_word="neu",
+ and_word="a",
+ true="gwir",
+ false="gau",
+ neither="dim un o'r ddau",
+ full_stop=".",
+ comma=",",
+ question_mark="?",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon=";",
+ ),
Language.WESTERN_FRISIAN: TranslationLiterals(language=Language.WESTERN_FRISIAN),
- Language.WESTERN_PERSIAN: TranslationLiterals(language=Language.WESTERN_PERSIAN),
+ Language.WESTERN_PERSIAN: TranslationLiterals(
+ language=Language.WESTERN_PERSIAN,
+ question_word="پرسش",
+ answer="پاسخ",
+ confirmation_word="درست است",
+ yes="بله",
+ no="نه",
+ also="همچنین",
+ cause_word="زیرا",
+ effect_word="بنابراین",
+ or_word="یا",
+ and_word="و",
+ true="درست",
+ false="نادرست",
+ neither="هیچکدام",
+ full_stop=".",
+ comma="،",
+ question_mark="؟",
+ exclamation_mark="!",
+ word_space=" ",
+ sentence_space=" ",
+ colon=":",
+ semicolon="؛",
+ ),
Language.WEST_CENTRAL_OROMO: TranslationLiterals(language=Language.WEST_CENTRAL_OROMO),
Language.WOLOF: TranslationLiterals(language=Language.WOLOF),
Language.XHOSA: TranslationLiterals(language=Language.XHOSA),
@@ -1439,14 +2123,9 @@ def __getattribute__(self, name: str) -> str:
for _language in [
Language.AMHARIC,
Language.HAUSA,
- Language.HEBREW,
Language.IGBO,
- Language.KIRGHIZ,
Language.MALAGASY,
- Language.MALAY,
- Language.NEPALI,
Language.NYANJA,
- Language.PERSIAN,
Language.SHONA,
Language.SINHALA,
Language.SOMALI,
From f30649b07bde502439ad7a6447f1aef5be5fe262 Mon Sep 17 00:00:00 2001
From: rakkit <26144573+rakkit@users.noreply.github.com>
Date: Tue, 19 May 2026 16:09:22 +0200
Subject: [PATCH 8/8] fix bpb and cf for swarm
---
.../metrics/harness_compatibility/drop.py | 2 +-
src/lighteval/metrics/metrics.py | 2 +-
.../metrics/utils/linguistic_tokenizers.py | 4 +-
.../models/endpoints/litellm_model.py | 19 +-
.../tasks/multilingual/tasks/mgsm.py | 2 +-
src/lighteval/tasks/tasks/arc.py | 12 +-
src/lighteval/tasks/tasks/basic_skills.py | 135 ++++++++-
src/lighteval/tasks/tasks/commonsenseqa.py | 8 +-
src/lighteval/tasks/tasks/gsm8k.py | 2 +-
src/lighteval/tasks/tasks/hellaswag.py | 12 +-
src/lighteval/tasks/tasks/lambada.py | 64 +++-
src/lighteval/tasks/tasks/med.py | 8 +-
.../tasks/tasks/natural_questions.py | 5 +-
src/lighteval/tasks/tasks/ruler.py | 279 +++++++++++++-----
src/lighteval/tasks/tasks/squad.py | 6 +-
src/lighteval/tasks/tasks/winogrande.py | 21 +-
16 files changed, 436 insertions(+), 145 deletions(-)
diff --git a/src/lighteval/metrics/harness_compatibility/drop.py b/src/lighteval/metrics/harness_compatibility/drop.py
index 93ff0028b..4a40f52a4 100644
--- a/src/lighteval/metrics/harness_compatibility/drop.py
+++ b/src/lighteval/metrics/harness_compatibility/drop.py
@@ -89,7 +89,7 @@ def _get_metrics(self, predicted: List[str], gold: List[str]):
pred_normalized_spans, pred_bags = self._answer_to_bags(predicted)
gold_normalized_spans, gold_bags = self._answer_to_bags(gold)
- if set(pred_normalized_spans) == set(gold_normalized_spans) and len(gold_normalized_spans) == len(
+ if set(pred_normalized_spans) == set(gold_normalized_spans) and len(pred_normalized_spans) == len(
gold_normalized_spans
):
exact_match = 1.0
diff --git a/src/lighteval/metrics/metrics.py b/src/lighteval/metrics/metrics.py
index 41053de70..b707b4fdd 100644
--- a/src/lighteval/metrics/metrics.py
+++ b/src/lighteval/metrics/metrics.py
@@ -329,7 +329,7 @@ class Metrics(Enum):
metric_name=["em", "f1"],
sample_level_fn=DropMetrics(),
category=SamplingMethod.GENERATIVE,
- corpus_level_fn={"em": max, "f1": max},
+ corpus_level_fn={"em": np.mean, "f1": np.mean},
higher_is_better={"em": True, "f1": True},
)
exact_match = SampleLevelMetric(
diff --git a/src/lighteval/metrics/utils/linguistic_tokenizers.py b/src/lighteval/metrics/utils/linguistic_tokenizers.py
index a5670a268..120596cec 100644
--- a/src/lighteval/metrics/utils/linguistic_tokenizers.py
+++ b/src/lighteval/metrics/utils/linguistic_tokenizers.py
@@ -175,10 +175,10 @@ def span_tokenize(self, text: str) -> list[tuple[int, int]]:
# If you know a better tokenizer or better proxy language, please submit a PR
TOKENIZER_FACTORY: dict[Language, Callable[[], WordTokenizer]] = {
Language.ENGLISH: lambda: SpaCyTokenizer("en"),
- Language.KOREAN: lambda: SpaCyTokenizer("ko"),
+ Language.KOREAN: lambda: StanzaTokenizer("ko"),
Language.GERMAN: lambda: SpaCyTokenizer("de"),
Language.FRENCH: lambda: SpaCyTokenizer("fr"),
- Language.CZECH: lambda: SpaCyTokenizer("cz"),
+ Language.CZECH: lambda: SpaCyTokenizer("cs"),
Language.DANISH: lambda: SpaCyTokenizer("da"),
Language.DUTCH: lambda: SpaCyTokenizer("nl"),
Language.ESTONIAN: lambda: SpaCyTokenizer("et"),
diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py
index 3cf03e31f..e7def5a55 100644
--- a/src/lighteval/models/endpoints/litellm_model.py
+++ b/src/lighteval/models/endpoints/litellm_model.py
@@ -67,13 +67,16 @@ def _env_flag(name: str, default: bool) -> bool:
# SLURM jobs run concurrently.
import tempfile as _tempfile
- _cache_dir = os.environ.get("LITELLM_CACHE_DIR")
- if _cache_dir:
- _cache_dir = os.path.expandvars(os.path.expanduser(_cache_dir))
- os.makedirs(_cache_dir, exist_ok=True)
+ if os.environ.get("LITELLM_DISABLE_CACHE", "0").strip().lower() in ("1", "true", "yes"):
+ litellm.disable_cache()
else:
- _cache_dir = _tempfile.mkdtemp(prefix="litellm_cache_")
- litellm.cache = Cache(type=LiteLLMCacheType.DISK, disk_cache_dir=_cache_dir)
+ _cache_dir = os.environ.get("LITELLM_CACHE_DIR")
+ if _cache_dir:
+ _cache_dir = os.path.expandvars(os.path.expanduser(_cache_dir))
+ os.makedirs(_cache_dir, exist_ok=True)
+ else:
+ _cache_dir = _tempfile.mkdtemp(prefix="litellm_cache_")
+ litellm.cache = Cache(type=LiteLLMCacheType.DISK, disk_cache_dir=_cache_dir)
else:
from unittest.mock import Mock
@@ -882,9 +885,11 @@ def _score_batch_work(batch_items: list[tuple]) -> list[tuple]:
for doc, context, pairs in zip(doc_list, context_list, scored):
logprobs = [lp for lp, _ in pairs]
argmax = [g for _, g in pairs]
+ output_tokens = [[0] * _unique_choices[choice] for choice in doc.choices]
results.append(
ModelResponse(
- input=context, logprobs=logprobs, argmax_logits_eq_gold=argmax
+ input=context, logprobs=logprobs, argmax_logits_eq_gold=argmax,
+ output_tokens=output_tokens,
)
)
diff --git a/src/lighteval/tasks/multilingual/tasks/mgsm.py b/src/lighteval/tasks/multilingual/tasks/mgsm.py
index 8dac452f1..cec3c6108 100644
--- a/src/lighteval/tasks/multilingual/tasks/mgsm.py
+++ b/src/lighteval/tasks/multilingual/tasks/mgsm.py
@@ -390,7 +390,7 @@ def prompt_fn(line, task_name: str = None):
generation_size=1024,
metrics=[
Metrics.expr_gold_metric,
- MultilingualQuasiExactMatchMetric(language, "full"),
+ # MultilingualQuasiExactMatchMetric(language, "full"),
],
stop_sequence=mgsm_stop_sequences(language),
)
diff --git a/src/lighteval/tasks/tasks/arc.py b/src/lighteval/tasks/tasks/arc.py
index 13bd1ae1a..4ec4981ba 100644
--- a/src/lighteval/tasks/tasks/arc.py
+++ b/src/lighteval/tasks/tasks/arc.py
@@ -72,7 +72,7 @@ def arc_mcf_prompt(line, task_name: str = None):
hf_subset="ARC-Challenge",
hf_avail_splits=["train", "validation", "test"],
evaluation_splits=["test"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_CF_METRICS,
@@ -87,7 +87,7 @@ def arc_mcf_prompt(line, task_name: str = None):
hf_subset="ARC-Easy",
hf_avail_splits=["train", "validation", "test"],
evaluation_splits=["test"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_CF_METRICS,
@@ -103,7 +103,7 @@ def arc_mcf_prompt(line, task_name: str = None):
hf_subset="ARC-Challenge",
hf_avail_splits=["train", "validation", "test"],
evaluation_splits=["test"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_MCF_METRICS,
@@ -118,7 +118,7 @@ def arc_mcf_prompt(line, task_name: str = None):
hf_subset="ARC-Easy",
hf_avail_splits=["train", "validation", "test"],
evaluation_splits=["test"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_MCF_METRICS,
@@ -134,7 +134,7 @@ def arc_mcf_prompt(line, task_name: str = None):
hf_subset="ARC-Challenge",
hf_avail_splits=["train", "validation", "test"],
evaluation_splits=["test"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=1,
metrics=[Metrics.exact_match],
@@ -149,7 +149,7 @@ def arc_mcf_prompt(line, task_name: str = None):
hf_subset="ARC-Easy",
hf_avail_splits=["train", "validation", "test"],
evaluation_splits=["test"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=1,
metrics=[Metrics.exact_match],
diff --git a/src/lighteval/tasks/tasks/basic_skills.py b/src/lighteval/tasks/tasks/basic_skills.py
index 7dff5241f..c413d342a 100644
--- a/src/lighteval/tasks/tasks/basic_skills.py
+++ b/src/lighteval/tasks/tasks/basic_skills.py
@@ -19,13 +19,24 @@
paper:
"""
+import hashlib
+import random
+import threading
+from string import ascii_uppercase
+
+from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
from lighteval.metrics.metrics import Metrics
+from lighteval.metrics.normalizations import LogProbCharNorm, LogProbTokenNorm
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
-# Assumed dataset fields: "question" (str), "answer" (str)
-# Adjust if the actual allenai/basic-skills schema differs.
+# Dataset fields: "question" (str), "answer" (str) — no built-in choices.
+# Distractors for :cf and :mcf are built incrementally from the answers that
+# flow through the prompt function (including few-shot examples, which are
+# processed first). No separate dataset load is needed.
+_BASIC_SKILLS_ACCUM: dict = {} # {subset: set of answers}
+_ACCUM_LOCK = threading.Lock()
# HF config names in allenai/basic-skills (only validation split available)
_BASIC_SKILLS_SUBSETS = [
@@ -38,18 +49,21 @@
]
-def basic_skills_prompt(line, task_name: str = None):
- """Greedy variant: generate answer text, compare with exact match."""
- return Doc(
- task_name=task_name,
- query=f"Question: {line['question']}\nAnswer:",
- choices=[" " + line["answer"]],
- gold_index=0,
- )
+def _get_distractor_pool(subset: str, gold: str) -> list:
+ """Return available distractors (excluding *gold*), then accumulate *gold*.
+
+ The pool is thread-safe and grows as examples pass through the prompt function.
+ Early examples in a subset may have fewer than 3 distractors until the pool
+ accumulates enough unique answers from prior few-shot or test examples.
+ """
+ with _ACCUM_LOCK:
+ current = list(_BASIC_SKILLS_ACCUM.get(subset, set()))
+ _BASIC_SKILLS_ACCUM.setdefault(subset, set()).add(gold)
+ return [a for a in current if a != gold]
def basic_skills_bpb_prompt(line, task_name: str = None):
- """BPB variant: CF-style prompt with gold answer as single choice."""
+ """BPB variant: single-choice, score gold continuation only (no acc ranking)."""
gold_text = line["answer"]
if not gold_text:
return None
@@ -63,13 +77,66 @@ def basic_skills_bpb_prompt(line, task_name: str = None):
)
+def basic_skills_cf_prompt(line, task_name: str = None):
+ """CF variant: gold answer vs up to 3 distractors (RC_per-token acc + BPB).
+
+ Subset is inferred from task_name (e.g. 'basic_skills:arithmetic:cf' → 'arithmetic').
+ Distractors are drawn from a module-level pool that accumulates answers from
+ examples as they are processed — no separate dataset load required.
+ """
+ if not line.get("answer") or not line["answer"].strip():
+ return None
+ subset = task_name.split(":")[1] if task_name else _BASIC_SKILLS_SUBSETS[0]
+ gold = line["answer"]
+ seed = int(hashlib.md5((line["question"] + gold).encode()).hexdigest(), 16) % (2**32)
+ rng = random.Random(seed)
+ candidates = _get_distractor_pool(subset, gold)
+ distractors = rng.sample(candidates, min(3, len(candidates)))
+ all_choices = distractors + [gold]
+ rng.shuffle(all_choices)
+ gold_ix = all_choices.index(gold)
+ return Doc(
+ task_name=task_name,
+ query=f"Question: {line['question']}\nAnswer:",
+ choices=[" " + c for c in all_choices],
+ gold_index=gold_ix,
+ )
+
+
+def basic_skills_mcf_prompt(line, task_name: str = None):
+ """MCF variant: labeled A/B/C/D options with gold answer (MC Acc).
+
+ Uses a different hash seed from :cf to produce an independent shuffle.
+ """
+ if not line.get("answer") or not line["answer"].strip():
+ return None
+ subset = task_name.split(":")[1] if task_name else _BASIC_SKILLS_SUBSETS[0]
+ gold = line["answer"]
+ seed = int(hashlib.md5((line["question"] + gold).encode()).hexdigest(), 16) % (2**32)
+ rng = random.Random(seed + 1)
+ candidates = _get_distractor_pool(subset, gold)
+ distractors = rng.sample(candidates, min(3, len(candidates)))
+ all_choices = distractors + [gold]
+ rng.shuffle(all_choices)
+ gold_ix = all_choices.index(gold)
+ query = f"Question: {line['question']}\n"
+ query += "".join(f" {key}. {choice}\n" for key, choice in zip(ascii_uppercase, all_choices))
+ query += "Answer:"
+ return Doc(
+ task_name=task_name,
+ query=query,
+ choices=[" " + key for key in ascii_uppercase[:len(all_choices)]],
+ gold_index=gold_ix,
+ )
+
+
TASKS_TABLE = []
for _subset in _BASIC_SKILLS_SUBSETS:
- # CF variant: RC per-token normalization + BPB on gold (merged)
+ # BPB variant: single-choice gold continuation, BPB only
TASKS_TABLE.append(
LightevalTaskConfig(
- name=f"basic_skills:{_subset}:cf",
+ name=f"basic_skills:{_subset}:bpb",
prompt_function=basic_skills_bpb_prompt,
hf_repo="allenai/basic-skills",
hf_subset=_subset,
@@ -78,8 +145,48 @@ def basic_skills_bpb_prompt(line, task_name: str = None):
few_shots_split="validation",
few_shots_select="random_sampling",
generation_size=-1,
+ metrics=[Metrics.target_bits_per_byte],
+ stop_sequence=["\n"],
+ version=0,
+ )
+ )
+
+ # CF variant: multi-choice RC_per-token acc + BPB (gold vs up to 3 distractors)
+ TASKS_TABLE.append(
+ LightevalTaskConfig(
+ name=f"basic_skills:{_subset}:cf",
+ prompt_function=basic_skills_cf_prompt,
+ hf_repo="allenai/basic-skills",
+ hf_subset=_subset,
+ hf_avail_splits=["validation"],
+ evaluation_splits=["validation"],
+ few_shots_split="validation",
+ few_shots_select="random_sampling",
+ generation_size=-1,
+ metrics=[
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbTokenNorm()),
+ ],
+ stop_sequence=["\n"],
+ version=0,
+ )
+ )
+
+ # MCF variant: labeled A/B/C/D, MC acc (gold vs up to 3 distractors)
+ TASKS_TABLE.append(
+ LightevalTaskConfig(
+ name=f"basic_skills:{_subset}:mcf",
+ prompt_function=basic_skills_mcf_prompt,
+ hf_repo="allenai/basic-skills",
+ hf_subset=_subset,
+ hf_avail_splits=["validation"],
+ evaluation_splits=["validation"],
+ few_shots_split="validation",
+ few_shots_select="random_sampling",
+ generation_size=-1,
metrics=[
- Metrics.target_bits_per_byte,
+ LogLikelihoodAccMetric(),
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
],
stop_sequence=["\n"],
version=0,
diff --git a/src/lighteval/tasks/tasks/commonsenseqa.py b/src/lighteval/tasks/tasks/commonsenseqa.py
index 0aa872082..03f6d1ff5 100644
--- a/src/lighteval/tasks/tasks/commonsenseqa.py
+++ b/src/lighteval/tasks/tasks/commonsenseqa.py
@@ -77,8 +77,8 @@ def commonsenseqa_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
- few_shots_select=None,
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=1,
metrics=[Metrics.exact_match],
stop_sequence=["\n"],
@@ -93,7 +93,7 @@ def commonsenseqa_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_MCF_METRICS,
@@ -109,7 +109,7 @@ def commonsenseqa_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_CF_METRICS,
diff --git a/src/lighteval/tasks/tasks/gsm8k.py b/src/lighteval/tasks/tasks/gsm8k.py
index 83e5390df..b0c124154 100644
--- a/src/lighteval/tasks/tasks/gsm8k.py
+++ b/src/lighteval/tasks/tasks/gsm8k.py
@@ -86,7 +86,7 @@ def gsm8k_prompt(line, task_name: str = None):
hf_subset="main",
hf_avail_splits=["train", "test"],
evaluation_splits=["test"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=512,
metrics=[
diff --git a/src/lighteval/tasks/tasks/hellaswag.py b/src/lighteval/tasks/tasks/hellaswag.py
index 940ef8d61..be560e73c 100644
--- a/src/lighteval/tasks/tasks/hellaswag.py
+++ b/src/lighteval/tasks/tasks/hellaswag.py
@@ -89,8 +89,8 @@ def hellaswag_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split="test",
- few_shots_select=None,
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_MCF_METRICS,
stop_sequence=["\n"],
@@ -105,8 +105,8 @@ def hellaswag_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split="test",
- few_shots_select=None,
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=1,
metrics=[Metrics.exact_match],
stop_sequence=["\n"],
@@ -120,8 +120,8 @@ def hellaswag_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split="test",
- few_shots_select=None,
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_CF_METRICS,
stop_sequence=["\n"],
diff --git a/src/lighteval/tasks/tasks/lambada.py b/src/lighteval/tasks/tasks/lambada.py
index b4e7aa2f3..778b87202 100644
--- a/src/lighteval/tasks/tasks/lambada.py
+++ b/src/lighteval/tasks/tasks/lambada.py
@@ -21,12 +21,29 @@
https://arxiv.org/abs/1606.06031
"""
+import hashlib
+import random
+
from lighteval.metrics.dynamic_metrics import LogLikelihoodAccMetric
from lighteval.metrics.metrics import Metrics
from lighteval.metrics.normalizations import LogProbCharNorm
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
+# Distractor pool for lambada:cf — loaded lazily on first prompt call.
+# Contains all last words from the test split; used to sample 3 distractors
+# per example via a deterministic hash of the passage text.
+_LAMBADA_LAST_WORDS = None
+
+
+def _get_lambada_last_words():
+ global _LAMBADA_LAST_WORDS
+ if _LAMBADA_LAST_WORDS is None:
+ import datasets as _hf_datasets
+ ds = _hf_datasets.load_dataset("cimec/lambada", "plain_text", split="test")
+ _LAMBADA_LAST_WORDS = [text.rsplit(" ", 1)[1] for text in ds["text"]]
+ return _LAMBADA_LAST_WORDS
+
def lambada_prompt(line, task_name: str = None):
"""Standard LAMBADA prompt: context as query, last word as gold continuation."""
@@ -39,6 +56,29 @@ def lambada_prompt(line, task_name: str = None):
)
+def lambada_multichoice_prompt(line, task_name: str = None):
+ """RC_per-char prompt: gold last word vs 3 distractors sampled from the test set.
+
+ Distractors are chosen deterministically by hashing the passage text, so
+ results are reproducible without storing a precomputed mapping file.
+ """
+ query, gold = line["text"].rsplit(" ", 1)
+ last_words = _get_lambada_last_words()
+ seed = int(hashlib.md5(line["text"].encode()).hexdigest(), 16) % (2**32)
+ rng = random.Random(seed)
+ pool = [w for w in last_words if w != gold]
+ distractors = rng.sample(pool, 3)
+ all_choices = distractors + [gold]
+ rng.shuffle(all_choices)
+ gold_ix = all_choices.index(gold)
+ return Doc(
+ task_name=task_name,
+ query=query,
+ choices=[" " + c for c in all_choices],
+ gold_index=gold_ix,
+ )
+
+
def lambada_cloze_prompt(line, task_name: str = None):
"""Cloze-style LAMBADA prompt with fill-in-the-blank indicator."""
query, choice = line["text"].rsplit(" ", 1)
@@ -50,10 +90,26 @@ def lambada_cloze_prompt(line, task_name: str = None):
)
-# CF variant: RC per-char (rank choice with per-character normalization) + BPB on gold
+# BPB variant: single-choice, score gold continuation only (no acc ranking)
+lambada_bpb = LightevalTaskConfig(
+ name="lambada:bpb",
+ prompt_function=lambada_prompt,
+ hf_repo="cimec/lambada",
+ hf_subset="plain_text",
+ hf_avail_splits=["train", "test", "validation"],
+ evaluation_splits=["test"],
+ few_shots_split=None,
+ few_shots_select=None,
+ generation_size=-1,
+ metrics=[Metrics.target_bits_per_byte],
+ stop_sequence=["\n"],
+ version=0,
+)
+
+# CF variant: multi-choice RC per-char acc (gold last word vs 3 sampled distractors) + BPB
lambada_cf = LightevalTaskConfig(
name="lambada:cf",
- prompt_function=lambada_prompt,
+ prompt_function=lambada_multichoice_prompt,
hf_repo="cimec/lambada",
hf_subset="plain_text",
hf_avail_splits=["train", "test", "validation"],
@@ -62,8 +118,7 @@ def lambada_cloze_prompt(line, task_name: str = None):
few_shots_select=None,
generation_size=-1,
metrics=[
- LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
- Metrics.target_bits_per_byte,
+ LogLikelihoodAccMetric(normalization=LogProbCharNorm())
],
stop_sequence=["\n"],
version=0,
@@ -103,6 +158,7 @@ def lambada_cloze_prompt(line, task_name: str = None):
)
TASKS_TABLE = [
+ lambada_bpb,
lambada_cf,
lambada_standard_cloze,
lambada_openai_cloze,
diff --git a/src/lighteval/tasks/tasks/med.py b/src/lighteval/tasks/tasks/med.py
index 2493fd3bc..fa4f2f19f 100644
--- a/src/lighteval/tasks/tasks/med.py
+++ b/src/lighteval/tasks/tasks/med.py
@@ -115,8 +115,8 @@ def med_mcqa_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
- few_shots_select=None,
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=1,
metrics=[Metrics.exact_match],
stop_sequence=["\n"],
@@ -131,7 +131,7 @@ def med_mcqa_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_MCF_METRICS,
@@ -179,7 +179,7 @@ def med_mcqa_cf_prompt(line, task_name: str = None):
hf_subset="default",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
+ few_shots_split="train",
few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=_CF_METRICS,
diff --git a/src/lighteval/tasks/tasks/natural_questions.py b/src/lighteval/tasks/tasks/natural_questions.py
index 6ded9c40d..5a39fa8ff 100644
--- a/src/lighteval/tasks/tasks/natural_questions.py
+++ b/src/lighteval/tasks/tasks/natural_questions.py
@@ -30,11 +30,12 @@ def nq_gen_prompt(line, task_name: str = None):
if not answers:
return None
is_few_shots = line.get("__few_shots", False)
+ prefix = " " if is_few_shots else ""
return Doc(
task_name=task_name,
query=f"Question: {line['question']}\nAnswer:",
- choices=[f"{' ' if is_few_shots else ''}{answers[0]}"],
- gold_index=0,
+ choices=[f"{prefix}{ans}" for ans in answers],
+ gold_index=list(range(len(answers))),
)
diff --git a/src/lighteval/tasks/tasks/ruler.py b/src/lighteval/tasks/tasks/ruler.py
index 7b6648861..caa3a953d 100644
--- a/src/lighteval/tasks/tasks/ruler.py
+++ b/src/lighteval/tasks/tasks/ruler.py
@@ -88,12 +88,24 @@
"qa_2",
]
-DEFAULT_LENGTHS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
+DEFAULT_LENGTHS = [2048, 4096, 8192, 16384, 32768, 65536, 131072]
# NUM_SAMPLES = 500
NUM_SAMPLES = 100
RANDOM_SEED = 42
+MIN_RULER_LENGTH = 2048
+# Reserve 1 token so vllm/inference-time special tokens (e.g. BOS) don't push
+# prompt + gen_size over the model's context window.
+PROMPT_SAFETY_MARGIN = 1
+
+
+def _validate_lengths(lengths: list[int]) -> list[int]:
+ bad = [x for x in lengths if x < MIN_RULER_LENGTH]
+ if bad:
+ raise ValueError(f"RULER lengths must be >= {MIN_RULER_LENGTH}; got {bad}")
+ return lengths
+
# ---------------------------------------------------------------------------
# Tokenizer helpers
# ---------------------------------------------------------------------------
@@ -135,10 +147,26 @@ def _get_model_arch(tokenizer_path: str) -> str:
return f"{safe}_{h}"
+def _get_tokenizer_hash(tokenizer_path: str) -> str:
+ candidates = [
+ os.path.join(tokenizer_path, "tokenizer.json"),
+ os.path.join(tokenizer_path, "tokenizer.model"),
+ os.path.join(tokenizer_path, "vocab.json"),
+ os.path.join(tokenizer_path, "tokenizer_config.json"),
+ ]
+ hasher = hashlib.md5()
+ for path in candidates:
+ if os.path.isfile(path):
+ hasher.update(path.encode())
+ hasher.update(str(os.path.getmtime(path)).encode())
+ return hasher.hexdigest()[:8]
+
+
def _get_cache_dir(tokenizer_path: str) -> Path:
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
arch = _get_model_arch(tokenizer_path)
- return Path(hf_home) / "lighteval" / "ruler" / arch
+ tok_hash = _get_tokenizer_hash(tokenizer_path)
+ return Path(hf_home) / "lighteval" / "ruler" / f"{arch}_{tok_hash}"
# ---------------------------------------------------------------------------
@@ -334,13 +362,11 @@ def _niah_generate_samples(
num_needle_q: int = 1,
random_seed: int = RANDOM_SEED,
) -> list[dict]:
- budget = max_seq_length
+ budget = max_seq_length - PROMPT_SAFETY_MARGIN
num_needle_k = max(num_needle_k, num_needle_q)
if type_haystack == "essay":
incremental = 500
- elif type_haystack in ("repeat", "needle"):
- incremental = 25 if max_seq_length >= 4096 else 5
else:
incremental = 25
@@ -383,9 +409,16 @@ def _gen_prefix(tnv, nq, nv, query):
num_haystack = max(incremental, num_haystack)
write_jsons = []
- for index in tqdm(
- range(num_samples), desc=f"Generating NIAH samples | {max_seq_length}"
- ):
+ index = 0
+ pbar = tqdm(total=num_samples, desc=f"Generating NIAH samples | {max_seq_length}")
+ MAX_ATTEMPTS = num_samples * 10
+ while len(write_jsons) < num_samples:
+ if index >= MAX_ATTEMPTS:
+ logger.warning(
+ f"Reached max attempts ({MAX_ATTEMPTS}) for NIAH at length {max_seq_length}, "
+ f"got {len(write_jsons)}/{num_samples} samples"
+ )
+ break
# Per-sample seeding for reproducibility and diversity
sample_seed = random_seed + index
random.seed(sample_seed)
@@ -394,7 +427,7 @@ def _gen_prefix(tnv, nq, nv, query):
input_text = answer = query = gen_prefix = length = None
while True:
try:
- input_text, answer, query = _niah_generate_input_output(
+ cand_text, cand_answer, cand_query = _niah_generate_input_output(
used_haystack,
haystack,
type_haystack=type_haystack,
@@ -406,10 +439,15 @@ def _gen_prefix(tnv, nq, nv, query):
num_needle_q=num_needle_q,
random_seed=sample_seed,
)
- gen_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, query)
- prompt = _build_runtime_prompt(input_text, gen_prefix)
- length = len(tokenizer(prompt).input_ids) + tokens_to_generate
- assert length <= budget
+ cand_prefix = _gen_prefix(tnv_base, num_needle_q, num_needle_v, cand_query)
+ cand_prompt = _build_runtime_prompt(cand_text, cand_prefix)
+ cand_length = len(tokenizer(cand_prompt).input_ids) + tokens_to_generate
+ if cand_length > budget:
+ raise RuntimeError(f"length {cand_length} exceeds budget {budget}")
+ # Commit only after validation passes
+ input_text, answer, query = cand_text, cand_answer, cand_query
+ gen_prefix = cand_prefix
+ length = cand_length
break
except Exception:
if used_haystack > incremental:
@@ -418,6 +456,7 @@ def _gen_prefix(tnv, nq, nv, query):
break
if answer is None:
+ index += 1
continue
write_jsons.append(
@@ -430,6 +469,9 @@ def _gen_prefix(tnv, nq, nv, query):
"gen_prefix": gen_prefix,
}
)
+ pbar.update(1)
+ index += 1
+ pbar.close()
return write_jsons
@@ -448,12 +490,12 @@ def _gen_prefix(tnv, nq, nv, query):
"Question: Find all variables that are assigned the value {query} in the text above."
),
"answer_prefix": (
- " Answer: According to the chain(s) of variable assignment in the text above, "
- "{num_v} variables are assigned the value {query}, they are: "
+ "Answer: According to the chain(s) of variable assignment in the text above, "
+ "{num_v} variables are assigned the value {query}, they are:"
),
}
VT_TEMPLATE = VT_CONFIG["template"] + VT_CONFIG["answer_prefix"]
-VT_ANSWER_PREFIX_BASE = " Answer: According to the chain(s) of variable assignment"
+VT_ANSWER_PREFIX_BASE = "Answer: According to the chain(s) of variable assignment"
def _vt_generate_chains(
@@ -564,7 +606,7 @@ def _vt_generate_samples(
num_hops: int = 4,
tokens_to_generate: int = VT_CONFIG["tokens_to_generate"],
) -> list[dict]:
- budget = max_seq_length
+ budget = max_seq_length - PROMPT_SAFETY_MARGIN
# --- Step 1: Generate ICL example within a 500-token budget (matches reference) ---
# Reference: get_dataset() calls sys_vartrack_w_noise_random(max_seq_length=500, incremental=5)
@@ -608,26 +650,38 @@ def _vt_generate_samples(
num_noises = max(incremental, num_noises)
write_jsons = []
- for index in tqdm(
- range(num_samples), desc=f"Generating VT samples | {max_seq_length}"
- ):
+ index = 0
+ pbar = tqdm(total=num_samples, desc=f"Generating VT samples | {max_seq_length}")
+ MAX_ATTEMPTS = num_samples * 10
+ while len(write_jsons) < num_samples:
+ if index >= MAX_ATTEMPTS:
+ logger.warning(
+ f"Reached max attempts ({MAX_ATTEMPTS}) for VT at length {max_seq_length}, "
+ f"got {len(write_jsons)}/{num_samples} samples"
+ )
+ break
# Per-sample seeding for reproducibility and diversity
random.seed(RANDOM_SEED + index)
np.random.seed(RANDOM_SEED + index)
used_noises = num_noises
- input_text = answer = length = None
+ input_text = answer = cached_input = cached_gen_prefix = length = None
while True:
try:
- input_text, answer = _vt_generate_input_output(
+ cand_text, cand_answer = _vt_generate_input_output(
used_noises, num_chains, num_hops
)
- cached_input, cached_gen_prefix = _vt_build_cached_sample(
- input_text, _vt_randomize_icl(icl_str)
+ cand_cached_input, cand_cached_gen_prefix = _vt_build_cached_sample(
+ cand_text, _vt_randomize_icl(icl_str)
)
- length = _runtime_prompt_budget_length(
- tokenizer, cached_input, cached_gen_prefix, tokens_to_generate
+ cand_length = _runtime_prompt_budget_length(
+ tokenizer, cand_cached_input, cand_cached_gen_prefix, tokens_to_generate
)
- assert length <= budget
+ if cand_length > budget:
+ raise RuntimeError(f"length {cand_length} exceeds budget {budget}")
+ # Commit only after validation passes
+ input_text, answer = cand_text, cand_answer
+ cached_input, cached_gen_prefix = cand_cached_input, cand_cached_gen_prefix
+ length = cand_length
break
except Exception:
if used_noises > incremental:
@@ -636,6 +690,7 @@ def _vt_generate_samples(
break
if answer is None:
+ index += 1
continue
write_jsons.append(
@@ -648,6 +703,9 @@ def _vt_generate_samples(
"gen_prefix": cached_gen_prefix,
}
)
+ pbar.update(1)
+ index += 1
+ pbar.close()
return write_jsons
@@ -665,7 +723,7 @@ def _vt_generate_samples(
"Memorize the ones that appear most often.\n{context}\n"
"Question: What are the 10 most common words in the above list?"
),
- "answer_prefix": " Answer: The top 10 words that appear most often in the list are:",
+ "answer_prefix": "Answer: The top 10 words that appear most often in the list are:",
}
CWE_TEMPLATE = CWE_CONFIG["template"] + CWE_CONFIG["answer_prefix"]
_CWE_RNG = random.Random(RANDOM_SEED)
@@ -708,12 +766,8 @@ def _cwe_get_example(
def _cwe_generate_input_output(num_words: int, max_seq_length: int, words: list[str]):
- if max_seq_length < 4096:
- ctx_ex, ans_ex = _cwe_get_example(20, words, 3, 1, 10)
- context, answer = _cwe_get_example(num_words, words, 6, 1, 10)
- else:
- ctx_ex, ans_ex = _cwe_get_example(40, words, 10, 3, 10)
- context, answer = _cwe_get_example(num_words, words, 30, 3, 10)
+ ctx_ex, ans_ex = _cwe_get_example(40, words, 10, 3, 10)
+ context, answer = _cwe_get_example(num_words, words, 30, 3, 10)
input_example = CWE_TEMPLATE.format(context=ctx_ex, query="") + " ".join(
[f"{i + 1}. {w}" for i, w in enumerate(ans_ex)]
@@ -730,7 +784,7 @@ def _cwe_generate_samples(
tokens_to_generate: int = CWE_CONFIG["tokens_to_generate"],
) -> list[dict]:
words = _get_cwe_words()
- budget = max_seq_length
+ budget = max_seq_length - PROMPT_SAFETY_MARGIN
num_words = incremental
total_tokens = 0
@@ -753,27 +807,40 @@ def _cwe_generate_samples(
num_words = max(incremental, num_words)
write_jsons = []
- for index in tqdm(
- range(num_samples), desc=f"Generating CWE samples | {max_seq_length}"
- ):
+ write_jsons = []
+ index = 0
+ pbar = tqdm(total=num_samples, desc=f"Generating CWE samples | {max_seq_length}")
+ MAX_ATTEMPTS = num_samples * 10
+ while len(write_jsons) < num_samples:
+ if index >= MAX_ATTEMPTS:
+ logger.warning(
+ f"Reached max attempts ({MAX_ATTEMPTS}) for CWE at length {max_seq_length}, "
+ f"got {len(write_jsons)}/{num_samples} samples"
+ )
+ break
# Per-sample seeding for reproducibility and diversity
random.seed(RANDOM_SEED + index)
np.random.seed(RANDOM_SEED + index)
_CWE_RNG.seed(RANDOM_SEED + index)
used_words = num_words
- input_example = input_text = answer = length = None
+ input_text = answer = full_input = gen_prefix = length = None
while True:
try:
- input_example, input_text, answer = _cwe_generate_input_output(
+ cand_example, cand_text, cand_answer = _cwe_generate_input_output(
used_words, max_seq_length, words
)
- full_input, gen_prefix = _cwe_build_cached_sample(
- input_example, input_text
+ cand_full_input, cand_gen_prefix = _cwe_build_cached_sample(
+ cand_example, cand_text
)
- length = _runtime_prompt_budget_length(
- tokenizer, full_input, gen_prefix, tokens_to_generate
+ cand_length = _runtime_prompt_budget_length(
+ tokenizer, cand_full_input, cand_gen_prefix, tokens_to_generate
)
- assert length <= budget
+ if cand_length > budget:
+ raise RuntimeError(f"length {cand_length} exceeds budget {budget}")
+ # Commit only after validation passes
+ input_text, answer = cand_text, cand_answer
+ full_input, gen_prefix = cand_full_input, cand_gen_prefix
+ length = cand_length
break
except Exception:
if used_words > incremental:
@@ -782,6 +849,7 @@ def _cwe_generate_samples(
break
if answer is None:
+ index += 1
continue
write_jsons.append(
@@ -794,6 +862,9 @@ def _cwe_generate_samples(
"gen_prefix": gen_prefix,
}
)
+ pbar.update(1)
+ index += 1
+ pbar.close()
return write_jsons
@@ -812,7 +883,7 @@ def _cwe_generate_samples(
"Question: Do not provide any explanation. Please ignore the dots '....'. "
"What are the three most frequently appeared words in the above coded text?"
),
- "answer_prefix": " Answer: According to the coded text above, the three most frequently appeared words are:",
+ "answer_prefix": "Answer: According to the coded text above, the three most frequently appeared words are:",
}
FWE_TEMPLATE = FWE_CONFIG["template"] + FWE_CONFIG["answer_prefix"]
FWE_SEED = RANDOM_SEED
@@ -876,42 +947,67 @@ def _fwe_generate_samples(
alpha: float = 2.0,
tokens_to_generate: int = 50,
) -> list[dict]:
- budget = max_seq_length
+ budget = max_seq_length - PROMPT_SAFETY_MARGIN
input_max_len = budget - tokens_to_generate
vs = max_seq_length // 50 if vocab_size == -1 else vocab_size
+ fwe_incremental = input_max_len // 32
_, _, num_example_words = _fwe_generate_input_output(
input_max_len,
tokenizer,
coded_wordlen=coded_wordlen,
vocab_size=vs,
- incremental=input_max_len // 32,
+ incremental=fwe_incremental,
alpha=alpha,
)
write_jsons = []
- for index in tqdm(
- range(num_samples), desc=f"Generating FWE samples | {max_seq_length}"
- ):
+ index = 0
+ pbar = tqdm(total=num_samples, desc=f"Generating FWE samples | {max_seq_length}")
+ MAX_ATTEMPTS = num_samples * 10
+ while len(write_jsons) < num_samples:
+ if index >= MAX_ATTEMPTS:
+ logger.warning(
+ f"Reached max attempts ({MAX_ATTEMPTS}) for FWE at length {max_seq_length}, "
+ f"got {len(write_jsons)}/{num_samples} samples"
+ )
+ break
# Per-sample seeding for reproducibility and diversity
sample_seed = RANDOM_SEED + index
random.seed(sample_seed)
np.random.seed(sample_seed)
- input_text, answer, _ = _fwe_generate_input_output(
- input_max_len,
- tokenizer,
- num_words=num_example_words,
- coded_wordlen=coded_wordlen,
- vocab_size=vs,
- incremental=input_max_len // 32,
- alpha=alpha,
- sample_seed=sample_seed,
- )
- input_text_clean, gen_prefix = _fwe_build_cached_sample(input_text)
- length = _runtime_prompt_budget_length(
- tokenizer, input_text_clean, gen_prefix, tokens_to_generate
- )
- assert length <= budget
+ used_words = num_example_words
+ input_text_clean = answer = gen_prefix = length = None
+ while True:
+ try:
+ cand_text, cand_answer, _ = _fwe_generate_input_output(
+ input_max_len,
+ tokenizer,
+ num_words=used_words,
+ coded_wordlen=coded_wordlen,
+ vocab_size=vs,
+ incremental=fwe_incremental,
+ alpha=alpha,
+ sample_seed=sample_seed,
+ )
+ cand_clean, cand_prefix = _fwe_build_cached_sample(cand_text)
+ cand_length = _runtime_prompt_budget_length(
+ tokenizer, cand_clean, cand_prefix, tokens_to_generate
+ )
+ if cand_length > budget:
+ raise RuntimeError(f"length {cand_length} exceeds budget {budget}")
+ input_text_clean, answer = cand_clean, cand_answer
+ gen_prefix, length = cand_prefix, cand_length
+ break
+ except Exception:
+ if used_words > fwe_incremental:
+ used_words -= fwe_incremental
+ else:
+ break
+
+ if answer is None:
+ index += 1
+ continue
write_jsons.append(
{
@@ -923,6 +1019,9 @@ def _fwe_generate_samples(
"gen_prefix": gen_prefix,
}
)
+ pbar.update(1)
+ index += 1
+ pbar.close()
return write_jsons
@@ -1066,7 +1165,7 @@ def _qa_generate_samples(
tokens_to_generate: int = QA_CONFIG["tokens_to_generate"],
incremental: int = 10,
) -> list[dict]:
- budget = max_seq_length
+ budget = max_seq_length - PROMPT_SAFETY_MARGIN
gen_prefix = QA_CONFIG["answer_prefix"]
num_docs = incremental
@@ -1086,9 +1185,16 @@ def _qa_generate_samples(
num_docs = max(incremental, num_docs)
write_jsons = []
- for index in tqdm(
- range(num_samples), desc=f"Generating QA samples | {max_seq_length}"
- ):
+ index = 0
+ pbar = tqdm(total=num_samples, desc=f"Generating QA samples | {max_seq_length}")
+ MAX_ATTEMPTS = num_samples * 10
+ while len(write_jsons) < num_samples:
+ if index >= MAX_ATTEMPTS:
+ logger.warning(
+ f"Reached max attempts ({MAX_ATTEMPTS}) for QA at length {max_seq_length}, "
+ f"got {len(write_jsons)}/{num_samples} samples"
+ )
+ break
# Per-sample seeding for reproducibility and diversity
random.seed(RANDOM_SEED + index)
np.random.seed(RANDOM_SEED + index)
@@ -1096,12 +1202,16 @@ def _qa_generate_samples(
input_text = answer = length = None
while True:
try:
- input_text, answer = _qa_generate_input_output(
+ cand_text, cand_answer = _qa_generate_input_output(
index, used_docs, qas=qas, docs=docs
)
- prompt = _build_runtime_prompt(input_text, gen_prefix)
- length = len(tokenizer(prompt).input_ids) + tokens_to_generate
- assert length <= budget
+ cand_prompt = _build_runtime_prompt(cand_text, gen_prefix)
+ cand_length = len(tokenizer(cand_prompt).input_ids) + tokens_to_generate
+ if cand_length > budget:
+ raise RuntimeError(f"length {cand_length} exceeds budget {budget}")
+ # Commit only after validation passes
+ input_text, answer = cand_text, cand_answer
+ length = cand_length
break
except Exception:
if used_docs > incremental:
@@ -1110,6 +1220,7 @@ def _qa_generate_samples(
break
if answer is None:
+ index += 1
continue
write_jsons.append(
@@ -1122,6 +1233,9 @@ def _qa_generate_samples(
"gen_prefix": gen_prefix,
}
)
+ pbar.update(1)
+ index += 1
+ pbar.close()
return write_jsons
@@ -1223,7 +1337,7 @@ def _generate_subset(subset: str, length: int, tokenizer) -> list[dict]:
type_needle_v="numbers",
num_needle_v=4,
template=NIAH_TEMPLATE_MULTI,
- tokens_to_generate=300 if length == 4096 else 250,
+ tokens_to_generate=300 if length <= 4096 else 250,
)
elif subset == "vt":
return _vt_generate_samples(tokenizer, max_seq_length=length)
@@ -1254,6 +1368,7 @@ def ensure_ruler_cache(
tokenizer_name: str,
lengths: list[int] | None = None,
subsets: list[str] | None = None,
+ allow_partial: bool = False,
) -> Path:
"""Generate and cache RULER data for the given tokenizer.
@@ -1263,7 +1378,7 @@ def ensure_ruler_cache(
Returns the cache base directory.
"""
- lengths = lengths or DEFAULT_LENGTHS
+ lengths = _validate_lengths(lengths or DEFAULT_LENGTHS)
subsets = subsets or SUBSETS
cache_base = _get_cache_dir(tokenizer_name)
cache_base.mkdir(parents=True, exist_ok=True)
@@ -1282,6 +1397,7 @@ def ensure_ruler_cache(
)
tokenizer = _load_tokenizer(tokenizer_name)
+ failed = []
for length, subset in missing:
logger.info(f"Generating RULER data: length={length} subset={subset} ...")
try:
@@ -1295,8 +1411,14 @@ def ensure_ruler_cache(
)
logger.info(f"Saved {len(samples)} samples to {subset_dir}")
except Exception as e:
+ failed.append((length, subset, str(e)))
logger.warning(f"Skipping subset {subset} at length {length}: {e}")
+ if failed and not allow_partial:
+ raise RuntimeError(
+ f"RULER cache generation failed for {len(failed)} subset(s): {failed}"
+ )
+
return cache_base
@@ -1344,7 +1466,7 @@ def get_ruler_tasks(
Data is generated on demand in download_dataset_worker (keyed off
TOKENIZER_PATH) so only the lengths actually evaluated are generated.
"""
- lengths = lengths or DEFAULT_LENGTHS
+ lengths = _validate_lengths(lengths or DEFAULT_LENGTHS)
subsets = subsets or SUBSETS
cache_base = _get_cache_dir(tokenizer_name)
@@ -1359,7 +1481,6 @@ def get_ruler_tasks(
_niah_gen_sizes = {
"niah_multikey_3": 100,
"niah_multiquery": 100,
- "niah_multivalue": 300 if length == 4096 else 250,
}
gen_size = (
_niah_gen_sizes.get(subset, 50)
@@ -1368,6 +1489,8 @@ def get_ruler_tasks(
50 if subset == "vt" else 100 if subset == "cwe" else 50
) # fwe and qa
)
+ if subset == "niah_multivalue":
+ gen_size = 300 if length <= 4096 else 250
tasks.append(
LightevalTaskConfig(
name=f"ruler_{length}:{subset}",
diff --git a/src/lighteval/tasks/tasks/squad.py b/src/lighteval/tasks/tasks/squad.py
index 67c6a9895..6baf16d6b 100644
--- a/src/lighteval/tasks/tasks/squad.py
+++ b/src/lighteval/tasks/tasks/squad.py
@@ -31,12 +31,12 @@ def squad_gen_prompt(line, task_name: str = None):
if not answers_text:
return None
is_few_shots = line.get("__few_shots", False)
- gold_text = f"{' ' if is_few_shots else ''}{answers_text[0]}"
+ prefix = " " if is_few_shots else ""
return Doc(
task_name=task_name,
query=f"Title: {line['title']}\n\nBackground: {line['context']}\n\nQuestion: {line['question']}\n\nAnswer:",
- choices=[gold_text],
- gold_index=0,
+ choices=[f"{prefix}{ans}" for ans in answers_text],
+ gold_index=list(range(len(answers_text))),
)
diff --git a/src/lighteval/tasks/tasks/winogrande.py b/src/lighteval/tasks/tasks/winogrande.py
index be13d4210..72c2099c2 100644
--- a/src/lighteval/tasks/tasks/winogrande.py
+++ b/src/lighteval/tasks/tasks/winogrande.py
@@ -42,7 +42,7 @@ def winogrande_cf_prompt(line, task_name: str = None):
BPB = logprob(suffix | prefix + gold_option) / bytes(suffix).
Accuracy is not meaningful in this single-choice format (always 1.0), so
- only BPB is reported. Use winogrande:mcf for accuracy scoring.
+ only BPB is reported. Use winogrande:cf for accuracy scoring.
"""
sentence = line["sentence"]
blank_pos = sentence.index("_")
@@ -100,17 +100,17 @@ def winogrande_rc_prompt(line, task_name: str = None):
)
-# CF variant: partial evaluation (OLMO-style), BPB only
+# BPB variant: partial evaluation (OLMO-style), BPB only
# Accuracy is not reported because partial eval uses one fixed gold choice.
-winogrande_cf = LightevalTaskConfig(
- name="winogrande:cf",
+winogrande_bpb = LightevalTaskConfig(
+ name="winogrande:bpb",
prompt_function=winogrande_cf_prompt,
hf_repo="allenai/winogrande",
hf_subset="winogrande_xl",
hf_avail_splits=["train", "test", "validation"],
evaluation_splits=["validation"],
- few_shots_split=None,
- few_shots_select="random_sampling",
+ few_shots_split="train",
+ few_shots_select="random_sampling_from_train",
generation_size=-1,
metrics=[Metrics.target_bits_per_byte],
stop_sequence=["\n"],
@@ -149,11 +149,11 @@ def winogrande_rc_prompt(line, task_name: str = None):
version=0,
)
-# RC variant: standard cloze-form for accuracy (approximates Table 46 RC_none).
+# CF variant: standard cloze-form for accuracy (approximates Table 46 RC_none).
# Scores P(option+suffix | prefix) rather than OLMO's P(suffix | prefix+option),
# but gives meaningful accuracy using lighteval's standard Doc framework.
-winogrande_rc = LightevalTaskConfig(
- name="winogrande:rc",
+winogrande_cf = LightevalTaskConfig(
+ name="winogrande:cf",
prompt_function=winogrande_rc_prompt,
hf_repo="allenai/winogrande",
hf_subset="winogrande_xl",
@@ -164,15 +164,14 @@ def winogrande_rc_prompt(line, task_name: str = None):
generation_size=-1,
metrics=[
LogLikelihoodAccMetric(),
- LogLikelihoodAccMetric(normalization=LogProbCharNorm()),
],
stop_sequence=["\n"],
version=0,
)
TASKS_TABLE = [
+ winogrande_bpb,
winogrande_cf,
winogrande_mcf,
winogrande_mcf_em,
- winogrande_rc,
]