Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ keywords = ["Kaggle", "API"]
requires-python = ">= 3.11"
dependencies = [
"bleach",
"kagglesdk >= 0.1.17, < 1.0", # sync with kagglehub
"kagglesdk >= 0.1.18, < 1.0", # sync with kagglehub
"python-slugify",
"requests",
"python-dateutil",
"tqdm",
"urllib3 >= 1.15.1",
"packaging",
"protobuf",
"jupytext",
]

[project.scripts]
Expand Down
7 changes: 6 additions & 1 deletion requirements-test.lock
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Regenerate with: pip-compile requirements-test.in -o requirements-test.lock
#
# This file is autogenerated by pip-compile with Python 3.13
# by the following command:
#
# pip-compile --output-file=requirements-test.lock requirements-test.in
#
iniconfig==2.3.0
# via pytest
packaging==26.0
Expand Down
53 changes: 50 additions & 3 deletions requirements.lock
Original file line number Diff line number Diff line change
@@ -1,16 +1,49 @@
# Regenerate with: pip-compile pyproject.toml -o requirements.lock
#
# This file is autogenerated by pip-compile with Python 3.13
# by the following command:
#
# pip-compile --output-file=requirements.lock pyproject.toml
#
attrs==26.1.0
# via
# jsonschema
# referencing
bleach==6.3.0
# via kaggle (pyproject.toml)
certifi==2026.2.25
# via requests
charset-normalizer==3.4.6
# via requests
fastjsonschema==2.21.2
# via nbformat
idna==3.11
# via requests
kagglesdk==0.1.17
jsonschema==4.26.0
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.19.1
# via kaggle (pyproject.toml)
packaging==26.0
kagglesdk==0.1.18
# via kaggle (pyproject.toml)
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
nbformat==5.10.4
# via jupytext
packaging==26.0
# via
# jupytext
# kaggle (pyproject.toml)
platformdirs==4.9.6
# via jupyter-core
protobuf==7.34.1
# via
# kaggle (pyproject.toml)
Expand All @@ -19,16 +52,30 @@ python-dateutil==2.9.0.post0
# via kaggle (pyproject.toml)
python-slugify==8.0.4
# via kaggle (pyproject.toml)
pyyaml==6.0.3
# via jupytext
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
requests==2.33.1
# via
# kaggle (pyproject.toml)
# kagglesdk
rpds-py==0.30.0
# via
# jsonschema
# referencing
six==1.17.0
# via python-dateutil
text-unidecode==1.3
# via python-slugify
tqdm==4.67.3
# via kaggle (pyproject.toml)
traitlets==5.14.3
# via
# jupyter-core
# nbformat
urllib3==2.6.3
# via
# kaggle (pyproject.toml)
Expand Down
271 changes: 271 additions & 0 deletions src/kaggle/api/kaggle_api_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@
from kagglesdk import get_access_token_from_env, KaggleClient, KaggleCredentials, KaggleEnv, KaggleOAuth # type: ignore[attr-defined]
from kagglesdk.admin.types.inbox_file_service import CreateInboxFileRequest
from kagglesdk.blobs.types.blob_api_service import ApiStartBlobUploadRequest, ApiStartBlobUploadResponse, ApiBlobType
from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskRunState, BenchmarkTaskVersionCreationState
from kagglesdk.benchmarks.types.benchmark_tasks_api_service import (
ApiCreateBenchmarkTaskRequest,
ApiGetBenchmarkTaskRequest,
ApiListBenchmarkTaskRunsRequest,
ApiBenchmarkTaskSlug,
ApiBatchScheduleBenchmarkTaskRunsRequest,
)
from kagglesdk.benchmarks.types.benchmarks_api_service import ApiListBenchmarkModelsRequest
from kagglesdk.competitions.types.competition_api_service import (
ApiListCompetitionsRequest,
ApiCreateCodeSubmissionRequest,
Expand Down Expand Up @@ -5429,6 +5438,268 @@ def _check_response_version(self, response: Response):
def get_response_processor(self):
return self._check_response_version

# ---- Benchmarks CLI ----

_TERMINAL_RUN_STATES = {
BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED,
BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED,
}

_PENDING_CREATION_STATES = {
BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED,
BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING,
}

@staticmethod
def _make_task_slug(task: str) -> ApiBenchmarkTaskSlug:
"""Build an ApiBenchmarkTaskSlug from a task name string."""
slug = ApiBenchmarkTaskSlug()
slug.task_slug = task
return slug

@staticmethod
def _normalize_model_list(model) -> list:
"""Normalize a model argument (str, list, or None) into a list."""
if isinstance(model, list):
return model
return [model] if model else []

@staticmethod
def _paginate(fetch_page, get_items):
"""Exhaust a paginated API, returning all items."""
items = []
page_token = ""
while True:
response = fetch_page(page_token)
items.extend(get_items(response))
page_token = response.next_page_token or ""
if not page_token:
break
return items

def _get_task_names_from_file(self, file_content: str) -> List[str]:
"""Extract task names from a Python file."""
import ast

task_names = []
try:
tree = ast.parse(file_content)
except SyntaxError:
return []

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue

for decorator in node.decorator_list:
func = decorator.func if isinstance(decorator, ast.Call) else decorator

if not (
(isinstance(func, ast.Name) and func.id == "task")
or (isinstance(func, ast.Attribute) and func.attr == "task")
):
continue

name = None
if isinstance(decorator, ast.Call):
name = next(
(
k.value.value
for k in decorator.keywords
if k.arg == "name" and isinstance(k.value, ast.Constant)
),
None,
)

task_names.append(name if name else node.name.title().replace("_", " "))

return task_names

def _get_benchmark_task(self, task: str, kaggle):
"""Get benchmark task details from the server."""
request = ApiGetBenchmarkTaskRequest()
request.slug = self._make_task_slug(task)
return kaggle.benchmarks.benchmark_tasks_api_client.get_benchmark_task(request)

def _validate_task_in_file(self, task: str, file: str, file_content: str):
"""Validate that the task name is defined in the Python file."""
task_names = self._get_task_names_from_file(file_content)
if not task_names:
raise ValueError(f"No @task decorators found in file {file}. The file must define at least one task.")
if task not in task_names:
raise ValueError(f"Task '{task}' not found in file {file}. Found tasks: {', '.join(task_names)}")

def benchmarks_tasks_push_cli(self, task, file):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you passing "task" as a parameter to the CLI method instead of reading it from the Python file?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's because currently the users can define multiple tasks in a single file. We need user to decide which one to create.

if not os.path.isfile(file):
raise ValueError(f"File {file} does not exist")
if not file.endswith(".py"):
raise ValueError(f"File {file} must be a .py file")

with open(file) as f:
content = f.read()

self._validate_task_in_file(task, file, content)

# Convert .py file with percent delimiters to .ipynb
import jupytext

notebook = jupytext.reads(content, fmt="py:percent")
# Add kernelspec metadata so papermill can execute it on the server
notebook.metadata["kernelspec"] = {
"display_name": "Python 3",
"language": "python",
"name": "python3",
}
notebook_content = jupytext.writes(notebook, fmt="ipynb")

with self.build_kaggle_client() as kaggle:
try:
task_info = self._get_benchmark_task(task, kaggle)
if task_info.creation_state in self._PENDING_CREATION_STATES:
raise ValueError(f"Task '{task}' is currently being created (pending). Cannot push now.")
except HTTPError as e:
if e.response.status_code != 404:
raise

request = ApiCreateBenchmarkTaskRequest()
request.slug = task
request.text = notebook_content

response = kaggle.benchmarks.benchmark_tasks_api_client.create_benchmark_task(request)
error = getattr(response, "error_message", None) or getattr(response, "errorMessage", None)
if error:
raise ValueError(f"Failed to push task: {error}")
print(f"Task '{task}' pushed.")
url = response.url
if url.startswith("/"):
url = "https://www.kaggle.com" + url
print(f"Task URL: {url}")

def _select_models_interactively(self, kaggle, page_size=20):
"""Prompt the user to pick benchmark models from a paginated list."""

def _fetch_models(page_token):
req = ApiListBenchmarkModelsRequest()
if page_token:
req.page_token = page_token
return kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(req)

available = self._paginate(_fetch_models, lambda r: r.benchmark_models)
if not available:
raise ValueError("No benchmark models available. Cannot schedule runs.")

total = len(available)
total_pages = -(-total // page_size) # ceiling division
current_page = 0

print(f"No model specified. {total} model(s) available:")
while True:
start = current_page * page_size
for i, m in enumerate(available[start : start + page_size], start=start + 1):
print(f" {i}. {m.version.slug} ({m.display_name})")

nav_hints = []
if total_pages > 1:
print(f" [Page {current_page + 1}/{total_pages}]")
if current_page < total_pages - 1:
nav_hints.append("'n'=next")
if current_page > 0:
nav_hints.append("'p'=prev")

prompt_parts = ["Enter model numbers (comma-separated)", "'all'"]
if nav_hints:
prompt_parts.extend(nav_hints)
selection = input(", ".join(prompt_parts) + ": ").strip().lower()

if selection == "n" and current_page < total_pages - 1:
current_page += 1
elif selection == "p" and current_page > 0:
current_page -= 1
elif selection == "all":
return [m.version.slug for m in available]
else:
try:
indices = [int(s) for s in selection.split(",")]
return [available[i - 1].version.slug for i in indices]
except (ValueError, IndexError):
raise ValueError(f"Invalid selection: {selection}")

def _poll_runs(self, kaggle, task_slug_obj, models, wait, poll_interval):
"""Poll run status until all runs are terminal or timeout."""

def _fetch_runs(page_token):
req = ApiListBenchmarkTaskRunsRequest()
req.task_slug = task_slug_obj
if models:
req.model_version_slugs = models
if page_token:
req.page_token = page_token
return kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(req)

print("Waiting for run(s) to complete...")
start_time = time.time()
while True:
all_runs = self._paginate(_fetch_runs, lambda r: r.runs)

if all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs):
print("All runs completed:")
for r in all_runs:
label = (
"COMPLETED"
if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED
else "ERRORED"
)
print(f" {r.model_version_slug}: {label}")
return

pending = sum(1 for r in all_runs if r.state not in self._TERMINAL_RUN_STATES)
print(f" {pending} run(s) still in progress...")

if wait > 0 and (time.time() - start_time) > wait:
print(f"Timed out waiting for runs after {wait} seconds.")
return

time.sleep(poll_interval)

def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10):
models = self._normalize_model_list(model)
task_slug_obj = self._make_task_slug(task)

with self.build_kaggle_client() as kaggle:
# Verify the task exists and is ready to run
task_info = self._get_benchmark_task(task, kaggle)
if (
task_info.creation_state
!= BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED
):
error_msg = f"Task '{task}' is not ready to run (status: {task_info.creation_state})."
if (
task_info.creation_state
== BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_ERRORED
):
error_msg += f" Task Info: {task_info}."
error_msg += " Only completed tasks can be run."
raise ValueError(error_msg)

if not models:
models = self._select_models_interactively(kaggle)
print(f"Selected models: {models}")

request = ApiBatchScheduleBenchmarkTaskRunsRequest()
request.task_slugs = [task_slug_obj]
request.model_version_slugs = models

response = kaggle.benchmarks.benchmark_tasks_api_client.batch_schedule_benchmark_task_runs(request)
print(f"Submitted run(s) for task '{task}'.")
for model_slug, res in zip(models, response.results):
if res.run_scheduled:
print(f" {model_slug}: Scheduled")
else:
print(f" {model_slug}: Skipped ({res.run_skipped_reason})")

if wait is not None:
self._poll_runs(kaggle, task_slug_obj, models, wait, poll_interval)


class TqdmBufferedReader(io.BufferedReader):

Expand Down
Loading
Loading