Skip to content

Commit f74b05b

Browse files
authored
feat(benchmarks): implement Kaggle client (push/run functionality) (#960)
# Benchmarks CLI Reference (push & run) The benchmarks CLI manages benchmark tasks — registering evaluation code, scheduling runs against models, monitoring progress, and downloading results. **Aliases**: `kaggle benchmarks` or `kaggle b` All task subcommands are under `kaggle benchmarks tasks` (alias: `kaggle b t`). --- ## Commands ### `push` — Register a task Upload a Python source file as a benchmark task definition. The file is expected to be a `.py` file with percent delimiters (e.g., `# %%`). The CLI converts it to an `.ipynb` file before uploading. If the task already exists, it creates a new version. ``` kaggle b t push <task> -f <file> ``` | Parameter | Flag | Required | Description | |-----------|------|----------|-------------| | `task` | positional | **Yes** | Task name (e.g. `math-eval`) | | `file` | `-f`, `--file` | **Yes** | Path to the Python source file defining the task | **Behavior**: 1. Validates the file exists and has a `.py` extension. 2. Reads the source file and parses it with Python's `ast` module to extract task names from `@task` decorators (supports both `@task` and `@kbench.task` styles, as well as `@task(name="...")` with explicit names). - When an explicit `name=` keyword is provided, that name is used. - When no explicit name is provided, the function name is title-cased with underscores replaced by spaces (e.g. `my_test_task` → `"My Test Task"`). 3. Validates that the file contains at least one `@task` decorator. If none are found, raises `ValueError` and stops. 4. Validates that the given task name matches one of the task names extracted from the file. 5. Converts the `.py` file content to `.ipynb` format (Jupyter Notebook) using `jupytext` (assuming percent format), and adds a Python 3 kernelspec to the notebook metadata. 6. Checks the server for an existing task with the same slug: - If the task exists and its `creation_state` is `QUEUED` or `RUNNING` (i.e. a previous version is still being built), the push is **rejected** with `ValueError`. - If the task exists and is in `COMPLETED` or `ERRORED` state, the push proceeds (creates a new version). - If the task does not exist (404), the push proceeds (creates a new task). 7. Sends the notebook content (JSON string) to `create_benchmark_task`. 8. If the server returns an error message in the response, raises `ValueError` with the error details. 9. On success, prints the task slug and its URL. **Errors**: - `ValueError: File <path> does not exist` — file path is invalid. - `ValueError: File <path> must be a .py file` — file is not a Python file. - `ValueError: No @task decorators found in file <path>. The file must define at least one task.` — the file does not contain any `@task`-decorated functions. - `ValueError: Task '<name>' not found in file <path>. Found tasks: ...` — the task name doesn't match any `@task`-decorated function in the file. - `ValueError: Task '<name>' is currently being created (pending). Cannot push now.` — a previous version of this task is still being processed by the server. - `ValueError: Failed to push task: <error>` — the server returned an error message in the response. - `HTTPError` — server-side error (e.g. authentication failure, permission denied). **Example**: ```bash kaggle b t push math-eval -f tasks/math_eval.py ``` ### `run` — Schedule task runs Schedule benchmark task execution against one or more models. ``` kaggle b t run <task> [-m <model> ...] [--wait] ``` | Parameter | Flag | Required | Description | |-----------|------|----------|-------------| | `task` | positional | **Yes** | Task name (e.g. `math-eval`) | | `model` | `-m`, `--model` | No | Model slug(s) to run against. Accepts multiple space-separated values | | `wait` | `--wait` | No | Wait for runs to complete. Can specify a timeout in seconds (0 or omit = indefinite) | | `poll_interval` | `--poll-interval` | No | Seconds between status polls when using `--wait` (default: 10) | **Behavior**: 1. **Task readiness check**: Before scheduling, verifies that the task exists and its `creation_state` is `COMPLETED`. If the task is not ready: - For `ERRORED` tasks, the error message includes the task info for debugging. - For other non-completed states (e.g. `QUEUED`, `RUNNING`), raises `ValueError` indicating the task is not ready to run. 2. **Model selection**: If no `-m` is provided, fetches the list of available benchmark models via `list_benchmark_models` and prompts the user interactively: ``` No model specified. 5 model(s) available: 1. gemini-pro (Gemini Pro) 2. gemma-2b (Gemma 2B) Enter model numbers (comma-separated), 'all': ``` - Enter comma-separated numbers (e.g. `1,3`) to select specific models. - Enter `all` to run against every available model. - When there are more than 20 models, the list is paginated. Use `n` for next page and `p` for previous page. - Invalid input (non-numeric, out-of-range index) raises `ValueError`. - If no benchmark models exist on the server, raises `ValueError: No benchmark models available. Cannot schedule runs.` 3. **Scheduling**: Calls `batch_schedule_benchmark_task_runs` with the task slug and selected model slugs. Output: ``` Submitted run(s) for task 'math-eval'. gemini-pro: Scheduled gemma-2b: Scheduled gemini-flash: Skipped (<reason>) ``` 4. **Waiting** (`--wait`): After scheduling, if `--wait` is specified, polls `list_benchmark_task_runs` at a fixed interval (default **10 seconds**, configurable via `--poll-interval`) until all runs reach a terminal state (`COMPLETED` or `ERRORED`) or the timeout is reached. Output while waiting: ``` Waiting for run(s) to complete... 2 run(s) still in progress... 1 run(s) still in progress... All runs completed: gemini-pro: COMPLETED gemma-2b: ERRORED ``` - If a timeout (in seconds) is specified and reached, it stops waiting and prints: `Timed out waiting for runs after <timeout> seconds.` - If `0` or no value is specified for `--wait`, it waits indefinitely. **Errors**: - `ValueError: Task '<name>' is not ready to run (status: <state>). Only completed tasks can be run.` — the task has not finished building (or errored during build). - `ValueError: No benchmark models available. Cannot schedule runs.` — no models exist on the server and none were specified via `-m`. - `ValueError: Invalid selection: <input>` — the user entered non-numeric or out-of-range input during interactive model selection. - `HTTPError` — server-side error (task not found, authentication failure, etc.). **Examples**: ```bash # Run against specific models kaggle b t run math-eval -m gemini-pro gemma-2b # Run and wait for completion kaggle b t run math-eval -m gemini-pro --wait # Wait with a custom poll interval (30 seconds) kaggle b t run math-eval -m gemini-pro --wait --poll-interval 30 # Wait with a timeout (60 seconds) kaggle b t run math-eval -m gemini-pro --wait 60 # Interactive model selection (prompts user) kaggle b t run math-eval ``` --- ### End to end test https://paste.googleplex.com/6483737513689088
1 parent a08af17 commit f74b05b

7 files changed

Lines changed: 814 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@ keywords = ["Kaggle", "API"]
2424
requires-python = ">= 3.11"
2525
dependencies = [
2626
"bleach",
27-
"kagglesdk >= 0.1.17, < 1.0", # sync with kagglehub
27+
"kagglesdk >= 0.1.18, < 1.0", # sync with kagglehub
2828
"python-slugify",
2929
"requests",
3030
"python-dateutil",
3131
"tqdm",
3232
"urllib3 >= 1.15.1",
3333
"packaging",
3434
"protobuf",
35+
"jupytext",
3536
]
3637

3738
[project.scripts]

requirements-test.lock

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
# Regenerate with: pip-compile requirements-test.in -o requirements-test.lock
1+
#
2+
# This file is autogenerated by pip-compile with Python 3.13
3+
# by the following command:
4+
#
5+
# pip-compile --output-file=requirements-test.lock requirements-test.in
6+
#
27
iniconfig==2.3.0
38
# via pytest
49
packaging==26.0

requirements.lock

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,49 @@
1-
# Regenerate with: pip-compile pyproject.toml -o requirements.lock
1+
#
2+
# This file is autogenerated by pip-compile with Python 3.13
3+
# by the following command:
4+
#
5+
# pip-compile --output-file=requirements.lock pyproject.toml
6+
#
7+
attrs==26.1.0
8+
# via
9+
# jsonschema
10+
# referencing
211
bleach==6.3.0
312
# via kaggle (pyproject.toml)
413
certifi==2026.2.25
514
# via requests
615
charset-normalizer==3.4.6
716
# via requests
17+
fastjsonschema==2.21.2
18+
# via nbformat
819
idna==3.11
920
# via requests
10-
kagglesdk==0.1.17
21+
jsonschema==4.26.0
22+
# via nbformat
23+
jsonschema-specifications==2025.9.1
24+
# via jsonschema
25+
jupyter-core==5.9.1
26+
# via nbformat
27+
jupytext==1.19.1
1128
# via kaggle (pyproject.toml)
12-
packaging==26.0
29+
kagglesdk==0.1.18
1330
# via kaggle (pyproject.toml)
31+
markdown-it-py==4.0.0
32+
# via
33+
# jupytext
34+
# mdit-py-plugins
35+
mdit-py-plugins==0.5.0
36+
# via jupytext
37+
mdurl==0.1.2
38+
# via markdown-it-py
39+
nbformat==5.10.4
40+
# via jupytext
41+
packaging==26.0
42+
# via
43+
# jupytext
44+
# kaggle (pyproject.toml)
45+
platformdirs==4.9.6
46+
# via jupyter-core
1447
protobuf==7.34.1
1548
# via
1649
# kaggle (pyproject.toml)
@@ -19,16 +52,30 @@ python-dateutil==2.9.0.post0
1952
# via kaggle (pyproject.toml)
2053
python-slugify==8.0.4
2154
# via kaggle (pyproject.toml)
55+
pyyaml==6.0.3
56+
# via jupytext
57+
referencing==0.37.0
58+
# via
59+
# jsonschema
60+
# jsonschema-specifications
2261
requests==2.33.1
2362
# via
2463
# kaggle (pyproject.toml)
2564
# kagglesdk
65+
rpds-py==0.30.0
66+
# via
67+
# jsonschema
68+
# referencing
2669
six==1.17.0
2770
# via python-dateutil
2871
text-unidecode==1.3
2972
# via python-slugify
3073
tqdm==4.67.3
3174
# via kaggle (pyproject.toml)
75+
traitlets==5.14.3
76+
# via
77+
# jupyter-core
78+
# nbformat
3279
urllib3==2.6.3
3380
# via
3481
# kaggle (pyproject.toml)

src/kaggle/api/kaggle_api_extended.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@
5757
from kagglesdk import get_access_token_from_env, KaggleClient, KaggleCredentials, KaggleEnv, KaggleOAuth # type: ignore[attr-defined]
5858
from kagglesdk.admin.types.inbox_file_service import CreateInboxFileRequest
5959
from kagglesdk.blobs.types.blob_api_service import ApiStartBlobUploadRequest, ApiStartBlobUploadResponse, ApiBlobType
60+
from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskRunState, BenchmarkTaskVersionCreationState
61+
from kagglesdk.benchmarks.types.benchmark_tasks_api_service import (
62+
ApiCreateBenchmarkTaskRequest,
63+
ApiGetBenchmarkTaskRequest,
64+
ApiListBenchmarkTaskRunsRequest,
65+
ApiBenchmarkTaskSlug,
66+
ApiBatchScheduleBenchmarkTaskRunsRequest,
67+
)
68+
from kagglesdk.benchmarks.types.benchmarks_api_service import ApiListBenchmarkModelsRequest
6069
from kagglesdk.competitions.types.competition_api_service import (
6170
ApiListCompetitionsRequest,
6271
ApiCreateCodeSubmissionRequest,
@@ -5429,6 +5438,268 @@ def _check_response_version(self, response: Response):
54295438
def get_response_processor(self):
54305439
return self._check_response_version
54315440

5441+
# ---- Benchmarks CLI ----
5442+
5443+
_TERMINAL_RUN_STATES = {
5444+
BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED,
5445+
BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED,
5446+
}
5447+
5448+
_PENDING_CREATION_STATES = {
5449+
BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED,
5450+
BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING,
5451+
}
5452+
5453+
@staticmethod
5454+
def _make_task_slug(task: str) -> ApiBenchmarkTaskSlug:
5455+
"""Build an ApiBenchmarkTaskSlug from a task name string."""
5456+
slug = ApiBenchmarkTaskSlug()
5457+
slug.task_slug = task
5458+
return slug
5459+
5460+
@staticmethod
5461+
def _normalize_model_list(model) -> list:
5462+
"""Normalize a model argument (str, list, or None) into a list."""
5463+
if isinstance(model, list):
5464+
return model
5465+
return [model] if model else []
5466+
5467+
@staticmethod
5468+
def _paginate(fetch_page, get_items):
5469+
"""Exhaust a paginated API, returning all items."""
5470+
items = []
5471+
page_token = ""
5472+
while True:
5473+
response = fetch_page(page_token)
5474+
items.extend(get_items(response))
5475+
page_token = response.next_page_token or ""
5476+
if not page_token:
5477+
break
5478+
return items
5479+
5480+
def _get_task_names_from_file(self, file_content: str) -> List[str]:
5481+
"""Extract task names from a Python file."""
5482+
import ast
5483+
5484+
task_names = []
5485+
try:
5486+
tree = ast.parse(file_content)
5487+
except SyntaxError:
5488+
return []
5489+
5490+
for node in ast.walk(tree):
5491+
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
5492+
continue
5493+
5494+
for decorator in node.decorator_list:
5495+
func = decorator.func if isinstance(decorator, ast.Call) else decorator
5496+
5497+
if not (
5498+
(isinstance(func, ast.Name) and func.id == "task")
5499+
or (isinstance(func, ast.Attribute) and func.attr == "task")
5500+
):
5501+
continue
5502+
5503+
name = None
5504+
if isinstance(decorator, ast.Call):
5505+
name = next(
5506+
(
5507+
k.value.value
5508+
for k in decorator.keywords
5509+
if k.arg == "name" and isinstance(k.value, ast.Constant)
5510+
),
5511+
None,
5512+
)
5513+
5514+
task_names.append(name if name else node.name.title().replace("_", " "))
5515+
5516+
return task_names
5517+
5518+
def _get_benchmark_task(self, task: str, kaggle):
5519+
"""Get benchmark task details from the server."""
5520+
request = ApiGetBenchmarkTaskRequest()
5521+
request.slug = self._make_task_slug(task)
5522+
return kaggle.benchmarks.benchmark_tasks_api_client.get_benchmark_task(request)
5523+
5524+
def _validate_task_in_file(self, task: str, file: str, file_content: str):
5525+
"""Validate that the task name is defined in the Python file."""
5526+
task_names = self._get_task_names_from_file(file_content)
5527+
if not task_names:
5528+
raise ValueError(f"No @task decorators found in file {file}. The file must define at least one task.")
5529+
if task not in task_names:
5530+
raise ValueError(f"Task '{task}' not found in file {file}. Found tasks: {', '.join(task_names)}")
5531+
5532+
def benchmarks_tasks_push_cli(self, task, file):
5533+
if not os.path.isfile(file):
5534+
raise ValueError(f"File {file} does not exist")
5535+
if not file.endswith(".py"):
5536+
raise ValueError(f"File {file} must be a .py file")
5537+
5538+
with open(file) as f:
5539+
content = f.read()
5540+
5541+
self._validate_task_in_file(task, file, content)
5542+
5543+
# Convert .py file with percent delimiters to .ipynb
5544+
import jupytext
5545+
5546+
notebook = jupytext.reads(content, fmt="py:percent")
5547+
# Add kernelspec metadata so papermill can execute it on the server
5548+
notebook.metadata["kernelspec"] = {
5549+
"display_name": "Python 3",
5550+
"language": "python",
5551+
"name": "python3",
5552+
}
5553+
notebook_content = jupytext.writes(notebook, fmt="ipynb")
5554+
5555+
with self.build_kaggle_client() as kaggle:
5556+
try:
5557+
task_info = self._get_benchmark_task(task, kaggle)
5558+
if task_info.creation_state in self._PENDING_CREATION_STATES:
5559+
raise ValueError(f"Task '{task}' is currently being created (pending). Cannot push now.")
5560+
except HTTPError as e:
5561+
if e.response.status_code != 404:
5562+
raise
5563+
5564+
request = ApiCreateBenchmarkTaskRequest()
5565+
request.slug = task
5566+
request.text = notebook_content
5567+
5568+
response = kaggle.benchmarks.benchmark_tasks_api_client.create_benchmark_task(request)
5569+
error = getattr(response, "error_message", None) or getattr(response, "errorMessage", None)
5570+
if error:
5571+
raise ValueError(f"Failed to push task: {error}")
5572+
print(f"Task '{task}' pushed.")
5573+
url = response.url
5574+
if url.startswith("/"):
5575+
url = "https://www.kaggle.com" + url
5576+
print(f"Task URL: {url}")
5577+
5578+
def _select_models_interactively(self, kaggle, page_size=20):
5579+
"""Prompt the user to pick benchmark models from a paginated list."""
5580+
5581+
def _fetch_models(page_token):
5582+
req = ApiListBenchmarkModelsRequest()
5583+
if page_token:
5584+
req.page_token = page_token
5585+
return kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(req)
5586+
5587+
available = self._paginate(_fetch_models, lambda r: r.benchmark_models)
5588+
if not available:
5589+
raise ValueError("No benchmark models available. Cannot schedule runs.")
5590+
5591+
total = len(available)
5592+
total_pages = -(-total // page_size) # ceiling division
5593+
current_page = 0
5594+
5595+
print(f"No model specified. {total} model(s) available:")
5596+
while True:
5597+
start = current_page * page_size
5598+
for i, m in enumerate(available[start : start + page_size], start=start + 1):
5599+
print(f" {i}. {m.version.slug} ({m.display_name})")
5600+
5601+
nav_hints = []
5602+
if total_pages > 1:
5603+
print(f" [Page {current_page + 1}/{total_pages}]")
5604+
if current_page < total_pages - 1:
5605+
nav_hints.append("'n'=next")
5606+
if current_page > 0:
5607+
nav_hints.append("'p'=prev")
5608+
5609+
prompt_parts = ["Enter model numbers (comma-separated)", "'all'"]
5610+
if nav_hints:
5611+
prompt_parts.extend(nav_hints)
5612+
selection = input(", ".join(prompt_parts) + ": ").strip().lower()
5613+
5614+
if selection == "n" and current_page < total_pages - 1:
5615+
current_page += 1
5616+
elif selection == "p" and current_page > 0:
5617+
current_page -= 1
5618+
elif selection == "all":
5619+
return [m.version.slug for m in available]
5620+
else:
5621+
try:
5622+
indices = [int(s) for s in selection.split(",")]
5623+
return [available[i - 1].version.slug for i in indices]
5624+
except (ValueError, IndexError):
5625+
raise ValueError(f"Invalid selection: {selection}")
5626+
5627+
def _poll_runs(self, kaggle, task_slug_obj, models, wait, poll_interval):
5628+
"""Poll run status until all runs are terminal or timeout."""
5629+
5630+
def _fetch_runs(page_token):
5631+
req = ApiListBenchmarkTaskRunsRequest()
5632+
req.task_slug = task_slug_obj
5633+
if models:
5634+
req.model_version_slugs = models
5635+
if page_token:
5636+
req.page_token = page_token
5637+
return kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(req)
5638+
5639+
print("Waiting for run(s) to complete...")
5640+
start_time = time.time()
5641+
while True:
5642+
all_runs = self._paginate(_fetch_runs, lambda r: r.runs)
5643+
5644+
if all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs):
5645+
print("All runs completed:")
5646+
for r in all_runs:
5647+
label = (
5648+
"COMPLETED"
5649+
if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED
5650+
else "ERRORED"
5651+
)
5652+
print(f" {r.model_version_slug}: {label}")
5653+
return
5654+
5655+
pending = sum(1 for r in all_runs if r.state not in self._TERMINAL_RUN_STATES)
5656+
print(f" {pending} run(s) still in progress...")
5657+
5658+
if wait > 0 and (time.time() - start_time) > wait:
5659+
print(f"Timed out waiting for runs after {wait} seconds.")
5660+
return
5661+
5662+
time.sleep(poll_interval)
5663+
5664+
def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10):
5665+
models = self._normalize_model_list(model)
5666+
task_slug_obj = self._make_task_slug(task)
5667+
5668+
with self.build_kaggle_client() as kaggle:
5669+
# Verify the task exists and is ready to run
5670+
task_info = self._get_benchmark_task(task, kaggle)
5671+
if (
5672+
task_info.creation_state
5673+
!= BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED
5674+
):
5675+
error_msg = f"Task '{task}' is not ready to run (status: {task_info.creation_state})."
5676+
if (
5677+
task_info.creation_state
5678+
== BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_ERRORED
5679+
):
5680+
error_msg += f" Task Info: {task_info}."
5681+
error_msg += " Only completed tasks can be run."
5682+
raise ValueError(error_msg)
5683+
5684+
if not models:
5685+
models = self._select_models_interactively(kaggle)
5686+
print(f"Selected models: {models}")
5687+
5688+
request = ApiBatchScheduleBenchmarkTaskRunsRequest()
5689+
request.task_slugs = [task_slug_obj]
5690+
request.model_version_slugs = models
5691+
5692+
response = kaggle.benchmarks.benchmark_tasks_api_client.batch_schedule_benchmark_task_runs(request)
5693+
print(f"Submitted run(s) for task '{task}'.")
5694+
for model_slug, res in zip(models, response.results):
5695+
if res.run_scheduled:
5696+
print(f" {model_slug}: Scheduled")
5697+
else:
5698+
print(f" {model_slug}: Skipped ({res.run_skipped_reason})")
5699+
5700+
if wait is not None:
5701+
self._poll_runs(kaggle, task_slug_obj, models, wait, poll_interval)
5702+
54325703

54335704
class TqdmBufferedReader(io.BufferedReader):
54345705

0 commit comments

Comments
 (0)