diff --git a/pyproject.toml b/pyproject.toml index 793d0125..062c41c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ keywords = ["Kaggle", "API"] requires-python = ">= 3.11" dependencies = [ "bleach", - "kagglesdk >= 0.1.17, < 1.0", # sync with kagglehub + "kagglesdk >= 0.1.18, < 1.0", # sync with kagglehub "python-slugify", "requests", "python-dateutil", @@ -32,6 +32,7 @@ dependencies = [ "urllib3 >= 1.15.1", "packaging", "protobuf", + "jupytext", ] [project.scripts] diff --git a/requirements-test.lock b/requirements-test.lock index ef288660..7591b959 100644 --- a/requirements-test.lock +++ b/requirements-test.lock @@ -1,4 +1,9 @@ -# Regenerate with: pip-compile requirements-test.in -o requirements-test.lock +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# pip-compile --output-file=requirements-test.lock requirements-test.in +# iniconfig==2.3.0 # via pytest packaging==26.0 diff --git a/requirements.lock b/requirements.lock index b15a458c..be56b94f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -1,16 +1,49 @@ -# Regenerate with: pip-compile pyproject.toml -o requirements.lock +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# pip-compile --output-file=requirements.lock pyproject.toml +# +attrs==26.1.0 + # via + # jsonschema + # referencing bleach==6.3.0 # via kaggle (pyproject.toml) certifi==2026.2.25 # via requests charset-normalizer==3.4.6 # via requests +fastjsonschema==2.21.2 + # via nbformat idna==3.11 # via requests -kagglesdk==0.1.17 +jsonschema==4.26.0 + # via nbformat +jsonschema-specifications==2025.9.1 + # via jsonschema +jupyter-core==5.9.1 + # via nbformat +jupytext==1.19.1 # via kaggle (pyproject.toml) -packaging==26.0 +kagglesdk==0.1.18 # via kaggle (pyproject.toml) +markdown-it-py==4.0.0 + # via + # jupytext + # mdit-py-plugins +mdit-py-plugins==0.5.0 + # via jupytext +mdurl==0.1.2 + # via markdown-it-py +nbformat==5.10.4 + # via jupytext +packaging==26.0 + # via + # jupytext + # kaggle (pyproject.toml) +platformdirs==4.9.6 + # via jupyter-core protobuf==7.34.1 # via # kaggle (pyproject.toml) @@ -19,16 +52,30 @@ python-dateutil==2.9.0.post0 # via kaggle (pyproject.toml) python-slugify==8.0.4 # via kaggle (pyproject.toml) +pyyaml==6.0.3 + # via jupytext +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications requests==2.33.1 # via # kaggle (pyproject.toml) # kagglesdk +rpds-py==0.30.0 + # via + # jsonschema + # referencing six==1.17.0 # via python-dateutil text-unidecode==1.3 # via python-slugify tqdm==4.67.3 # via kaggle (pyproject.toml) +traitlets==5.14.3 + # via + # jupyter-core + # nbformat urllib3==2.6.3 # via # kaggle (pyproject.toml) diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index 90e0f7b8..78124f86 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -57,6 +57,15 @@ from kagglesdk import get_access_token_from_env, KaggleClient, KaggleCredentials, KaggleEnv, KaggleOAuth # type: ignore[attr-defined] from kagglesdk.admin.types.inbox_file_service import CreateInboxFileRequest from kagglesdk.blobs.types.blob_api_service import ApiStartBlobUploadRequest, ApiStartBlobUploadResponse, ApiBlobType +from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskRunState, BenchmarkTaskVersionCreationState +from kagglesdk.benchmarks.types.benchmark_tasks_api_service import ( + ApiCreateBenchmarkTaskRequest, + ApiGetBenchmarkTaskRequest, + ApiListBenchmarkTaskRunsRequest, + ApiBenchmarkTaskSlug, + ApiBatchScheduleBenchmarkTaskRunsRequest, +) +from kagglesdk.benchmarks.types.benchmarks_api_service import ApiListBenchmarkModelsRequest from kagglesdk.competitions.types.competition_api_service import ( ApiListCompetitionsRequest, ApiCreateCodeSubmissionRequest, @@ -5429,6 +5438,268 @@ def _check_response_version(self, response: Response): def get_response_processor(self): return self._check_response_version + # ---- Benchmarks CLI ---- + + _TERMINAL_RUN_STATES = { + BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED, + BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED, + } + + _PENDING_CREATION_STATES = { + BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED, + BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING, + } + + @staticmethod + def _make_task_slug(task: str) -> ApiBenchmarkTaskSlug: + """Build an ApiBenchmarkTaskSlug from a task name string.""" + slug = ApiBenchmarkTaskSlug() + slug.task_slug = task + return slug + + @staticmethod + def _normalize_model_list(model) -> list: + """Normalize a model argument (str, list, or None) into a list.""" + if isinstance(model, list): + return model + return [model] if model else [] + + @staticmethod + def _paginate(fetch_page, get_items): + """Exhaust a paginated API, returning all items.""" + items = [] + page_token = "" + while True: + response = fetch_page(page_token) + items.extend(get_items(response)) + page_token = response.next_page_token or "" + if not page_token: + break + return items + + def _get_task_names_from_file(self, file_content: str) -> List[str]: + """Extract task names from a Python file.""" + import ast + + task_names = [] + try: + tree = ast.parse(file_content) + except SyntaxError: + return [] + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + + for decorator in node.decorator_list: + func = decorator.func if isinstance(decorator, ast.Call) else decorator + + if not ( + (isinstance(func, ast.Name) and func.id == "task") + or (isinstance(func, ast.Attribute) and func.attr == "task") + ): + continue + + name = None + if isinstance(decorator, ast.Call): + name = next( + ( + k.value.value + for k in decorator.keywords + if k.arg == "name" and isinstance(k.value, ast.Constant) + ), + None, + ) + + task_names.append(name if name else node.name.title().replace("_", " ")) + + return task_names + + def _get_benchmark_task(self, task: str, kaggle): + """Get benchmark task details from the server.""" + request = ApiGetBenchmarkTaskRequest() + request.slug = self._make_task_slug(task) + return kaggle.benchmarks.benchmark_tasks_api_client.get_benchmark_task(request) + + def _validate_task_in_file(self, task: str, file: str, file_content: str): + """Validate that the task name is defined in the Python file.""" + task_names = self._get_task_names_from_file(file_content) + if not task_names: + raise ValueError(f"No @task decorators found in file {file}. The file must define at least one task.") + if task not in task_names: + raise ValueError(f"Task '{task}' not found in file {file}. Found tasks: {', '.join(task_names)}") + + def benchmarks_tasks_push_cli(self, task, file): + if not os.path.isfile(file): + raise ValueError(f"File {file} does not exist") + if not file.endswith(".py"): + raise ValueError(f"File {file} must be a .py file") + + with open(file) as f: + content = f.read() + + self._validate_task_in_file(task, file, content) + + # Convert .py file with percent delimiters to .ipynb + import jupytext + + notebook = jupytext.reads(content, fmt="py:percent") + # Add kernelspec metadata so papermill can execute it on the server + notebook.metadata["kernelspec"] = { + "display_name": "Python 3", + "language": "python", + "name": "python3", + } + notebook_content = jupytext.writes(notebook, fmt="ipynb") + + with self.build_kaggle_client() as kaggle: + try: + task_info = self._get_benchmark_task(task, kaggle) + if task_info.creation_state in self._PENDING_CREATION_STATES: + raise ValueError(f"Task '{task}' is currently being created (pending). Cannot push now.") + except HTTPError as e: + if e.response.status_code != 404: + raise + + request = ApiCreateBenchmarkTaskRequest() + request.slug = task + request.text = notebook_content + + response = kaggle.benchmarks.benchmark_tasks_api_client.create_benchmark_task(request) + error = getattr(response, "error_message", None) or getattr(response, "errorMessage", None) + if error: + raise ValueError(f"Failed to push task: {error}") + print(f"Task '{task}' pushed.") + url = response.url + if url.startswith("/"): + url = "https://www.kaggle.com" + url + print(f"Task URL: {url}") + + def _select_models_interactively(self, kaggle, page_size=20): + """Prompt the user to pick benchmark models from a paginated list.""" + + def _fetch_models(page_token): + req = ApiListBenchmarkModelsRequest() + if page_token: + req.page_token = page_token + return kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(req) + + available = self._paginate(_fetch_models, lambda r: r.benchmark_models) + if not available: + raise ValueError("No benchmark models available. Cannot schedule runs.") + + total = len(available) + total_pages = -(-total // page_size) # ceiling division + current_page = 0 + + print(f"No model specified. {total} model(s) available:") + while True: + start = current_page * page_size + for i, m in enumerate(available[start : start + page_size], start=start + 1): + print(f" {i}. {m.version.slug} ({m.display_name})") + + nav_hints = [] + if total_pages > 1: + print(f" [Page {current_page + 1}/{total_pages}]") + if current_page < total_pages - 1: + nav_hints.append("'n'=next") + if current_page > 0: + nav_hints.append("'p'=prev") + + prompt_parts = ["Enter model numbers (comma-separated)", "'all'"] + if nav_hints: + prompt_parts.extend(nav_hints) + selection = input(", ".join(prompt_parts) + ": ").strip().lower() + + if selection == "n" and current_page < total_pages - 1: + current_page += 1 + elif selection == "p" and current_page > 0: + current_page -= 1 + elif selection == "all": + return [m.version.slug for m in available] + else: + try: + indices = [int(s) for s in selection.split(",")] + return [available[i - 1].version.slug for i in indices] + except (ValueError, IndexError): + raise ValueError(f"Invalid selection: {selection}") + + def _poll_runs(self, kaggle, task_slug_obj, models, wait, poll_interval): + """Poll run status until all runs are terminal or timeout.""" + + def _fetch_runs(page_token): + req = ApiListBenchmarkTaskRunsRequest() + req.task_slug = task_slug_obj + if models: + req.model_version_slugs = models + if page_token: + req.page_token = page_token + return kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(req) + + print("Waiting for run(s) to complete...") + start_time = time.time() + while True: + all_runs = self._paginate(_fetch_runs, lambda r: r.runs) + + if all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs): + print("All runs completed:") + for r in all_runs: + label = ( + "COMPLETED" + if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED + else "ERRORED" + ) + print(f" {r.model_version_slug}: {label}") + return + + pending = sum(1 for r in all_runs if r.state not in self._TERMINAL_RUN_STATES) + print(f" {pending} run(s) still in progress...") + + if wait > 0 and (time.time() - start_time) > wait: + print(f"Timed out waiting for runs after {wait} seconds.") + return + + time.sleep(poll_interval) + + def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10): + models = self._normalize_model_list(model) + task_slug_obj = self._make_task_slug(task) + + with self.build_kaggle_client() as kaggle: + # Verify the task exists and is ready to run + task_info = self._get_benchmark_task(task, kaggle) + if ( + task_info.creation_state + != BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED + ): + error_msg = f"Task '{task}' is not ready to run (status: {task_info.creation_state})." + if ( + task_info.creation_state + == BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_ERRORED + ): + error_msg += f" Task Info: {task_info}." + error_msg += " Only completed tasks can be run." + raise ValueError(error_msg) + + if not models: + models = self._select_models_interactively(kaggle) + print(f"Selected models: {models}") + + request = ApiBatchScheduleBenchmarkTaskRunsRequest() + request.task_slugs = [task_slug_obj] + request.model_version_slugs = models + + response = kaggle.benchmarks.benchmark_tasks_api_client.batch_schedule_benchmark_task_runs(request) + print(f"Submitted run(s) for task '{task}'.") + for model_slug, res in zip(models, response.results): + if res.run_scheduled: + print(f" {model_slug}: Scheduled") + else: + print(f" {model_slug}: Skipped ({res.run_skipped_reason})") + + if wait is not None: + self._poll_runs(kaggle, task_slug_obj, models, wait, poll_interval) + class TqdmBufferedReader(io.BufferedReader): diff --git a/src/kaggle/cli.py b/src/kaggle/cli.py index 754b4237..5a4567c0 100644 --- a/src/kaggle/cli.py +++ b/src/kaggle/cli.py @@ -54,6 +54,7 @@ def main() -> None: parse_kernels(subparsers) parse_models(subparsers) parse_files(subparsers) + parse_benchmarks(subparsers) parse_config(subparsers) if api.enable_oauth: parse_auth(subparsers) @@ -979,6 +980,48 @@ def parse_files(subparsers) -> None: parser_files_upload.set_defaults(func=api.files_upload_cli) +def parse_benchmarks(subparsers) -> None: + parser_benchmarks = subparsers.add_parser( + "benchmarks", formatter_class=argparse.RawTextHelpFormatter, help=Help.group_benchmarks, aliases=["b"] + ) + subparsers_benchmarks = parser_benchmarks.add_subparsers(title="commands", dest="command") + subparsers_benchmarks.required = True + subparsers_benchmarks.choices = Help.benchmarks_choices + + parse_benchmark_tasks(subparsers_benchmarks) + + +def parse_benchmark_tasks(subparsers) -> None: + parser_tasks = subparsers.add_parser( + "tasks", formatter_class=argparse.RawTextHelpFormatter, help=Help.group_benchmarks_tasks, aliases=["t"] + ) + subparsers_tasks = parser_tasks.add_subparsers(title="commands", dest="command") + subparsers_tasks.required = True + subparsers_tasks.choices = Help.benchmarks_tasks_choices + + # push + parser_push = subparsers_tasks.add_parser( + "push", formatter_class=argparse.RawTextHelpFormatter, help=Help.command_benchmarks_tasks_push + ) + parser_push_optional = parser_push._action_groups.pop() + parser_push_optional.add_argument("task", help=Help.param_benchmarks_task) + parser_push_optional.add_argument("-f", "--file", dest="file", required=True, help=Help.param_benchmarks_file) + parser_push._action_groups.append(parser_push_optional) + parser_push.set_defaults(func=api.benchmarks_tasks_push_cli) + + # run + parser_run = subparsers_tasks.add_parser( + "run", formatter_class=argparse.RawTextHelpFormatter, help=Help.command_benchmarks_tasks_run + ) + parser_run_optional = parser_run._action_groups.pop() + parser_run_optional.add_argument("task", help=Help.param_benchmarks_task) + parser_run_optional.add_argument("-m", "--model", dest="model", nargs="+", required=False, help=Help.param_benchmarks_model) + parser_run_optional.add_argument("--wait", dest="wait", type=int, nargs="?", const=0, default=None, required=False, help=Help.param_benchmarks_wait) + parser_run_optional.add_argument("--poll-interval", dest="poll_interval", type=int, default=10, required=False, help=Help.param_benchmarks_poll_interval) + parser_run._action_groups.append(parser_run_optional) + parser_run.set_defaults(func=api.benchmarks_tasks_run_cli) + + def parse_config(subparsers) -> None: parser_config = subparsers.add_parser( "config", formatter_class=argparse.RawTextHelpFormatter, help=Help.group_config @@ -1068,6 +1111,8 @@ class Help(object): "m", "files", "f", + "benchmarks", + "b", "config", "auth", ] @@ -1078,6 +1123,8 @@ class Help(object): model_instances_choices = ["versions", "v", "get", "files", "list", "init", "create", "delete", "update"] model_instance_versions_choices = ["init", "create", "download", "delete", "files", "list"] files_choices = ["upload"] + benchmarks_choices = ["tasks", "t"] + benchmarks_tasks_choices = ["push", "run"] config_choices = ["view", "set", "unset"] auth_choices = ["login", "print-access-token", "revoke"] @@ -1094,6 +1141,8 @@ class Help(object): + ", ".join(model_instances_choices) + "}\nmodels variations versions {" + ", ".join(model_instance_versions_choices) + + "}\nbenchmarks {" + + ", ".join(benchmarks_choices) + "}\nconfig {" + ", ".join(config_choices) + "}" @@ -1108,6 +1157,8 @@ class Help(object): group_model_instances = "Commands related to Kaggle model variations" group_model_instance_versions = "Commands related to Kaggle model variations versions" group_files = "Commands related files" + group_benchmarks = "Commands related to Kaggle benchmarks" + group_benchmarks_tasks = "Commands related to benchmark tasks" group_config = "Configuration settings" group_auth = "Commands related to authentication" @@ -1149,6 +1200,10 @@ class Help(object): command_models_delete = "Delete a model" command_models_update = "Update a model" + # Benchmarks commands + command_benchmarks_tasks_push = "Create or update a task from a Python source file" + command_benchmarks_tasks_run = "Run a task against model(s)" + # Files commands command_files_upload = "Upload files" @@ -1365,6 +1420,13 @@ class Help(object): param_files_upload_no_compress = "Whether to compress directories (zip) or not (tar)" param_files_upload_no_resume = "Whether to skip resumable uploads." + # Benchmarks params + param_benchmarks_task = "Task name" + param_benchmarks_file = "Python source file containing the task definition" + param_benchmarks_model = "Model slug(s) to run the task against" + param_benchmarks_wait = "Wait for runs to complete (seconds). 0 means wait indefinitely." + param_benchmarks_poll_interval = "Polling interval in seconds when waiting for runs" + # Config params param_config_name = "Name of the configuration parameter\n(one of " "competition, path, proxy)" param_config_value = ( diff --git a/src/kaggle/test/conftest.py b/src/kaggle/test/conftest.py new file mode 100644 index 00000000..9595981d --- /dev/null +++ b/src/kaggle/test/conftest.py @@ -0,0 +1,18 @@ +"""Shared test configuration for kaggle CLI tests. + +Must be set before importing kaggle, which calls api.authenticate() at +module level. Fake legacy credentials keep authenticate() off the network; +removing KAGGLE_API_TOKEN prevents _introspect_token() from being called. +We also patch get_access_token_from_env so the ~/.kaggle/access_token file +doesn't trigger token introspection. +""" + +import os +from unittest.mock import patch + +os.environ.pop("KAGGLE_API_TOKEN", None) +os.environ["KAGGLE_USERNAME"] = "testuser" +os.environ["KAGGLE_KEY"] = "testkey" + +with patch("kagglesdk.get_access_token_from_env", return_value=(None, None)): + import kaggle # noqa: F401 — triggers authenticate() diff --git a/src/kaggle/test/test_benchmarks_cli.py b/src/kaggle/test/test_benchmarks_cli.py new file mode 100644 index 00000000..ddf3a0df --- /dev/null +++ b/src/kaggle/test/test_benchmarks_cli.py @@ -0,0 +1,405 @@ +"""Tests for ``kaggle benchmarks tasks`` CLI commands. + +Organized by user journey: + TestPush – ``kaggle benchmarks tasks push -f `` + TestRun – ``kaggle benchmarks tasks run [-m ...] [--wait]`` + TestCliArgParsing – argparse wiring verification + +Uses pure pytest (capsys, tmp_path) instead of unittest.TestCase so that +stdout capture and temp-file creation are handled by built-in fixtures +rather than manual patching. +""" + +import argparse +from unittest.mock import patch, MagicMock + +import pytest +from requests.exceptions import HTTPError + +from kaggle.api.kaggle_api_extended import KaggleApi +from kagglesdk.benchmarks.types.benchmark_enums import ( + BenchmarkTaskVersionCreationState, + BenchmarkTaskRunState, +) + +# Short aliases for the verbose enum members used throughout the tests. +QUEUED = BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED +RUNNING = BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING +COMPLETED = BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED + +RUN_RUNNING = BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_RUNNING +RUN_COMPLETED = BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED +RUN_ERRORED = BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED + +DEFAULT_TASK_CONTENT = '@task(name="my-task")\ndef evaluate(): pass\n' + + +# ---- Fixtures & helpers ---- + + +@pytest.fixture +def api(): + """A KaggleApi with mocked auth and client — no network calls.""" + a = KaggleApi() + a.authenticate = MagicMock() + mock_client = MagicMock() + a.build_kaggle_client = MagicMock() + a.build_kaggle_client.return_value.__enter__.return_value = mock_client + # Expose internals so helpers can wire up responses. + a._mock_client = mock_client + a._mock_benchmarks = mock_client.benchmarks.benchmark_tasks_api_client + return a + + +def _write_task_file(tmp_path, content=DEFAULT_TASK_CONTENT, name="task.py"): + """Write *content* to a .py file under *tmp_path* and return its path str.""" + p = tmp_path / name + p.write_text(content) + return str(p) + + +def _mock_jupytext(): + """Return ``(mock_jupytext_module, context_manager)``.""" + jt = MagicMock() + notebook = MagicMock() + notebook.metadata = {} + jt.reads.return_value = notebook + jt.writes.return_value = '{"cells": []}' + return jt, patch.dict("sys.modules", {"jupytext": jt}) + + +def _push(api, task, filepath): + """Call ``benchmarks_tasks_push_cli`` with jupytext mocked. + + Returns the mock jupytext module so callers can assert on calls. + """ + jt, ctx = _mock_jupytext() + with ctx: + api.benchmarks_tasks_push_cli(task, filepath) + return jt + + +def _make_task(slug="my-task", state=COMPLETED): + t = MagicMock() + t.slug.task_slug = slug + t.creation_state = state + t.create_time = "2026-04-06" + return t + + +def _make_run_result(scheduled=True, skipped_reason=None): + r = MagicMock() + r.run_scheduled = scheduled + r.benchmark_task_version_id = 1 + r.benchmark_model_version_id = 10 + r.run_skipped_reason = skipped_reason + return r + + +def _make_run(model="gemini-pro", state=RUN_COMPLETED): + r = MagicMock() + r.model_version_slug = model + r.state = state + return r + + +def _setup_create_response(api, task_slug="my-task"): + resp = MagicMock() + resp.slug.task_slug = task_slug + resp.url = f"https://kaggle.com/benchmarks/{task_slug}" + resp.error_message = None + resp.errorMessage = None + api._mock_benchmarks.create_benchmark_task.return_value = resp + + +def _setup_completed_task(api, slug="my-task"): + task = _make_task(slug=slug, state=COMPLETED) + api._mock_benchmarks.get_benchmark_task.return_value = task + + +def _setup_batch_schedule(api, results): + resp = MagicMock() + resp.results = results + api._mock_benchmarks.batch_schedule_benchmark_task_runs.return_value = resp + + +def _setup_available_models(api, slugs): + models = [] + for s in slugs: + m = MagicMock() + m.version.slug = s + m.display_name = s.title() + models.append(m) + resp = MagicMock() + resp.benchmark_models = models + resp.next_page_token = "" + api._mock_client.benchmarks.benchmarks_api_client.list_benchmark_models.return_value = resp + + +# ============================================================ +# Push Journey +# ============================================================ + + +class TestPush: + """``kaggle benchmarks tasks push -f ``""" + + # -- Input validation (before any server call) -- + + @pytest.mark.parametrize( + "task, filename, content, expected_error", + [ + ("my-task", None, None, "does not exist"), + ("my-task", "task.txt", "hello", "must be a .py"), + ("any-task", "task.py", "def f(): pass\n", "No @task decorators"), + ("wrong", "task.py", '@task(name="real")\ndef f(llm): pass\n', "not found"), + ("any-task", "task.py", "def broken(\n", "No @task decorators"), + ], + ids=["missing_file", "wrong_extension", "no_decorators", "wrong_name", "syntax_error"], + ) + def test_push_rejects_invalid_input(self, api, tmp_path, task, filename, content, expected_error): + if filename is None: + filepath = "/nonexistent/task.py" + else: + filepath = _write_task_file(tmp_path, content, name=filename) + with pytest.raises(ValueError, match=expected_error): + api.benchmarks_tasks_push_cli(task, filepath) + + # -- Happy path -- + + @pytest.mark.parametrize( + "content, task_name", + [ + ('@task(name="my-task")\ndef evaluate(): pass\n', "my-task"), + ("@task\ndef my_task(llm): pass\n", "My Task"), + ("@task\nasync def my_task(llm): pass\n", "My Task"), + ], + ids=["explicit_name", "title_cased", "async_function"], + ) + def test_push_creates_task(self, api, tmp_path, capsys, content, task_name): + """Push converts .py → ipynb via jupytext and creates the task.""" + filepath = _write_task_file(tmp_path, content) + _setup_create_response(api, task_name) + + jt = _push(api, task_name, filepath) + + # Verify jupytext conversion happened + jt.reads.assert_called_once() + jt.writes.assert_called_once() + request = api._mock_benchmarks.create_benchmark_task.call_args[0][0] + assert request.text == '{"cells": []}' + + assert f"Task '{task_name}' pushed." in capsys.readouterr().out + + def test_push_creates_new_task_on_404(self, api, tmp_path, capsys): + """A 404 from get_benchmark_task means new task — still creates successfully.""" + filepath = _write_task_file(tmp_path) + api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError(response=MagicMock(status_code=404)) + _setup_create_response(api) + _push(api, "my-task", filepath) + assert "Task 'my-task' pushed." in capsys.readouterr().out + + # -- Server edge cases -- + + @pytest.mark.parametrize("state", [QUEUED, RUNNING], ids=["queued", "running"]) + def test_push_rejects_pending_task(self, api, tmp_path, state): + """Push rejects when the task version is still being created.""" + filepath = _write_task_file(tmp_path) + api._mock_benchmarks.get_benchmark_task.return_value = _make_task(state=state) + with pytest.raises(ValueError, match="currently being created"): + _push(api, "my-task", filepath) + + def test_push_propagates_server_error(self, api, tmp_path): + """Non-404 HTTP errors (e.g. 500) are re-raised, not swallowed.""" + filepath = _write_task_file(tmp_path) + api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError(response=MagicMock(status_code=500)) + with pytest.raises(HTTPError): + _push(api, "my-task", filepath) + + def test_push_handles_api_error(self, api, tmp_path): + """Push raises ValueError when response contains error_message.""" + filepath = _write_task_file(tmp_path) + _setup_completed_task(api) + + resp = MagicMock() + resp.error_message = "Some backend error" + api._mock_benchmarks.create_benchmark_task.return_value = resp + + with pytest.raises(ValueError, match="Failed to push task: Some backend error"): + _push(api, "my-task", filepath) + + +# ============================================================ +# Run Journey +# ============================================================ + + +class TestRun: + """``kaggle benchmarks tasks run [-m ...] [--wait]``""" + + # -- Pre-conditions -- + + def test_run_rejects_non_completed_task(self, api): + """Run errors when the task creation state is not COMPLETED.""" + api._mock_benchmarks.get_benchmark_task.return_value = _make_task(state=QUEUED) + with pytest.raises(ValueError, match="not ready to run"): + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) + api._mock_benchmarks.batch_schedule_benchmark_task_runs.assert_not_called() + + # -- Model scheduling -- + + @pytest.mark.parametrize( + "models", + [["gemini-pro"], ["gemini-pro", "gemma-2b"]], + ids=["single_model", "multiple_models"], + ) + def test_run_schedules_models(self, api, capsys, models): + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result() for _ in models]) + api.benchmarks_tasks_run_cli("my-task", models) + output = capsys.readouterr().out + assert "Submitted run(s) for task 'my-task'" in output + for m in models: + assert f"{m}: Scheduled" in output + + def test_run_reports_skipped_with_reason(self, api, capsys): + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result(scheduled=False, skipped_reason="Already running")]) + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) + output = capsys.readouterr().out + assert "gemini-pro: Skipped" in output + assert "Already running" in output + + # -- Interactive model selection -- + + def test_run_prompts_model_selection(self, api): + """No model specified → user picks from a numbered list.""" + _setup_completed_task(api) + _setup_available_models(api, ["gemini-pro", "gemma-2b"]) + _setup_batch_schedule(api, [_make_run_result()]) + with patch("builtins.input", return_value="1"): + api.benchmarks_tasks_run_cli("my-task") + request = api._mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] + assert request.model_version_slugs == ["gemini-pro"] + + def test_run_selects_all_models(self, api): + _setup_completed_task(api) + _setup_available_models(api, ["gemini-pro", "gemma-2b"]) + _setup_batch_schedule(api, []) + with patch("builtins.input", return_value="all"): + api.benchmarks_tasks_run_cli("my-task") + request = api._mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] + assert request.model_version_slugs == ["gemini-pro", "gemma-2b"] + + def test_run_rejects_empty_model_list(self, api): + """No models available on server → ValueError.""" + _setup_completed_task(api) + _setup_available_models(api, []) + with pytest.raises(ValueError, match="No benchmark models available"): + api.benchmarks_tasks_run_cli("my-task") + + def test_run_rejects_invalid_model_selection(self, api): + """Bad input during interactive model selection → ValueError.""" + _setup_completed_task(api) + _setup_available_models(api, ["gemini-pro"]) + with patch("builtins.input", return_value="abc"): + with pytest.raises(ValueError, match="Invalid selection"): + api.benchmarks_tasks_run_cli("my-task") + + # -- Wait / polling -- + + def test_run_wait_polls_until_completion(self, api, capsys): + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result()]) + api._mock_benchmarks.list_benchmark_task_runs.side_effect = [ + MagicMock(runs=[_make_run(state=RUN_RUNNING)], next_page_token=""), + MagicMock(runs=[_make_run(state=RUN_COMPLETED)], next_page_token=""), + ] + with patch("time.sleep"): + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) + output = capsys.readouterr().out + assert "Waiting for run(s) to complete" in output + assert "All runs completed" in output + assert "gemini-pro: COMPLETED" in output + + def test_run_wait_times_out(self, api, capsys): + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result()]) + api._mock_benchmarks.list_benchmark_task_runs.return_value = MagicMock( + runs=[_make_run(state=RUN_RUNNING)], next_page_token="" + ) + with patch("time.sleep"), patch("time.time", side_effect=[1000, 1060]): + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=30) + output = capsys.readouterr().out + assert "Timed out waiting for runs after 30 seconds" in output + + def test_run_wait_shows_errored_runs(self, api, capsys): + """ERRORED runs display with ERRORED label.""" + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result()]) + api._mock_benchmarks.list_benchmark_task_runs.return_value = MagicMock( + runs=[_make_run(state=RUN_ERRORED)], next_page_token="" + ) + with patch("time.sleep"): + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) + assert "gemini-pro: ERRORED" in capsys.readouterr().out + + +# ============================================================ +# CLI Arg Parsing +# ============================================================ + + +class TestCliArgParsing: + """Tests that argparse wiring for ``kaggle benchmarks tasks`` is correct. + + Verifies argument names, aliases (b/t), required flags, and nargs + constraints are properly configured in cli.py. + """ + + def setup_method(self): + self.parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + subparsers = self.parser.add_subparsers(title="commands", dest="command") + subparsers.required = True + from kaggle.cli import parse_benchmarks + + parse_benchmarks(subparsers) + + def _parse(self, arg_string): + return self.parser.parse_args(arg_string.split()) + + @pytest.mark.parametrize( + "cmd, expected", + [ + ("benchmarks tasks push my-task -f ./task.py", {"task": "my-task", "file": "./task.py"}), + ("b t push my-task -f ./task.py", {"task": "my-task", "file": "./task.py"}), + ("benchmarks tasks run my-task", {"task": "my-task", "model": None, "wait": None}), + ("benchmarks tasks run my-task -m gemini-3 --wait", {"model": ["gemini-3"], "wait": 0}), + ( + "benchmarks tasks run my-task -m gemini-3 --wait 60", + {"model": ["gemini-3"], "wait": 60}, + ), + ( + "benchmarks tasks run my-task -m gemini-3 gpt-5 claude-4", + {"model": ["gemini-3", "gpt-5", "claude-4"]}, + ), + ("b t run my-task -m gemini-3", {"task": "my-task", "model": ["gemini-3"]}), + ], + ) + def test_parse_success(self, cmd, expected): + args = self._parse(cmd) + for key, val in expected.items(): + assert getattr(args, key) == val + + @pytest.mark.parametrize( + "cmd", + [ + "benchmarks tasks push my-task", + "benchmarks tasks run my-task -m", + ], + ) + def test_parse_error(self, cmd): + with pytest.raises(SystemExit): + self._parse(cmd)