From 46c5a9bd966edf252cff4dfb648e83654f44912c Mon Sep 17 00:00:00 2001 From: Li Ma Date: Wed, 8 Apr 2026 22:10:32 +0000 Subject: [PATCH 1/9] Extract push and run commands to new branch --- src/kaggle/api/kaggle_api_extended.py | 175 ++++++++++++++++++++++++++ src/kaggle/cli.py | 62 +++++++++ 2 files changed, 237 insertions(+) diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index d3952a64..05971076 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -55,6 +55,17 @@ from kagglesdk import get_access_token_from_env, KaggleClient, KaggleCredentials, KaggleEnv, KaggleOAuth # type: ignore[attr-defined] from kagglesdk.admin.types.inbox_file_service import CreateInboxFileRequest from kagglesdk.blobs.types.blob_api_service import ApiStartBlobUploadRequest, ApiStartBlobUploadResponse, ApiBlobType +from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskRunState, BenchmarkTaskVersionCreationState +from kagglesdk.benchmarks.types.benchmark_tasks_api_service import ( + ApiCreateBenchmarkTaskRequest, + ApiListBenchmarkTasksRequest, + ApiGetBenchmarkTaskRequest, + ApiListBenchmarkTaskRunsRequest, + ApiBenchmarkTaskSlug, + ApiBatchScheduleBenchmarkTaskRunsRequest, + ApiDownloadBenchmarkTaskRunOutputRequest, +) +from kagglesdk.benchmarks.types.benchmarks_api_service import ApiListBenchmarkModelsRequest from kagglesdk.competitions.types.competition_api_service import ( ApiListCompetitionsRequest, ApiCreateCodeSubmissionRequest, @@ -5356,6 +5367,170 @@ def _check_response_version(self, response: Response): def get_response_processor(self): return self._check_response_version + # ---- Benchmarks CLI ---- + + _TERMINAL_RUN_STATES = { + BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED, + BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED, + } + + @staticmethod + def _make_task_slug(task: str) -> ApiBenchmarkTaskSlug: + """Build an ApiBenchmarkTaskSlug from a task name string.""" + slug = ApiBenchmarkTaskSlug() + slug.task_slug = task + return slug + + @staticmethod + def _normalize_model_list(model) -> list: + """Normalize a model argument (str, list, or None) into a list.""" + if isinstance(model, list): + return model + return [model] if model else [] + + def _get_task_names_from_file(self, file_content: str) -> List[str]: + """Extract task names from a Python file.""" + import ast + task_names = [] + try: + tree = ast.parse(file_content) + except SyntaxError: + return [] + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + + for decorator in node.decorator_list: + func = decorator.func if isinstance(decorator, ast.Call) else decorator + + if not ((isinstance(func, ast.Name) and func.id == 'task') or + (isinstance(func, ast.Attribute) and func.attr == 'task')): + continue + + name = None + if isinstance(decorator, ast.Call): + name = next((k.value.value for k in decorator.keywords if k.arg == 'name' and isinstance(k.value, ast.Constant)), None) + + task_names.append(name if name else node.name.title().replace("_", " ")) + + return task_names + + def _get_benchmark_task(self, task: str, kaggle): + """Get benchmark task details from the server.""" + request = ApiGetBenchmarkTaskRequest() + request.slug = self._make_task_slug(task) + return kaggle.benchmarks.benchmark_tasks_api_client.get_benchmark_task(request) + + def _validate_task_in_file(self, task: str, file: str, file_content: str): + """Validate that the task name is defined in the Python file.""" + task_names = self._get_task_names_from_file(file_content) + if not task_names: + raise ValueError(f"No @task decorators found in file {file}. The file must define at least one task.") + if task not in task_names: + raise ValueError(f"Task '{task}' not found in file {file}. Found tasks: {', '.join(task_names)}") + + def benchmarks_tasks_push_cli(self, task, file): + if not os.path.isfile(file): + raise ValueError(f"File {file} does not exist") + if not file.endswith(".py"): + raise ValueError(f"File {file} must be a .py file") + + with open(file, 'r') as f: + content = f.read() + + self._validate_task_in_file(task, file, content) + + # Convert .py file with percent delimiters to .ipynb + import jupytext + notebook = jupytext.reads(content, fmt="py:percent") + notebook_content = jupytext.writes(notebook, fmt="ipynb") + + with self.build_kaggle_client() as kaggle: + try: + task_info = self._get_benchmark_task(task, kaggle) + if task_info.creation_state in [ + BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED, + BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING + ]: + raise ValueError(f"Task '{task}' is currently being created (pending). Cannot push now.") + except HTTPError as e: + if e.response.status_code != 404: + raise + + request = ApiCreateBenchmarkTaskRequest() + request.slug = task + # Assume create_benchmark_task accepts ipynb content (JSON string) + request.text = notebook_content + + response = kaggle.benchmarks.benchmark_tasks_api_client.create_benchmark_task(request) + print(f"Task '{task}' pushed.") + print(f"Task URL: {response.url}") + + def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10): + models = self._normalize_model_list(model) + task_slug_obj = self._make_task_slug(task) + + with self.build_kaggle_client() as kaggle: + # If no models specified, prompt the user to select from available models + if not models: + models_request = ApiListBenchmarkModelsRequest() + models_response = kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(models_request) + available = models_response.benchmark_models + if not available: + raise ValueError("No benchmark models available. Cannot schedule runs.") + print("No model specified. Available models:") + for i, m in enumerate(available, 1): + print(f" {i}. {m.slug} ({m.display_name})") + selection = input("Enter model numbers (comma-separated), or 'all': ").strip() + if selection.lower() == "all": + models = [m.slug for m in available] + else: + try: + indices = [int(s.strip()) for s in selection.split(",")] + models = [available[i - 1].slug for i in indices] + except (ValueError, IndexError): + raise ValueError(f"Invalid selection: {selection}") + + request = ApiBatchScheduleBenchmarkTaskRunsRequest() + request.task_slugs = [task_slug_obj] + request.model_slugs = models + + response = kaggle.benchmarks.benchmark_tasks_api_client.batch_schedule_benchmark_task_runs(request) + print(f"Submitted run(s) for task '{task}'.") + for model_slug, res in zip(models, response.results): + if res.run_scheduled: + print(f" {model_slug}: Scheduled") + else: + print(f" {model_slug}: Skipped ({res.run_skipped_reason})") + + if wait is not None: + import time + print("Waiting for run(s) to complete...") + start_time = time.time() + while True: + runs_request = ApiListBenchmarkTaskRunsRequest() + runs_request.task_slugs = [task_slug_obj] + if models: + runs_request.model_slugs = models + runs_resp = kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(runs_request) + all_done = runs_resp.runs and all(r.state in self._TERMINAL_RUN_STATES for r in runs_resp.runs) + if all_done: + print("All runs completed:") + for r in runs_resp.runs: + state_label = "COMPLETED" if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED else "ERRORED" + print(f" {r.model_slug}: {state_label}") + break + + pending = sum(1 for r in runs_resp.runs if r.state not in self._TERMINAL_RUN_STATES) + print(f" {pending} run(s) still in progress...") + + if wait > 0 and (time.time() - start_time) > wait: + print(f"Timed out waiting for runs after {wait} seconds.") + break + + time.sleep(poll_interval) + class TqdmBufferedReader(io.BufferedReader): diff --git a/src/kaggle/cli.py b/src/kaggle/cli.py index 87558529..37430638 100644 --- a/src/kaggle/cli.py +++ b/src/kaggle/cli.py @@ -54,6 +54,7 @@ def main() -> None: parse_kernels(subparsers) parse_models(subparsers) parse_files(subparsers) + parse_benchmarks(subparsers) parse_config(subparsers) if api.enable_oauth: parse_auth(subparsers) @@ -979,6 +980,48 @@ def parse_files(subparsers) -> None: parser_files_upload.set_defaults(func=api.files_upload_cli) +def parse_benchmarks(subparsers) -> None: + parser_benchmarks = subparsers.add_parser( + "benchmarks", formatter_class=argparse.RawTextHelpFormatter, help=Help.group_benchmarks, aliases=["b"] + ) + subparsers_benchmarks = parser_benchmarks.add_subparsers(title="commands", dest="command") + subparsers_benchmarks.required = True + subparsers_benchmarks.choices = Help.benchmarks_choices + + parse_benchmark_tasks(subparsers_benchmarks) + + +def parse_benchmark_tasks(subparsers) -> None: + parser_tasks = subparsers.add_parser( + "tasks", formatter_class=argparse.RawTextHelpFormatter, help=Help.group_benchmarks_tasks, aliases=["t"] + ) + subparsers_tasks = parser_tasks.add_subparsers(title="commands", dest="command") + subparsers_tasks.required = True + subparsers_tasks.choices = Help.benchmarks_tasks_choices + + # push + parser_push = subparsers_tasks.add_parser( + "push", formatter_class=argparse.RawTextHelpFormatter, help=Help.command_benchmarks_tasks_push + ) + parser_push_optional = parser_push._action_groups.pop() + parser_push_optional.add_argument("task", help=Help.param_benchmarks_task) + parser_push_optional.add_argument("-f", "--file", dest="file", required=True, help=Help.param_benchmarks_file) + parser_push._action_groups.append(parser_push_optional) + parser_push.set_defaults(func=api.benchmarks_tasks_push_cli) + + # run + parser_run = subparsers_tasks.add_parser( + "run", formatter_class=argparse.RawTextHelpFormatter, help=Help.command_benchmarks_tasks_run + ) + parser_run_optional = parser_run._action_groups.pop() + parser_run_optional.add_argument("task", help=Help.param_benchmarks_task) + parser_run_optional.add_argument("-m", "--model", dest="model", nargs="+", required=False, help=Help.param_benchmarks_model) + parser_run_optional.add_argument("--wait", dest="wait", type=int, nargs="?", const=0, default=None, required=False, help=Help.param_benchmarks_wait) + parser_run_optional.add_argument("--poll-interval", dest="poll_interval", type=int, default=10, required=False, help=Help.param_benchmarks_poll_interval) + parser_run._action_groups.append(parser_run_optional) + parser_run.set_defaults(func=api.benchmarks_tasks_run_cli) + + def parse_config(subparsers) -> None: parser_config = subparsers.add_parser( "config", formatter_class=argparse.RawTextHelpFormatter, help=Help.group_config @@ -1068,6 +1111,8 @@ class Help(object): "m", "files", "f", + "benchmarks", + "b", "config", "auth", ] @@ -1078,6 +1123,8 @@ class Help(object): model_instances_choices = ["versions", "v", "get", "files", "list", "init", "create", "delete", "update"] model_instance_versions_choices = ["init", "create", "download", "delete", "files", "list"] files_choices = ["upload"] + benchmarks_choices = ["tasks", "t"] + benchmarks_tasks_choices = ["push", "run"] config_choices = ["view", "set", "unset"] auth_choices = ["login", "print-access-token", "revoke"] @@ -1094,6 +1141,8 @@ class Help(object): + ", ".join(model_instances_choices) + "}\nmodels variations versions {" + ", ".join(model_instance_versions_choices) + + "}\nbenchmarks {" + + ", ".join(benchmarks_choices) + "}\nconfig {" + ", ".join(config_choices) + "}" @@ -1108,6 +1157,8 @@ class Help(object): group_model_instances = "Commands related to Kaggle model variations" group_model_instance_versions = "Commands related to Kaggle model variations versions" group_files = "Commands related files" + group_benchmarks = "Commands related to Kaggle benchmarks" + group_benchmarks_tasks = "Commands related to benchmark tasks" group_config = "Configuration settings" group_auth = "Commands related to authentication" @@ -1149,6 +1200,10 @@ class Help(object): command_models_delete = "Delete a model" command_models_update = "Update a model" + # Benchmarks commands + command_benchmarks_tasks_push = "Register a task from a Python source file" + command_benchmarks_tasks_run = "Run a task against model(s)" + # Files commands command_files_upload = "Upload files" @@ -1367,6 +1422,13 @@ class Help(object): param_files_upload_no_compress = "Whether to compress directories (zip) or not (tar)" param_files_upload_no_resume = "Whether to skip resumable uploads." + # Benchmarks params + param_benchmarks_task = "Task name" + param_benchmarks_file = "Python source file containing the task definition" + param_benchmarks_model = "Model slug(s) to run the task against" + param_benchmarks_wait = "Wait for runs to complete (seconds). 0 means wait indefinitely." + param_benchmarks_poll_interval = "Polling interval in seconds when waiting for runs" + # Config params param_config_name = "Name of the configuration parameter\n(one of " "competition, path, proxy)" param_config_value = ( From 46b090f90f1bf11e7eedb1a826e2f3e28ce36199 Mon Sep 17 00:00:00 2001 From: Li Ma Date: Wed, 8 Apr 2026 22:27:12 +0000 Subject: [PATCH 2/9] Add tests for push and run commands --- src/kaggle/test/test_benchmarks_cli.py | 411 +++++++++++++++++++++++++ 1 file changed, 411 insertions(+) create mode 100644 src/kaggle/test/test_benchmarks_cli.py diff --git a/src/kaggle/test/test_benchmarks_cli.py b/src/kaggle/test/test_benchmarks_cli.py new file mode 100644 index 00000000..f99ddbbb --- /dev/null +++ b/src/kaggle/test/test_benchmarks_cli.py @@ -0,0 +1,411 @@ +import os +from unittest.mock import patch as _patch + +# Must be set before importing kaggle, which calls api.authenticate() at +# module level. Fake legacy credentials keep authenticate() off the network; +# removing KAGGLE_API_TOKEN prevents _introspect_token() from being called. +# We also patch get_access_token_from_env so the ~/.kaggle/access_token file +# doesn't trigger token introspection. +os.environ.pop("KAGGLE_API_TOKEN", None) +os.environ["KAGGLE_USERNAME"] = "testuser" +os.environ["KAGGLE_KEY"] = "testkey" + +with _patch("kagglesdk.get_access_token_from_env", return_value=(None, None)): + import kaggle # noqa: F401 — triggers authenticate() + +import unittest +from unittest.mock import patch, MagicMock +import argparse +import io +import tempfile +import pytest +from requests.exceptions import HTTPError +from kaggle.api.kaggle_api_extended import KaggleApi +from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskVersionCreationState, BenchmarkTaskRunState + + +class TestBenchmarksCli(unittest.TestCase): + """Tests for `kaggle benchmarks tasks ` CLI methods. + + Each test exercises one API method (e.g. benchmarks_tasks_push_cli) with + mocked SDK calls, verifying the printed output and request arguments match + the expected user experience. + """ + + TASK_FILE_CONTENT = '@task(name="my-task")\ndef evaluate(): pass\n' + + def setUp(self): + self.api = KaggleApi() + # Mock authenticate to avoid real network/creds check during unit tests + self.api.authenticate = MagicMock() + + # Mock build_kaggle_client to avoid real network calls + self.mock_client = MagicMock() + self.api.build_kaggle_client = MagicMock() + self.api.build_kaggle_client.return_value.__enter__.return_value = self.mock_client + self.mock_benchmarks = self.mock_client.benchmarks.benchmark_tasks_api_client + + # -- Helpers -- + + def _make_task_file(self, content=None): + """Create a temp .py file with task content. Caller must call .close().""" + f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=True) + f.write(content or self.TASK_FILE_CONTENT) + f.flush() + return f + + def _mock_jupytext(self): + """Return a mock jupytext module and a context manager that patches it in.""" + jt = MagicMock() + jt.reads.return_value = "mock_notebook" + jt.writes.return_value = '{"cells": []}' + return jt, patch.dict("sys.modules", {"jupytext": jt}) + + def _setup_create_response(self, task_slug="my-task"): + mock_resp = MagicMock() + mock_resp.slug.task_slug = task_slug + mock_resp.url = f"https://kaggle.com/benchmarks/{task_slug}" + self.mock_benchmarks.create_benchmark_task.return_value = mock_resp + return mock_resp + + def _make_mock_task(self, slug="my-task", state="COMPLETED", create_time="2026-04-06"): + t = MagicMock() + t.slug.task_slug = slug + t.creation_state = state + t.create_time = create_time + return t + + def _make_mock_run( + self, + model="gemini-pro", + state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED, + run_id=1, + start_time=None, + end_time=None, + error_message=None, + ): + r = MagicMock() + r.model_slug = model + r.state = state + r.id = run_id + r.start_time = start_time + r.end_time = end_time + r.error_message = error_message + return r + + def _setup_runs_response(self, runs): + resp = MagicMock() + resp.runs = runs + self.mock_benchmarks.list_benchmark_task_runs.return_value = resp + + def _make_run_result(self, scheduled=True, skipped_reason=None): + r = MagicMock() + r.run_scheduled = scheduled + r.benchmark_task_version_id = 1 + r.benchmark_model_version_id = 10 + r.run_skipped_reason = skipped_reason + return r + + # ---- kaggle benchmarks tasks push -f ---- + + @patch("sys.stdout", new_callable=io.StringIO) + def test_push_success(self, mock_stdout): + """Happy path: push a valid .py file creates the task on the server.""" + with self._make_task_file() as f: + self._setup_create_response() + jt, jt_ctx = self._mock_jupytext() + with jt_ctx: + self.api.benchmarks_tasks_push_cli("my-task", f.name) + self.assertIn("Task 'my-task' pushed.", mock_stdout.getvalue()) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_push_failure_pending(self, mock_stdout): + """Push rejected when the task version is still being created (QUEUED).""" + with self._make_task_file() as f: + task = self._make_mock_task( + state=BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED + ) + self.mock_benchmarks.get_benchmark_task.return_value = task + _, jt_ctx = self._mock_jupytext() + with jt_ctx, self.assertRaises(ValueError) as cm: + self.api.benchmarks_tasks_push_cli("my-task", f.name) + self.assertIn("is currently being created", str(cm.exception)) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_push_success_404(self, mock_stdout): + """A 404 on get means new task — should still create successfully.""" + mock_response = MagicMock() + mock_response.status_code = 404 + self.mock_benchmarks.get_benchmark_task.side_effect = HTTPError(response=mock_response) + self._setup_create_response() + with self._make_task_file() as f: + _, jt_ctx = self._mock_jupytext() + with jt_ctx: + self.api.benchmarks_tasks_push_cli("my-task", f.name) + self.assertIn("Task 'my-task' pushed.", mock_stdout.getvalue()) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_push_converts_to_ipynb(self, mock_stdout): + """Push converts the .py file to ipynb via jupytext before uploading.""" + with self._make_task_file() as f: + self._setup_create_response() + jt, jt_ctx = self._mock_jupytext() + with jt_ctx: + self.api.benchmarks_tasks_push_cli("my-task", f.name) + jt.reads.assert_called_once() + jt.writes.assert_called_once() + request = self.mock_benchmarks.create_benchmark_task.call_args[0][0] + self.assertEqual(request.text, '{"cells": []}') + + # ---- kaggle benchmarks tasks run [-m ...] [--wait] ---- + + def _setup_batch_schedule(self, results): + resp = MagicMock() + resp.results = results + self.mock_benchmarks.batch_schedule_benchmark_task_runs.return_value = resp + + def _setup_available_models(self, slugs): + models = [] + for s in slugs: + m = MagicMock() + m.slug = s + m.display_name = s.title() + models.append(m) + resp = MagicMock() + resp.benchmark_models = models + self.mock_client.benchmarks.benchmarks_api_client.list_benchmark_models.return_value = resp + + @patch("sys.stdout", new_callable=io.StringIO) + def test_run_single_model(self, mock_stdout): + """'kaggle b t run my-task -m gemini-pro' schedules one run.""" + self._setup_batch_schedule([self._make_run_result()]) + self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) + self.assertIn("Submitted run(s) for task 'my-task'", mock_stdout.getvalue()) + self.assertIn("gemini-pro: Scheduled", mock_stdout.getvalue()) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_run_multiple_models(self, mock_stdout): + """'kaggle b t run my-task -m gemini-pro gemma-2b' schedules two runs.""" + self._setup_batch_schedule([self._make_run_result(), self._make_run_result()]) + self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro", "gemma-2b"]) + output = mock_stdout.getvalue() + self.assertIn("gemini-pro: Scheduled", output) + self.assertIn("gemma-2b: Scheduled", output) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_run_skipped_result(self, mock_stdout): + """When the server skips a run (e.g. already running), show reason.""" + self._setup_batch_schedule([self._make_run_result(scheduled=False, skipped_reason="Already running")]) + self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) + output = mock_stdout.getvalue() + self.assertIn("gemini-pro: Skipped", output) + self.assertIn("Already running", output) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_run_no_model_prompts_selection(self, mock_stdout): + """When no model is specified, user is prompted to select from available models.""" + self._setup_available_models(["gemini-pro", "gemma-2b"]) + self._setup_batch_schedule([self._make_run_result()]) + with patch("builtins.input", return_value="1"): + self.api.benchmarks_tasks_run_cli("my-task") + request = self.mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] + self.assertEqual(request.model_slugs, ["gemini-pro"]) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_run_no_model_select_all(self, mock_stdout): + """When no model is specified and user selects 'all'.""" + self._setup_available_models(["gemini-pro", "gemma-2b"]) + self._setup_batch_schedule([]) + with patch("builtins.input", return_value="all"): + self.api.benchmarks_tasks_run_cli("my-task") + request = self.mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] + self.assertEqual(request.model_slugs, ["gemini-pro", "gemma-2b"]) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_run_with_wait(self, mock_stdout): + """Test --wait polls until runs complete.""" + self._setup_batch_schedule([self._make_run_result()]) + running = self._make_mock_run(state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_RUNNING) + done = self._make_mock_run(state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED) + resp1 = MagicMock(runs=[running]) + resp2 = MagicMock(runs=[done]) + self.mock_benchmarks.list_benchmark_task_runs.side_effect = [resp1, resp2] + with patch("time.sleep"): + self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) + output = mock_stdout.getvalue() + self.assertIn("Waiting for run(s) to complete", output) + self.assertIn("All runs completed", output) + self.assertIn("gemini-pro: COMPLETED", output) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_run_with_timeout(self, mock_stdout): + """Test --wait with timeout stops waiting.""" + self._setup_batch_schedule([self._make_run_result()]) + running = self._make_mock_run(state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_RUNNING) + self._setup_runs_response([running]) + with patch("time.sleep"), patch("time.time", side_effect=[1000, 1060]): + self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=30) + output = mock_stdout.getvalue() + self.assertIn("Waiting for run(s) to complete", output) + self.assertIn("Timed out waiting for runs after 30 seconds", output) + + # ---- push input validation (before any server call) ---- + + @patch("sys.stdout", new_callable=io.StringIO) + def test_push_validation_failure(self, mock_stdout): + """Push errors when the requested task name doesn't match any @task in the file.""" + with self._make_task_file('@task(name="real-task")\ndef my_task(llm): pass\n') as f: + with self.assertRaises(ValueError) as cm: + self.api.benchmarks_tasks_push_cli("wrong-task", f.name) + self.assertIn("Task 'wrong-task' not found", str(cm.exception)) + + @patch("sys.stdout", new_callable=io.StringIO) + def test_push_validation_no_tasks(self, mock_stdout): + """Push errors when the file has no @task decorators at all.""" + with self._make_task_file("def regular_function(): pass\n") as f: + with self.assertRaises(ValueError) as cm: + self.api.benchmarks_tasks_push_cli("any-task", f.name) + self.assertIn("No @task decorators found", str(cm.exception)) + + # ---- edge-case coverage ---- + + def test_push_file_not_found(self): + """Push errors immediately when the source file doesn't exist.""" + with self.assertRaises(ValueError) as cm: + self.api.benchmarks_tasks_push_cli("my-task", "/nonexistent/task.py") + self.assertIn("does not exist", str(cm.exception)) + + def test_push_not_py_file(self): + """Push errors when the file is not a .py file.""" + with tempfile.NamedTemporaryFile(suffix=".txt", mode="w") as f: + f.write("hello") + f.flush() + with self.assertRaises(ValueError) as cm: + self.api.benchmarks_tasks_push_cli("my-task", f.name) + self.assertIn("must be a .py file", str(cm.exception)) + + +class TestBenchmarksCliParsing: + """Tests that argparse wiring for `kaggle benchmarks tasks` is correct. + + These verify that argument names, aliases (b/t), required flags, + and nargs constraints are properly configured in cli.py. + """ + + def setup_method(self): + self.parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + ) + subparsers = self.parser.add_subparsers( + title="commands", + dest="command", + ) + subparsers.required = True + + from kaggle.cli import parse_benchmarks + + parse_benchmarks(subparsers) + + def _parse(self, arg_string): + return self.parser.parse_args(arg_string.split()) + + @pytest.mark.parametrize( + "cmd, expected", + [ + ("benchmarks tasks push my-task -f ./task.py", {"task": "my-task", "file": "./task.py"}), + ("b t push my-task -f ./task.py", {"task": "my-task", "file": "./task.py"}), + ("benchmarks tasks run my-task", {"task": "my-task", "model": None, "wait": None}), + ("benchmarks tasks run my-task -m gemini-3 --wait", {"model": ["gemini-3"], "wait": 0}), + ("benchmarks tasks run my-task -m gemini-3 --wait 60", {"model": ["gemini-3"], "wait": 60}), + ("benchmarks tasks run my-task -m gemini-3 gpt-5 claude-4", {"model": ["gemini-3", "gpt-5", "claude-4"]}), + ("b t run my-task -m gemini-3", {"task": "my-task", "model": ["gemini-3"]}), + ], + ) + def test_parse_success(self, cmd, expected): + args = self._parse(cmd) + for key, val in expected.items(): + assert getattr(args, key) == val + + @pytest.mark.parametrize( + "cmd", + [ + "benchmarks tasks push my-task", + "benchmarks tasks run my-task -m", + ], + ) + def test_parse_error(self, cmd): + with pytest.raises(SystemExit): + self._parse(cmd) + + +class TestTaskNameExtraction(unittest.TestCase): + """Tests for _get_task_names_from_file(), which parses @task decorators + from Python source code using AST to validate push inputs. + """ + + def setUp(self): + self.api = KaggleApi() + + def test_extract_simple_task(self): + """@kbench.task with no name= arg uses the function name as title case.""" + code = """ +import kaggle_benchmarks as kbench +@kbench.task +def my_task(llm): + pass +""" + task_names = self.api._get_task_names_from_file(code) + self.assertEqual(task_names, ["My Task"]) + + def test_extract_task_with_name(self): + """@kbench.task(name='custom_name') uses the explicit name.""" + code = """ +import kaggle_benchmarks as kbench +@kbench.task(name="custom_name") +def my_task(llm): + pass +""" + task_names = self.api._get_task_names_from_file(code) + self.assertEqual(task_names, ["custom_name"]) + + def test_extract_multiple_tasks(self): + """Multiple @task decorators in one file are all extracted.""" + code = """ +@task +def task1(llm): pass + +@task(name="task2_custom") +def task2(llm): pass +""" + task_names = self.api._get_task_names_from_file(code) + self.assertEqual(set(task_names), {"Task1", "task2_custom"}) + + def test_extract_no_tasks(self): + """File with no @task decorators returns empty list.""" + code = """ +def regular_function(): pass +""" + task_names = self.api._get_task_names_from_file(code) + self.assertEqual(task_names, []) + + def test_extract_syntax_error(self): + """Files with syntax errors return empty list instead of crashing.""" + code = """ +def broken_function( +""" + task_names = self.api._get_task_names_from_file(code) + self.assertEqual(task_names, []) + + def test_extract_async_task(self): + """Async function definitions with @task are also extracted.""" + code = """ +@task +async def my_async_task(llm): + pass +""" + task_names = self.api._get_task_names_from_file(code) + self.assertEqual(task_names, ["My Async Task"]) + + +if __name__ == "__main__": + unittest.main() From bdbef8765f2eeeaab2f031eb3ef81f6ddca55383 Mon Sep 17 00:00:00 2001 From: Li Ma Date: Wed, 8 Apr 2026 23:09:14 +0000 Subject: [PATCH 3/9] re-sync with sdk --- pyproject.toml | 1 + src/kaggle/api/kaggle_api_extended.py | 37 +- src/kaggle/test/conftest.py | 18 + src/kaggle/test/test_benchmarks_cli.py | 701 ++++++++++++------------- 4 files changed, 388 insertions(+), 369 deletions(-) create mode 100644 src/kaggle/test/conftest.py diff --git a/pyproject.toml b/pyproject.toml index 2ffc61db..a2defec8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "urllib3 >= 1.15.1", "packaging", "protobuf", + "jupytext", ] [project.scripts] diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index 05971076..b2be11e0 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -58,12 +58,10 @@ from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskRunState, BenchmarkTaskVersionCreationState from kagglesdk.benchmarks.types.benchmark_tasks_api_service import ( ApiCreateBenchmarkTaskRequest, - ApiListBenchmarkTasksRequest, ApiGetBenchmarkTaskRequest, ApiListBenchmarkTaskRunsRequest, ApiBenchmarkTaskSlug, ApiBatchScheduleBenchmarkTaskRunsRequest, - ApiDownloadBenchmarkTaskRunOutputRequest, ) from kagglesdk.benchmarks.types.benchmarks_api_service import ApiListBenchmarkModelsRequest from kagglesdk.competitions.types.competition_api_service import ( @@ -5444,6 +5442,12 @@ def benchmarks_tasks_push_cli(self, task, file): # Convert .py file with percent delimiters to .ipynb import jupytext notebook = jupytext.reads(content, fmt="py:percent") + # Add kernelspec metadata so papermill can execute it on the server + notebook.metadata["kernelspec"] = { + "display_name": "Python 3", + "language": "python", + "name": "python3", + } notebook_content = jupytext.writes(notebook, fmt="ipynb") with self.build_kaggle_client() as kaggle: @@ -5465,13 +5469,25 @@ def benchmarks_tasks_push_cli(self, task, file): response = kaggle.benchmarks.benchmark_tasks_api_client.create_benchmark_task(request) print(f"Task '{task}' pushed.") - print(f"Task URL: {response.url}") + url = response.url + if url.startswith("/"): + url = "https://www.kaggle.com" + url + print(f"Task URL: {url}") def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10): models = self._normalize_model_list(model) task_slug_obj = self._make_task_slug(task) with self.build_kaggle_client() as kaggle: + # Verify the task exists and is ready to run + task_info = self._get_benchmark_task(task, kaggle) + if task_info.creation_state != BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED: + error_msg = f"Task '{task}' is not ready to run (status: {task_info.creation_state})." + if task_info.creation_state == BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_ERRORED: + error_msg += f" Task Info: {task_info}." + error_msg += " Only completed tasks can be run." + raise ValueError(error_msg) + # If no models specified, prompt the user to select from available models if not models: models_request = ApiListBenchmarkModelsRequest() @@ -5481,20 +5497,21 @@ def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10 raise ValueError("No benchmark models available. Cannot schedule runs.") print("No model specified. Available models:") for i, m in enumerate(available, 1): - print(f" {i}. {m.slug} ({m.display_name})") + print(f" {i}. {m.version.slug} ({m.display_name})") selection = input("Enter model numbers (comma-separated), or 'all': ").strip() if selection.lower() == "all": - models = [m.slug for m in available] + models = [m.version.slug for m in available] else: try: indices = [int(s.strip()) for s in selection.split(",")] - models = [available[i - 1].slug for i in indices] + models = [available[i - 1].version.slug for i in indices] except (ValueError, IndexError): raise ValueError(f"Invalid selection: {selection}") + print(f"Selected models: {models}") request = ApiBatchScheduleBenchmarkTaskRunsRequest() request.task_slugs = [task_slug_obj] - request.model_slugs = models + request.model_version_slugs = models response = kaggle.benchmarks.benchmark_tasks_api_client.batch_schedule_benchmark_task_runs(request) print(f"Submitted run(s) for task '{task}'.") @@ -5510,16 +5527,16 @@ def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10 start_time = time.time() while True: runs_request = ApiListBenchmarkTaskRunsRequest() - runs_request.task_slugs = [task_slug_obj] + runs_request.task_slug = task_slug_obj if models: - runs_request.model_slugs = models + runs_request.model_version_slugs = models runs_resp = kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(runs_request) all_done = runs_resp.runs and all(r.state in self._TERMINAL_RUN_STATES for r in runs_resp.runs) if all_done: print("All runs completed:") for r in runs_resp.runs: state_label = "COMPLETED" if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED else "ERRORED" - print(f" {r.model_slug}: {state_label}") + print(f" {r.model_version_slug}: {state_label}") break pending = sum(1 for r in runs_resp.runs if r.state not in self._TERMINAL_RUN_STATES) diff --git a/src/kaggle/test/conftest.py b/src/kaggle/test/conftest.py new file mode 100644 index 00000000..9595981d --- /dev/null +++ b/src/kaggle/test/conftest.py @@ -0,0 +1,18 @@ +"""Shared test configuration for kaggle CLI tests. + +Must be set before importing kaggle, which calls api.authenticate() at +module level. Fake legacy credentials keep authenticate() off the network; +removing KAGGLE_API_TOKEN prevents _introspect_token() from being called. +We also patch get_access_token_from_env so the ~/.kaggle/access_token file +doesn't trigger token introspection. +""" + +import os +from unittest.mock import patch + +os.environ.pop("KAGGLE_API_TOKEN", None) +os.environ["KAGGLE_USERNAME"] = "testuser" +os.environ["KAGGLE_KEY"] = "testkey" + +with patch("kagglesdk.get_access_token_from_env", return_value=(None, None)): + import kaggle # noqa: F401 — triggers authenticate() diff --git a/src/kaggle/test/test_benchmarks_cli.py b/src/kaggle/test/test_benchmarks_cli.py index f99ddbbb..6b6ffed0 100644 --- a/src/kaggle/test/test_benchmarks_cli.py +++ b/src/kaggle/test/test_benchmarks_cli.py @@ -1,307 +1,357 @@ -import os -from unittest.mock import patch as _patch - -# Must be set before importing kaggle, which calls api.authenticate() at -# module level. Fake legacy credentials keep authenticate() off the network; -# removing KAGGLE_API_TOKEN prevents _introspect_token() from being called. -# We also patch get_access_token_from_env so the ~/.kaggle/access_token file -# doesn't trigger token introspection. -os.environ.pop("KAGGLE_API_TOKEN", None) -os.environ["KAGGLE_USERNAME"] = "testuser" -os.environ["KAGGLE_KEY"] = "testkey" - -with _patch("kagglesdk.get_access_token_from_env", return_value=(None, None)): - import kaggle # noqa: F401 — triggers authenticate() - -import unittest -from unittest.mock import patch, MagicMock +"""Tests for ``kaggle benchmarks tasks`` CLI commands. + +Organized by user journey: + TestPush – ``kaggle benchmarks tasks push -f `` + TestRun – ``kaggle benchmarks tasks run [-m ...] [--wait]`` + TestCliArgParsing – argparse wiring verification + +Uses pure pytest (capsys, tmp_path) instead of unittest.TestCase so that +stdout capture and temp-file creation are handled by built-in fixtures +rather than manual patching. +""" + import argparse -import io -import tempfile +from unittest.mock import patch, MagicMock + import pytest from requests.exceptions import HTTPError + from kaggle.api.kaggle_api_extended import KaggleApi -from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskVersionCreationState, BenchmarkTaskRunState +from kagglesdk.benchmarks.types.benchmark_enums import ( + BenchmarkTaskVersionCreationState, + BenchmarkTaskRunState, +) + +# Short aliases for the verbose enum members used throughout the tests. +QUEUED = BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED +RUNNING = BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING +COMPLETED = BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED + +RUN_RUNNING = BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_RUNNING +RUN_COMPLETED = BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED +RUN_ERRORED = BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED + +DEFAULT_TASK_CONTENT = '@task(name="my-task")\ndef evaluate(): pass\n' -class TestBenchmarksCli(unittest.TestCase): - """Tests for `kaggle benchmarks tasks ` CLI methods. +# ---- Fixtures & helpers ---- - Each test exercises one API method (e.g. benchmarks_tasks_push_cli) with - mocked SDK calls, verifying the printed output and request arguments match - the expected user experience. + +@pytest.fixture +def api(): + """A KaggleApi with mocked auth and client — no network calls.""" + a = KaggleApi() + a.authenticate = MagicMock() + mock_client = MagicMock() + a.build_kaggle_client = MagicMock() + a.build_kaggle_client.return_value.__enter__.return_value = mock_client + # Expose internals so helpers can wire up responses. + a._mock_client = mock_client + a._mock_benchmarks = mock_client.benchmarks.benchmark_tasks_api_client + return a + + +def _write_task_file(tmp_path, content=DEFAULT_TASK_CONTENT, name="task.py"): + """Write *content* to a .py file under *tmp_path* and return its path str.""" + p = tmp_path / name + p.write_text(content) + return str(p) + + +def _mock_jupytext(): + """Return ``(mock_jupytext_module, context_manager)``.""" + jt = MagicMock() + jt.reads.return_value = "mock_notebook" + jt.writes.return_value = '{"cells": []}' + return jt, patch.dict("sys.modules", {"jupytext": jt}) + + +def _push(api, task, filepath): + """Call ``benchmarks_tasks_push_cli`` with jupytext mocked. + + Returns the mock jupytext module so callers can assert on calls. """ + jt, ctx = _mock_jupytext() + with ctx: + api.benchmarks_tasks_push_cli(task, filepath) + return jt + + +def _make_task(slug="my-task", state=COMPLETED): + t = MagicMock() + t.slug.task_slug = slug + t.creation_state = state + t.create_time = "2026-04-06" + return t + + +def _make_run_result(scheduled=True, skipped_reason=None): + r = MagicMock() + r.run_scheduled = scheduled + r.benchmark_task_version_id = 1 + r.benchmark_model_version_id = 10 + r.run_skipped_reason = skipped_reason + return r + + +def _make_run(model="gemini-pro", state=RUN_COMPLETED): + r = MagicMock() + r.model_version_slug = model + r.state = state + return r + + +def _setup_create_response(api, task_slug="my-task"): + resp = MagicMock() + resp.slug.task_slug = task_slug + resp.url = f"https://kaggle.com/benchmarks/{task_slug}" + api._mock_benchmarks.create_benchmark_task.return_value = resp + + +def _setup_completed_task(api, slug="my-task"): + task = _make_task(slug=slug, state=COMPLETED) + api._mock_benchmarks.get_benchmark_task.return_value = task + + +def _setup_batch_schedule(api, results): + resp = MagicMock() + resp.results = results + api._mock_benchmarks.batch_schedule_benchmark_task_runs.return_value = resp + + +def _setup_available_models(api, slugs): + models = [] + for s in slugs: + m = MagicMock() + m.version.slug = s + m.display_name = s.title() + models.append(m) + resp = MagicMock() + resp.benchmark_models = models + api._mock_client.benchmarks.benchmarks_api_client.list_benchmark_models.return_value = resp + + +# ============================================================ +# Push Journey +# ============================================================ + + +class TestPush: + """``kaggle benchmarks tasks push -f ``""" + + # -- Input validation (before any server call) -- + + @pytest.mark.parametrize( + "task, filename, content, expected_error", + [ + ("my-task", None, None, "does not exist"), + ("my-task", "task.txt", "hello", "must be a .py"), + ("any-task", "task.py", "def f(): pass\n", "No @task decorators"), + ("wrong", "task.py", '@task(name="real")\ndef f(llm): pass\n', "not found"), + ("any-task", "task.py", "def broken(\n", "No @task decorators"), + ], + ids=["missing_file", "wrong_extension", "no_decorators", "wrong_name", "syntax_error"], + ) + def test_push_rejects_invalid_input(self, api, tmp_path, task, filename, content, expected_error): + if filename is None: + filepath = "/nonexistent/task.py" + else: + filepath = _write_task_file(tmp_path, content, name=filename) + with pytest.raises(ValueError, match=expected_error): + api.benchmarks_tasks_push_cli(task, filepath) + + # -- Happy path -- - TASK_FILE_CONTENT = '@task(name="my-task")\ndef evaluate(): pass\n' - - def setUp(self): - self.api = KaggleApi() - # Mock authenticate to avoid real network/creds check during unit tests - self.api.authenticate = MagicMock() - - # Mock build_kaggle_client to avoid real network calls - self.mock_client = MagicMock() - self.api.build_kaggle_client = MagicMock() - self.api.build_kaggle_client.return_value.__enter__.return_value = self.mock_client - self.mock_benchmarks = self.mock_client.benchmarks.benchmark_tasks_api_client - - # -- Helpers -- - - def _make_task_file(self, content=None): - """Create a temp .py file with task content. Caller must call .close().""" - f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=True) - f.write(content or self.TASK_FILE_CONTENT) - f.flush() - return f - - def _mock_jupytext(self): - """Return a mock jupytext module and a context manager that patches it in.""" - jt = MagicMock() - jt.reads.return_value = "mock_notebook" - jt.writes.return_value = '{"cells": []}' - return jt, patch.dict("sys.modules", {"jupytext": jt}) - - def _setup_create_response(self, task_slug="my-task"): - mock_resp = MagicMock() - mock_resp.slug.task_slug = task_slug - mock_resp.url = f"https://kaggle.com/benchmarks/{task_slug}" - self.mock_benchmarks.create_benchmark_task.return_value = mock_resp - return mock_resp - - def _make_mock_task(self, slug="my-task", state="COMPLETED", create_time="2026-04-06"): - t = MagicMock() - t.slug.task_slug = slug - t.creation_state = state - t.create_time = create_time - return t - - def _make_mock_run( - self, - model="gemini-pro", - state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED, - run_id=1, - start_time=None, - end_time=None, - error_message=None, - ): - r = MagicMock() - r.model_slug = model - r.state = state - r.id = run_id - r.start_time = start_time - r.end_time = end_time - r.error_message = error_message - return r - - def _setup_runs_response(self, runs): - resp = MagicMock() - resp.runs = runs - self.mock_benchmarks.list_benchmark_task_runs.return_value = resp - - def _make_run_result(self, scheduled=True, skipped_reason=None): - r = MagicMock() - r.run_scheduled = scheduled - r.benchmark_task_version_id = 1 - r.benchmark_model_version_id = 10 - r.run_skipped_reason = skipped_reason - return r - - # ---- kaggle benchmarks tasks push -f ---- - - @patch("sys.stdout", new_callable=io.StringIO) - def test_push_success(self, mock_stdout): - """Happy path: push a valid .py file creates the task on the server.""" - with self._make_task_file() as f: - self._setup_create_response() - jt, jt_ctx = self._mock_jupytext() - with jt_ctx: - self.api.benchmarks_tasks_push_cli("my-task", f.name) - self.assertIn("Task 'my-task' pushed.", mock_stdout.getvalue()) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_push_failure_pending(self, mock_stdout): - """Push rejected when the task version is still being created (QUEUED).""" - with self._make_task_file() as f: - task = self._make_mock_task( - state=BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED - ) - self.mock_benchmarks.get_benchmark_task.return_value = task - _, jt_ctx = self._mock_jupytext() - with jt_ctx, self.assertRaises(ValueError) as cm: - self.api.benchmarks_tasks_push_cli("my-task", f.name) - self.assertIn("is currently being created", str(cm.exception)) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_push_success_404(self, mock_stdout): - """A 404 on get means new task — should still create successfully.""" - mock_response = MagicMock() - mock_response.status_code = 404 - self.mock_benchmarks.get_benchmark_task.side_effect = HTTPError(response=mock_response) - self._setup_create_response() - with self._make_task_file() as f: - _, jt_ctx = self._mock_jupytext() - with jt_ctx: - self.api.benchmarks_tasks_push_cli("my-task", f.name) - self.assertIn("Task 'my-task' pushed.", mock_stdout.getvalue()) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_push_converts_to_ipynb(self, mock_stdout): - """Push converts the .py file to ipynb via jupytext before uploading.""" - with self._make_task_file() as f: - self._setup_create_response() - jt, jt_ctx = self._mock_jupytext() - with jt_ctx: - self.api.benchmarks_tasks_push_cli("my-task", f.name) - jt.reads.assert_called_once() - jt.writes.assert_called_once() - request = self.mock_benchmarks.create_benchmark_task.call_args[0][0] - self.assertEqual(request.text, '{"cells": []}') - - # ---- kaggle benchmarks tasks run [-m ...] [--wait] ---- - - def _setup_batch_schedule(self, results): - resp = MagicMock() - resp.results = results - self.mock_benchmarks.batch_schedule_benchmark_task_runs.return_value = resp - - def _setup_available_models(self, slugs): - models = [] - for s in slugs: - m = MagicMock() - m.slug = s - m.display_name = s.title() - models.append(m) - resp = MagicMock() - resp.benchmark_models = models - self.mock_client.benchmarks.benchmarks_api_client.list_benchmark_models.return_value = resp - - @patch("sys.stdout", new_callable=io.StringIO) - def test_run_single_model(self, mock_stdout): - """'kaggle b t run my-task -m gemini-pro' schedules one run.""" - self._setup_batch_schedule([self._make_run_result()]) - self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) - self.assertIn("Submitted run(s) for task 'my-task'", mock_stdout.getvalue()) - self.assertIn("gemini-pro: Scheduled", mock_stdout.getvalue()) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_run_multiple_models(self, mock_stdout): - """'kaggle b t run my-task -m gemini-pro gemma-2b' schedules two runs.""" - self._setup_batch_schedule([self._make_run_result(), self._make_run_result()]) - self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro", "gemma-2b"]) - output = mock_stdout.getvalue() - self.assertIn("gemini-pro: Scheduled", output) - self.assertIn("gemma-2b: Scheduled", output) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_run_skipped_result(self, mock_stdout): - """When the server skips a run (e.g. already running), show reason.""" - self._setup_batch_schedule([self._make_run_result(scheduled=False, skipped_reason="Already running")]) - self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) - output = mock_stdout.getvalue() - self.assertIn("gemini-pro: Skipped", output) - self.assertIn("Already running", output) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_run_no_model_prompts_selection(self, mock_stdout): - """When no model is specified, user is prompted to select from available models.""" - self._setup_available_models(["gemini-pro", "gemma-2b"]) - self._setup_batch_schedule([self._make_run_result()]) + @pytest.mark.parametrize( + "content, task_name", + [ + ('@task(name="my-task")\ndef evaluate(): pass\n', "my-task"), + ("@task\ndef my_task(llm): pass\n", "My Task"), + ("@task\nasync def my_task(llm): pass\n", "My Task"), + ], + ids=["explicit_name", "title_cased", "async_function"], + ) + def test_push_creates_task(self, api, tmp_path, capsys, content, task_name): + """Push converts .py → ipynb via jupytext and creates the task.""" + filepath = _write_task_file(tmp_path, content) + _setup_create_response(api, task_name) + + jt = _push(api, task_name, filepath) + + # Verify jupytext conversion happened + jt.reads.assert_called_once() + jt.writes.assert_called_once() + request = api._mock_benchmarks.create_benchmark_task.call_args[0][0] + assert request.text == '{"cells": []}' + + assert f"Task '{task_name}' pushed." in capsys.readouterr().out + + def test_push_creates_new_task_on_404(self, api, tmp_path, capsys): + """A 404 from get_benchmark_task means new task — still creates successfully.""" + filepath = _write_task_file(tmp_path) + api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError( + response=MagicMock(status_code=404) + ) + _setup_create_response(api) + _push(api, "my-task", filepath) + assert "Task 'my-task' pushed." in capsys.readouterr().out + + # -- Server edge cases -- + + @pytest.mark.parametrize("state", [QUEUED, RUNNING], ids=["queued", "running"]) + def test_push_rejects_pending_task(self, api, tmp_path, state): + """Push rejects when the task version is still being created.""" + filepath = _write_task_file(tmp_path) + api._mock_benchmarks.get_benchmark_task.return_value = _make_task(state=state) + with pytest.raises(ValueError, match="currently being created"): + _push(api, "my-task", filepath) + + def test_push_propagates_server_error(self, api, tmp_path): + """Non-404 HTTP errors (e.g. 500) are re-raised, not swallowed.""" + filepath = _write_task_file(tmp_path) + api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError( + response=MagicMock(status_code=500) + ) + with pytest.raises(HTTPError): + _push(api, "my-task", filepath) + + +# ============================================================ +# Run Journey +# ============================================================ + + +class TestRun: + """``kaggle benchmarks tasks run [-m ...] [--wait]``""" + + # -- Pre-conditions -- + + def test_run_rejects_non_completed_task(self, api): + """Run errors when the task creation state is not COMPLETED.""" + api._mock_benchmarks.get_benchmark_task.return_value = _make_task(state=QUEUED) + with pytest.raises(ValueError, match="not ready to run"): + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) + api._mock_benchmarks.batch_schedule_benchmark_task_runs.assert_not_called() + + # -- Model scheduling -- + + @pytest.mark.parametrize( + "models", + [["gemini-pro"], ["gemini-pro", "gemma-2b"]], + ids=["single_model", "multiple_models"], + ) + def test_run_schedules_models(self, api, capsys, models): + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result() for _ in models]) + api.benchmarks_tasks_run_cli("my-task", models) + output = capsys.readouterr().out + assert "Submitted run(s) for task 'my-task'" in output + for m in models: + assert f"{m}: Scheduled" in output + + def test_run_reports_skipped_with_reason(self, api, capsys): + _setup_completed_task(api) + _setup_batch_schedule( + api, [_make_run_result(scheduled=False, skipped_reason="Already running")] + ) + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) + output = capsys.readouterr().out + assert "gemini-pro: Skipped" in output + assert "Already running" in output + + # -- Interactive model selection -- + + def test_run_prompts_model_selection(self, api): + """No model specified → user picks from a numbered list.""" + _setup_completed_task(api) + _setup_available_models(api, ["gemini-pro", "gemma-2b"]) + _setup_batch_schedule(api, [_make_run_result()]) with patch("builtins.input", return_value="1"): - self.api.benchmarks_tasks_run_cli("my-task") - request = self.mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] - self.assertEqual(request.model_slugs, ["gemini-pro"]) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_run_no_model_select_all(self, mock_stdout): - """When no model is specified and user selects 'all'.""" - self._setup_available_models(["gemini-pro", "gemma-2b"]) - self._setup_batch_schedule([]) + api.benchmarks_tasks_run_cli("my-task") + request = api._mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] + assert request.model_version_slugs == ["gemini-pro"] + + def test_run_selects_all_models(self, api): + _setup_completed_task(api) + _setup_available_models(api, ["gemini-pro", "gemma-2b"]) + _setup_batch_schedule(api, []) with patch("builtins.input", return_value="all"): - self.api.benchmarks_tasks_run_cli("my-task") - request = self.mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] - self.assertEqual(request.model_slugs, ["gemini-pro", "gemma-2b"]) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_run_with_wait(self, mock_stdout): - """Test --wait polls until runs complete.""" - self._setup_batch_schedule([self._make_run_result()]) - running = self._make_mock_run(state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_RUNNING) - done = self._make_mock_run(state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED) - resp1 = MagicMock(runs=[running]) - resp2 = MagicMock(runs=[done]) - self.mock_benchmarks.list_benchmark_task_runs.side_effect = [resp1, resp2] + api.benchmarks_tasks_run_cli("my-task") + request = api._mock_benchmarks.batch_schedule_benchmark_task_runs.call_args[0][0] + assert request.model_version_slugs == ["gemini-pro", "gemma-2b"] + + def test_run_rejects_empty_model_list(self, api): + """No models available on server → ValueError.""" + _setup_completed_task(api) + _setup_available_models(api, []) + with pytest.raises(ValueError, match="No benchmark models available"): + api.benchmarks_tasks_run_cli("my-task") + + def test_run_rejects_invalid_model_selection(self, api): + """Bad input during interactive model selection → ValueError.""" + _setup_completed_task(api) + _setup_available_models(api, ["gemini-pro"]) + with patch("builtins.input", return_value="abc"): + with pytest.raises(ValueError, match="Invalid selection"): + api.benchmarks_tasks_run_cli("my-task") + + # -- Wait / polling -- + + def test_run_wait_polls_until_completion(self, api, capsys): + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result()]) + api._mock_benchmarks.list_benchmark_task_runs.side_effect = [ + MagicMock(runs=[_make_run(state=RUN_RUNNING)]), + MagicMock(runs=[_make_run(state=RUN_COMPLETED)]), + ] with patch("time.sleep"): - self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) - output = mock_stdout.getvalue() - self.assertIn("Waiting for run(s) to complete", output) - self.assertIn("All runs completed", output) - self.assertIn("gemini-pro: COMPLETED", output) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_run_with_timeout(self, mock_stdout): - """Test --wait with timeout stops waiting.""" - self._setup_batch_schedule([self._make_run_result()]) - running = self._make_mock_run(state=BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_RUNNING) - self._setup_runs_response([running]) + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) + output = capsys.readouterr().out + assert "Waiting for run(s) to complete" in output + assert "All runs completed" in output + assert "gemini-pro: COMPLETED" in output + + def test_run_wait_times_out(self, api, capsys): + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result()]) + api._mock_benchmarks.list_benchmark_task_runs.return_value = MagicMock( + runs=[_make_run(state=RUN_RUNNING)] + ) with patch("time.sleep"), patch("time.time", side_effect=[1000, 1060]): - self.api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=30) - output = mock_stdout.getvalue() - self.assertIn("Waiting for run(s) to complete", output) - self.assertIn("Timed out waiting for runs after 30 seconds", output) - - # ---- push input validation (before any server call) ---- - - @patch("sys.stdout", new_callable=io.StringIO) - def test_push_validation_failure(self, mock_stdout): - """Push errors when the requested task name doesn't match any @task in the file.""" - with self._make_task_file('@task(name="real-task")\ndef my_task(llm): pass\n') as f: - with self.assertRaises(ValueError) as cm: - self.api.benchmarks_tasks_push_cli("wrong-task", f.name) - self.assertIn("Task 'wrong-task' not found", str(cm.exception)) - - @patch("sys.stdout", new_callable=io.StringIO) - def test_push_validation_no_tasks(self, mock_stdout): - """Push errors when the file has no @task decorators at all.""" - with self._make_task_file("def regular_function(): pass\n") as f: - with self.assertRaises(ValueError) as cm: - self.api.benchmarks_tasks_push_cli("any-task", f.name) - self.assertIn("No @task decorators found", str(cm.exception)) - - # ---- edge-case coverage ---- - - def test_push_file_not_found(self): - """Push errors immediately when the source file doesn't exist.""" - with self.assertRaises(ValueError) as cm: - self.api.benchmarks_tasks_push_cli("my-task", "/nonexistent/task.py") - self.assertIn("does not exist", str(cm.exception)) - - def test_push_not_py_file(self): - """Push errors when the file is not a .py file.""" - with tempfile.NamedTemporaryFile(suffix=".txt", mode="w") as f: - f.write("hello") - f.flush() - with self.assertRaises(ValueError) as cm: - self.api.benchmarks_tasks_push_cli("my-task", f.name) - self.assertIn("must be a .py file", str(cm.exception)) - - -class TestBenchmarksCliParsing: - """Tests that argparse wiring for `kaggle benchmarks tasks` is correct. - - These verify that argument names, aliases (b/t), required flags, - and nargs constraints are properly configured in cli.py. + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=30) + output = capsys.readouterr().out + assert "Timed out waiting for runs after 30 seconds" in output + + def test_run_wait_shows_errored_runs(self, api, capsys): + """ERRORED runs display with ERRORED label.""" + _setup_completed_task(api) + _setup_batch_schedule(api, [_make_run_result()]) + api._mock_benchmarks.list_benchmark_task_runs.return_value = MagicMock( + runs=[_make_run(state=RUN_ERRORED)] + ) + with patch("time.sleep"): + api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) + assert "gemini-pro: ERRORED" in capsys.readouterr().out + + +# ============================================================ +# CLI Arg Parsing +# ============================================================ + + +class TestCliArgParsing: + """Tests that argparse wiring for ``kaggle benchmarks tasks`` is correct. + + Verifies argument names, aliases (b/t), required flags, and nargs + constraints are properly configured in cli.py. """ def setup_method(self): self.parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, ) - subparsers = self.parser.add_subparsers( - title="commands", - dest="command", - ) + subparsers = self.parser.add_subparsers(title="commands", dest="command") subparsers.required = True - from kaggle.cli import parse_benchmarks parse_benchmarks(subparsers) @@ -316,8 +366,14 @@ def _parse(self, arg_string): ("b t push my-task -f ./task.py", {"task": "my-task", "file": "./task.py"}), ("benchmarks tasks run my-task", {"task": "my-task", "model": None, "wait": None}), ("benchmarks tasks run my-task -m gemini-3 --wait", {"model": ["gemini-3"], "wait": 0}), - ("benchmarks tasks run my-task -m gemini-3 --wait 60", {"model": ["gemini-3"], "wait": 60}), - ("benchmarks tasks run my-task -m gemini-3 gpt-5 claude-4", {"model": ["gemini-3", "gpt-5", "claude-4"]}), + ( + "benchmarks tasks run my-task -m gemini-3 --wait 60", + {"model": ["gemini-3"], "wait": 60}, + ), + ( + "benchmarks tasks run my-task -m gemini-3 gpt-5 claude-4", + {"model": ["gemini-3", "gpt-5", "claude-4"]}, + ), ("b t run my-task -m gemini-3", {"task": "my-task", "model": ["gemini-3"]}), ], ) @@ -336,76 +392,3 @@ def test_parse_success(self, cmd, expected): def test_parse_error(self, cmd): with pytest.raises(SystemExit): self._parse(cmd) - - -class TestTaskNameExtraction(unittest.TestCase): - """Tests for _get_task_names_from_file(), which parses @task decorators - from Python source code using AST to validate push inputs. - """ - - def setUp(self): - self.api = KaggleApi() - - def test_extract_simple_task(self): - """@kbench.task with no name= arg uses the function name as title case.""" - code = """ -import kaggle_benchmarks as kbench -@kbench.task -def my_task(llm): - pass -""" - task_names = self.api._get_task_names_from_file(code) - self.assertEqual(task_names, ["My Task"]) - - def test_extract_task_with_name(self): - """@kbench.task(name='custom_name') uses the explicit name.""" - code = """ -import kaggle_benchmarks as kbench -@kbench.task(name="custom_name") -def my_task(llm): - pass -""" - task_names = self.api._get_task_names_from_file(code) - self.assertEqual(task_names, ["custom_name"]) - - def test_extract_multiple_tasks(self): - """Multiple @task decorators in one file are all extracted.""" - code = """ -@task -def task1(llm): pass - -@task(name="task2_custom") -def task2(llm): pass -""" - task_names = self.api._get_task_names_from_file(code) - self.assertEqual(set(task_names), {"Task1", "task2_custom"}) - - def test_extract_no_tasks(self): - """File with no @task decorators returns empty list.""" - code = """ -def regular_function(): pass -""" - task_names = self.api._get_task_names_from_file(code) - self.assertEqual(task_names, []) - - def test_extract_syntax_error(self): - """Files with syntax errors return empty list instead of crashing.""" - code = """ -def broken_function( -""" - task_names = self.api._get_task_names_from_file(code) - self.assertEqual(task_names, []) - - def test_extract_async_task(self): - """Async function definitions with @task are also extracted.""" - code = """ -@task -async def my_async_task(llm): - pass -""" - task_names = self.api._get_task_names_from_file(code) - self.assertEqual(task_names, ["My Async Task"]) - - -if __name__ == "__main__": - unittest.main() From 398c9c6f83c44e71399ce9172d60d27eed8f58dd Mon Sep 17 00:00:00 2001 From: Li Ma Date: Thu, 9 Apr 2026 19:54:34 +0000 Subject: [PATCH 4/9] add pagination --- src/kaggle/api/kaggle_api_extended.py | 27 ++++++++++++++++++-------- src/kaggle/test/test_benchmarks_cli.py | 12 +++++++----- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index b2be11e0..6e46e112 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -5526,20 +5526,31 @@ def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10 print("Waiting for run(s) to complete...") start_time = time.time() while True: - runs_request = ApiListBenchmarkTaskRunsRequest() - runs_request.task_slug = task_slug_obj - if models: - runs_request.model_version_slugs = models - runs_resp = kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(runs_request) - all_done = runs_resp.runs and all(r.state in self._TERMINAL_RUN_STATES for r in runs_resp.runs) + # Paginate through all runs + all_runs = [] + page_token = "" + while True: + runs_request = ApiListBenchmarkTaskRunsRequest() + runs_request.task_slug = task_slug_obj + if models: + runs_request.model_version_slugs = models + if page_token: + runs_request.page_token = page_token + runs_resp = kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(runs_request) + all_runs.extend(runs_resp.runs) + if not runs_resp.next_page_token: + break + page_token = runs_resp.next_page_token + + all_done = all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs) if all_done: print("All runs completed:") - for r in runs_resp.runs: + for r in all_runs: state_label = "COMPLETED" if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED else "ERRORED" print(f" {r.model_version_slug}: {state_label}") break - pending = sum(1 for r in runs_resp.runs if r.state not in self._TERMINAL_RUN_STATES) + pending = sum(1 for r in all_runs if r.state not in self._TERMINAL_RUN_STATES) print(f" {pending} run(s) still in progress...") if wait > 0 and (time.time() - start_time) > wait: diff --git a/src/kaggle/test/test_benchmarks_cli.py b/src/kaggle/test/test_benchmarks_cli.py index 6b6ffed0..84539989 100644 --- a/src/kaggle/test/test_benchmarks_cli.py +++ b/src/kaggle/test/test_benchmarks_cli.py @@ -61,7 +61,9 @@ def _write_task_file(tmp_path, content=DEFAULT_TASK_CONTENT, name="task.py"): def _mock_jupytext(): """Return ``(mock_jupytext_module, context_manager)``.""" jt = MagicMock() - jt.reads.return_value = "mock_notebook" + notebook = MagicMock() + notebook.metadata = {} + jt.reads.return_value = notebook jt.writes.return_value = '{"cells": []}' return jt, patch.dict("sys.modules", {"jupytext": jt}) @@ -301,8 +303,8 @@ def test_run_wait_polls_until_completion(self, api, capsys): _setup_completed_task(api) _setup_batch_schedule(api, [_make_run_result()]) api._mock_benchmarks.list_benchmark_task_runs.side_effect = [ - MagicMock(runs=[_make_run(state=RUN_RUNNING)]), - MagicMock(runs=[_make_run(state=RUN_COMPLETED)]), + MagicMock(runs=[_make_run(state=RUN_RUNNING)], next_page_token=""), + MagicMock(runs=[_make_run(state=RUN_COMPLETED)], next_page_token=""), ] with patch("time.sleep"): api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) @@ -315,7 +317,7 @@ def test_run_wait_times_out(self, api, capsys): _setup_completed_task(api) _setup_batch_schedule(api, [_make_run_result()]) api._mock_benchmarks.list_benchmark_task_runs.return_value = MagicMock( - runs=[_make_run(state=RUN_RUNNING)] + runs=[_make_run(state=RUN_RUNNING)], next_page_token="" ) with patch("time.sleep"), patch("time.time", side_effect=[1000, 1060]): api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=30) @@ -327,7 +329,7 @@ def test_run_wait_shows_errored_runs(self, api, capsys): _setup_completed_task(api) _setup_batch_schedule(api, [_make_run_result()]) api._mock_benchmarks.list_benchmark_task_runs.return_value = MagicMock( - runs=[_make_run(state=RUN_ERRORED)] + runs=[_make_run(state=RUN_ERRORED)], next_page_token="" ) with patch("time.sleep"): api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"], wait=0) From 2504099146f0529a0e75661fa2e26f72de5c4124 Mon Sep 17 00:00:00 2001 From: Li Ma Date: Thu, 9 Apr 2026 20:39:18 +0000 Subject: [PATCH 5/9] list model pagination --- src/kaggle/api/kaggle_api_extended.py | 69 ++++++++++++++++++++------ src/kaggle/test/test_benchmarks_cli.py | 1 + 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index 6e46e112..ca0b9f36 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -5490,23 +5490,62 @@ def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10 # If no models specified, prompt the user to select from available models if not models: - models_request = ApiListBenchmarkModelsRequest() - models_response = kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(models_request) - available = models_response.benchmark_models + available = [] + page_token = "" + while True: + models_request = ApiListBenchmarkModelsRequest() + if page_token: + models_request.page_token = page_token + models_response = kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(models_request) + available.extend(models_response.benchmark_models) + if not models_response.next_page_token: + break + page_token = models_response.next_page_token if not available: raise ValueError("No benchmark models available. Cannot schedule runs.") - print("No model specified. Available models:") - for i, m in enumerate(available, 1): - print(f" {i}. {m.version.slug} ({m.display_name})") - selection = input("Enter model numbers (comma-separated), or 'all': ").strip() - if selection.lower() == "all": - models = [m.version.slug for m in available] - else: - try: - indices = [int(s.strip()) for s in selection.split(",")] - models = [available[i - 1].version.slug for i in indices] - except (ValueError, IndexError): - raise ValueError(f"Invalid selection: {selection}") + + PAGE_SIZE = 20 + total = len(available) + total_pages = (total + PAGE_SIZE - 1) // PAGE_SIZE + current_page = 0 + + print(f"No model specified. {total} model(s) available:") + while True: + start = current_page * PAGE_SIZE + end = min(start + PAGE_SIZE, total) + for i in range(start, end): + m = available[i] + print(f" {i + 1}. {m.version.slug} ({m.display_name})") + + nav_hints = [] + if total_pages > 1: + print(f" [Page {current_page + 1}/{total_pages}]") + if current_page < total_pages - 1: + nav_hints.append("'n'=next") + if current_page > 0: + nav_hints.append("'p'=prev") + + prompt = "Enter model numbers (comma-separated), 'all'" + if nav_hints: + prompt += ", " + ", ".join(nav_hints) + selection = input(prompt + ": ").strip() + + if selection.lower() == "n" and current_page < total_pages - 1: + current_page += 1 + continue + elif selection.lower() == "p" and current_page > 0: + current_page -= 1 + continue + elif selection.lower() == "all": + models = [m.version.slug for m in available] + break + else: + try: + indices = [int(s.strip()) for s in selection.split(",")] + models = [available[i - 1].version.slug for i in indices] + break + except (ValueError, IndexError): + raise ValueError(f"Invalid selection: {selection}") print(f"Selected models: {models}") request = ApiBatchScheduleBenchmarkTaskRunsRequest() diff --git a/src/kaggle/test/test_benchmarks_cli.py b/src/kaggle/test/test_benchmarks_cli.py index 84539989..e30374c8 100644 --- a/src/kaggle/test/test_benchmarks_cli.py +++ b/src/kaggle/test/test_benchmarks_cli.py @@ -130,6 +130,7 @@ def _setup_available_models(api, slugs): models.append(m) resp = MagicMock() resp.benchmark_models = models + resp.next_page_token = "" api._mock_client.benchmarks.benchmarks_api_client.list_benchmark_models.return_value = resp From d422b0cca0d55d5d69942840cf0f009aa04a5aff Mon Sep 17 00:00:00 2001 From: Li Ma Date: Thu, 9 Apr 2026 21:44:49 +0000 Subject: [PATCH 6/9] address comments and improve readability --- src/kaggle/api/kaggle_api_extended.py | 206 +++++++++++++------------ src/kaggle/test/test_benchmarks_cli.py | 14 ++ 2 files changed, 120 insertions(+), 100 deletions(-) diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index ca0b9f36..4cf5adfa 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -5372,6 +5372,11 @@ def get_response_processor(self): BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED, } + _PENDING_CREATION_STATES = { + BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED, + BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING, + } + @staticmethod def _make_task_slug(task: str) -> ApiBenchmarkTaskSlug: """Build an ApiBenchmarkTaskSlug from a task name string.""" @@ -5386,6 +5391,19 @@ def _normalize_model_list(model) -> list: return model return [model] if model else [] + @staticmethod + def _paginate(fetch_page, get_items): + """Exhaust a paginated API, returning all items.""" + items = [] + page_token = "" + while True: + response = fetch_page(page_token) + items.extend(get_items(response)) + page_token = response.next_page_token or "" + if not page_token: + break + return items + def _get_task_names_from_file(self, file_content: str) -> List[str]: """Extract task names from a Python file.""" import ast @@ -5433,8 +5451,8 @@ def benchmarks_tasks_push_cli(self, task, file): raise ValueError(f"File {file} does not exist") if not file.endswith(".py"): raise ValueError(f"File {file} must be a .py file") - - with open(file, 'r') as f: + + with open(file) as f: content = f.read() self._validate_task_in_file(task, file, content) @@ -5453,10 +5471,7 @@ def benchmarks_tasks_push_cli(self, task, file): with self.build_kaggle_client() as kaggle: try: task_info = self._get_benchmark_task(task, kaggle) - if task_info.creation_state in [ - BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED, - BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING - ]: + if task_info.creation_state in self._PENDING_CREATION_STATES: raise ValueError(f"Task '{task}' is currently being created (pending). Cannot push now.") except HTTPError as e: if e.response.status_code != 404: @@ -5464,16 +5479,98 @@ def benchmarks_tasks_push_cli(self, task, file): request = ApiCreateBenchmarkTaskRequest() request.slug = task - # Assume create_benchmark_task accepts ipynb content (JSON string) request.text = notebook_content response = kaggle.benchmarks.benchmark_tasks_api_client.create_benchmark_task(request) + error = getattr(response, "error_message", None) or getattr(response, "errorMessage", None) + if error: + raise ValueError(f"Failed to push task: {error}") print(f"Task '{task}' pushed.") url = response.url if url.startswith("/"): url = "https://www.kaggle.com" + url print(f"Task URL: {url}") + def _select_models_interactively(self, kaggle, page_size=20): + """Prompt the user to pick benchmark models from a paginated list.""" + def _fetch_models(page_token): + req = ApiListBenchmarkModelsRequest() + if page_token: + req.page_token = page_token + return kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(req) + + available = self._paginate(_fetch_models, lambda r: r.benchmark_models) + if not available: + raise ValueError("No benchmark models available. Cannot schedule runs.") + + total = len(available) + total_pages = -(-total // page_size) # ceiling division + current_page = 0 + + print(f"No model specified. {total} model(s) available:") + while True: + start = current_page * page_size + for i, m in enumerate(available[start : start + page_size], start=start + 1): + print(f" {i}. {m.version.slug} ({m.display_name})") + + nav_hints = [] + if total_pages > 1: + print(f" [Page {current_page + 1}/{total_pages}]") + if current_page < total_pages - 1: + nav_hints.append("'n'=next") + if current_page > 0: + nav_hints.append("'p'=prev") + + prompt_parts = ["Enter model numbers (comma-separated)", "'all'"] + if nav_hints: + prompt_parts.extend(nav_hints) + selection = input(", ".join(prompt_parts) + ": ").strip().lower() + + if selection == "n" and current_page < total_pages - 1: + current_page += 1 + elif selection == "p" and current_page > 0: + current_page -= 1 + elif selection == "all": + return [m.version.slug for m in available] + else: + try: + indices = [int(s) for s in selection.split(",")] + return [available[i - 1].version.slug for i in indices] + except (ValueError, IndexError): + raise ValueError(f"Invalid selection: {selection}") + + def _poll_runs(self, kaggle, task_slug_obj, models, wait, poll_interval): + """Poll run status until all runs are terminal or timeout.""" + def _fetch_runs(page_token): + req = ApiListBenchmarkTaskRunsRequest() + req.task_slug = task_slug_obj + if models: + req.model_version_slugs = models + if page_token: + req.page_token = page_token + return kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(req) + + print("Waiting for run(s) to complete...") + start_time = time.time() + while True: + all_runs = self._paginate(_fetch_runs, lambda r: r.runs) + + if all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs): + print("All runs completed:") + for r in all_runs: + label = "COMPLETED" if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED else "ERRORED" + print(f" {r.model_version_slug}: {label}") + return + + pending = sum(1 for r in all_runs if r.state not in self._TERMINAL_RUN_STATES) + print(f" {pending} run(s) still in progress...") + + if wait > 0 and (time.time() - start_time) > wait: + print(f"Timed out waiting for runs after {wait} seconds.") + return + + time.sleep(poll_interval) + def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10): models = self._normalize_model_list(model) task_slug_obj = self._make_task_slug(task) @@ -5488,64 +5585,8 @@ def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10 error_msg += " Only completed tasks can be run." raise ValueError(error_msg) - # If no models specified, prompt the user to select from available models if not models: - available = [] - page_token = "" - while True: - models_request = ApiListBenchmarkModelsRequest() - if page_token: - models_request.page_token = page_token - models_response = kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(models_request) - available.extend(models_response.benchmark_models) - if not models_response.next_page_token: - break - page_token = models_response.next_page_token - if not available: - raise ValueError("No benchmark models available. Cannot schedule runs.") - - PAGE_SIZE = 20 - total = len(available) - total_pages = (total + PAGE_SIZE - 1) // PAGE_SIZE - current_page = 0 - - print(f"No model specified. {total} model(s) available:") - while True: - start = current_page * PAGE_SIZE - end = min(start + PAGE_SIZE, total) - for i in range(start, end): - m = available[i] - print(f" {i + 1}. {m.version.slug} ({m.display_name})") - - nav_hints = [] - if total_pages > 1: - print(f" [Page {current_page + 1}/{total_pages}]") - if current_page < total_pages - 1: - nav_hints.append("'n'=next") - if current_page > 0: - nav_hints.append("'p'=prev") - - prompt = "Enter model numbers (comma-separated), 'all'" - if nav_hints: - prompt += ", " + ", ".join(nav_hints) - selection = input(prompt + ": ").strip() - - if selection.lower() == "n" and current_page < total_pages - 1: - current_page += 1 - continue - elif selection.lower() == "p" and current_page > 0: - current_page -= 1 - continue - elif selection.lower() == "all": - models = [m.version.slug for m in available] - break - else: - try: - indices = [int(s.strip()) for s in selection.split(",")] - models = [available[i - 1].version.slug for i in indices] - break - except (ValueError, IndexError): - raise ValueError(f"Invalid selection: {selection}") + models = self._select_models_interactively(kaggle) print(f"Selected models: {models}") request = ApiBatchScheduleBenchmarkTaskRunsRequest() @@ -5561,42 +5602,7 @@ def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10 print(f" {model_slug}: Skipped ({res.run_skipped_reason})") if wait is not None: - import time - print("Waiting for run(s) to complete...") - start_time = time.time() - while True: - # Paginate through all runs - all_runs = [] - page_token = "" - while True: - runs_request = ApiListBenchmarkTaskRunsRequest() - runs_request.task_slug = task_slug_obj - if models: - runs_request.model_version_slugs = models - if page_token: - runs_request.page_token = page_token - runs_resp = kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(runs_request) - all_runs.extend(runs_resp.runs) - if not runs_resp.next_page_token: - break - page_token = runs_resp.next_page_token - - all_done = all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs) - if all_done: - print("All runs completed:") - for r in all_runs: - state_label = "COMPLETED" if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED else "ERRORED" - print(f" {r.model_version_slug}: {state_label}") - break - - pending = sum(1 for r in all_runs if r.state not in self._TERMINAL_RUN_STATES) - print(f" {pending} run(s) still in progress...") - - if wait > 0 and (time.time() - start_time) > wait: - print(f"Timed out waiting for runs after {wait} seconds.") - break - - time.sleep(poll_interval) + self._poll_runs(kaggle, task_slug_obj, models, wait, poll_interval) class TqdmBufferedReader(io.BufferedReader): diff --git a/src/kaggle/test/test_benchmarks_cli.py b/src/kaggle/test/test_benchmarks_cli.py index e30374c8..92e32e8b 100644 --- a/src/kaggle/test/test_benchmarks_cli.py +++ b/src/kaggle/test/test_benchmarks_cli.py @@ -107,6 +107,8 @@ def _setup_create_response(api, task_slug="my-task"): resp = MagicMock() resp.slug.task_slug = task_slug resp.url = f"https://kaggle.com/benchmarks/{task_slug}" + resp.error_message = None + resp.errorMessage = None api._mock_benchmarks.create_benchmark_task.return_value = resp @@ -218,6 +220,18 @@ def test_push_propagates_server_error(self, api, tmp_path): with pytest.raises(HTTPError): _push(api, "my-task", filepath) + def test_push_handles_api_error(self, api, tmp_path): + """Push raises ValueError when response contains error_message.""" + filepath = _write_task_file(tmp_path) + _setup_completed_task(api) + + resp = MagicMock() + resp.error_message = "Some backend error" + api._mock_benchmarks.create_benchmark_task.return_value = resp + + with pytest.raises(ValueError, match="Failed to push task: Some backend error"): + _push(api, "my-task", filepath) + # ============================================================ # Run Journey From d14a77b85edceb6a8dcadc8b1e610f96f66b67e9 Mon Sep 17 00:00:00 2001 From: Li Ma Date: Fri, 10 Apr 2026 18:11:35 +0000 Subject: [PATCH 7/9] update kagglesdk in pyproject.toml and requirements --- pyproject.toml | 2 +- requirements-test.lock | 7 +++++- requirements.lock | 53 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2defec8..062c41c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ keywords = ["Kaggle", "API"] requires-python = ">= 3.11" dependencies = [ "bleach", - "kagglesdk >= 0.1.16, < 1.0", # sync with kagglehub + "kagglesdk >= 0.1.18, < 1.0", # sync with kagglehub "python-slugify", "requests", "python-dateutil", diff --git a/requirements-test.lock b/requirements-test.lock index ef288660..7591b959 100644 --- a/requirements-test.lock +++ b/requirements-test.lock @@ -1,4 +1,9 @@ -# Regenerate with: pip-compile requirements-test.in -o requirements-test.lock +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# pip-compile --output-file=requirements-test.lock requirements-test.in +# iniconfig==2.3.0 # via pytest packaging==26.0 diff --git a/requirements.lock b/requirements.lock index e4283c33..be56b94f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -1,16 +1,49 @@ -# Regenerate with: pip-compile pyproject.toml -o requirements.lock +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# pip-compile --output-file=requirements.lock pyproject.toml +# +attrs==26.1.0 + # via + # jsonschema + # referencing bleach==6.3.0 # via kaggle (pyproject.toml) certifi==2026.2.25 # via requests charset-normalizer==3.4.6 # via requests +fastjsonschema==2.21.2 + # via nbformat idna==3.11 # via requests -kagglesdk==0.1.16 +jsonschema==4.26.0 + # via nbformat +jsonschema-specifications==2025.9.1 + # via jsonschema +jupyter-core==5.9.1 + # via nbformat +jupytext==1.19.1 # via kaggle (pyproject.toml) -packaging==26.0 +kagglesdk==0.1.18 # via kaggle (pyproject.toml) +markdown-it-py==4.0.0 + # via + # jupytext + # mdit-py-plugins +mdit-py-plugins==0.5.0 + # via jupytext +mdurl==0.1.2 + # via markdown-it-py +nbformat==5.10.4 + # via jupytext +packaging==26.0 + # via + # jupytext + # kaggle (pyproject.toml) +platformdirs==4.9.6 + # via jupyter-core protobuf==7.34.1 # via # kaggle (pyproject.toml) @@ -19,16 +52,30 @@ python-dateutil==2.9.0.post0 # via kaggle (pyproject.toml) python-slugify==8.0.4 # via kaggle (pyproject.toml) +pyyaml==6.0.3 + # via jupytext +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications requests==2.33.1 # via # kaggle (pyproject.toml) # kagglesdk +rpds-py==0.30.0 + # via + # jsonschema + # referencing six==1.17.0 # via python-dateutil text-unidecode==1.3 # via python-slugify tqdm==4.67.3 # via kaggle (pyproject.toml) +traitlets==5.14.3 + # via + # jupyter-core + # nbformat urllib3==2.6.3 # via # kaggle (pyproject.toml) From 6b0679ca930069167ab1729d50b783b14d9d542a Mon Sep 17 00:00:00 2001 From: Li Ma Date: Fri, 10 Apr 2026 18:17:44 +0000 Subject: [PATCH 8/9] reformat --- src/kaggle/api/kaggle_api_extended.py | 47 +++++++++++++++++++------- src/kaggle/test/test_benchmarks_cli.py | 16 +++------ 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/src/kaggle/api/kaggle_api_extended.py b/src/kaggle/api/kaggle_api_extended.py index 4cf5adfa..14613655 100644 --- a/src/kaggle/api/kaggle_api_extended.py +++ b/src/kaggle/api/kaggle_api_extended.py @@ -730,7 +730,7 @@ def _authenticate_with_legacy_apikey(self) -> bool: return True def _authenticate_with_access_token(self): - (access_token, source) = get_access_token_from_env() + access_token, source = get_access_token_from_env() if not access_token: return False @@ -1910,7 +1910,7 @@ def dataset_metadata_update(self, dataset, path): dataset: The dataset to update. path: The path to the metadata file. """ - (owner_slug, dataset_slug, effective_path) = self.dataset_metadata_prep(dataset, path) + owner_slug, dataset_slug, effective_path = self.dataset_metadata_prep(dataset, path) meta_file = self.get_dataset_metadata_file(effective_path) with open(meta_file, "r") as f: metadata = json.load(f) @@ -1963,7 +1963,7 @@ def dataset_metadata(self, dataset, path): Returns: The path to the downloaded metadata file. """ - (owner_slug, dataset_slug, effective_path) = self.dataset_metadata_prep(dataset, path) + owner_slug, dataset_slug, effective_path = self.dataset_metadata_prep(dataset, path) if not os.path.exists(effective_path): os.makedirs(effective_path) @@ -3442,7 +3442,7 @@ def kernels_output( token = response.next_page_token outfiles = [] - for item in (response.files or []): + for item in response.files or []: if compiled_pattern and not compiled_pattern.search(item.file_name): continue @@ -3482,7 +3482,7 @@ def kernels_output_cli(self, kernel, kernel_opt=None, path=None, force=False, qu file_pattern: Regex pattern to match against filenames. Only files matching the pattern will be downloaded. """ kernel = kernel or kernel_opt - (_, token) = self.kernels_output(kernel, path, file_pattern, force, quiet) + _, token = self.kernels_output(kernel, path, file_pattern, force, quiet) if token: print(f"Next page token: {token}") @@ -4618,7 +4618,7 @@ def files_upload_cli(self, local_paths, inbox_path, no_resume, no_compress): files_to_create = [] with ResumableUploadContext(no_resume) as upload_context: for local_path in local_paths: - (upload_file, file_name) = self.file_upload_cli(local_path, inbox_path, no_compress, upload_context) + upload_file, file_name = self.file_upload_cli(local_path, inbox_path, no_compress, upload_context) if upload_file is None: continue @@ -5407,6 +5407,7 @@ def _paginate(fetch_page, get_items): def _get_task_names_from_file(self, file_content: str) -> List[str]: """Extract task names from a Python file.""" import ast + task_names = [] try: tree = ast.parse(file_content) @@ -5420,13 +5421,22 @@ def _get_task_names_from_file(self, file_content: str) -> List[str]: for decorator in node.decorator_list: func = decorator.func if isinstance(decorator, ast.Call) else decorator - if not ((isinstance(func, ast.Name) and func.id == 'task') or - (isinstance(func, ast.Attribute) and func.attr == 'task')): + if not ( + (isinstance(func, ast.Name) and func.id == "task") + or (isinstance(func, ast.Attribute) and func.attr == "task") + ): continue name = None if isinstance(decorator, ast.Call): - name = next((k.value.value for k in decorator.keywords if k.arg == 'name' and isinstance(k.value, ast.Constant)), None) + name = next( + ( + k.value.value + for k in decorator.keywords + if k.arg == "name" and isinstance(k.value, ast.Constant) + ), + None, + ) task_names.append(name if name else node.name.title().replace("_", " ")) @@ -5459,6 +5469,7 @@ def benchmarks_tasks_push_cli(self, task, file): # Convert .py file with percent delimiters to .ipynb import jupytext + notebook = jupytext.reads(content, fmt="py:percent") # Add kernelspec metadata so papermill can execute it on the server notebook.metadata["kernelspec"] = { @@ -5493,6 +5504,7 @@ def benchmarks_tasks_push_cli(self, task, file): def _select_models_interactively(self, kaggle, page_size=20): """Prompt the user to pick benchmark models from a paginated list.""" + def _fetch_models(page_token): req = ApiListBenchmarkModelsRequest() if page_token: @@ -5541,6 +5553,7 @@ def _fetch_models(page_token): def _poll_runs(self, kaggle, task_slug_obj, models, wait, poll_interval): """Poll run status until all runs are terminal or timeout.""" + def _fetch_runs(page_token): req = ApiListBenchmarkTaskRunsRequest() req.task_slug = task_slug_obj @@ -5558,7 +5571,11 @@ def _fetch_runs(page_token): if all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs): print("All runs completed:") for r in all_runs: - label = "COMPLETED" if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED else "ERRORED" + label = ( + "COMPLETED" + if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED + else "ERRORED" + ) print(f" {r.model_version_slug}: {label}") return @@ -5578,9 +5595,15 @@ def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10 with self.build_kaggle_client() as kaggle: # Verify the task exists and is ready to run task_info = self._get_benchmark_task(task, kaggle) - if task_info.creation_state != BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED: + if ( + task_info.creation_state + != BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED + ): error_msg = f"Task '{task}' is not ready to run (status: {task_info.creation_state})." - if task_info.creation_state == BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_ERRORED: + if ( + task_info.creation_state + == BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_ERRORED + ): error_msg += f" Task Info: {task_info}." error_msg += " Only completed tasks can be run." raise ValueError(error_msg) diff --git a/src/kaggle/test/test_benchmarks_cli.py b/src/kaggle/test/test_benchmarks_cli.py index 92e32e8b..ddf3a0df 100644 --- a/src/kaggle/test/test_benchmarks_cli.py +++ b/src/kaggle/test/test_benchmarks_cli.py @@ -194,9 +194,7 @@ def test_push_creates_task(self, api, tmp_path, capsys, content, task_name): def test_push_creates_new_task_on_404(self, api, tmp_path, capsys): """A 404 from get_benchmark_task means new task — still creates successfully.""" filepath = _write_task_file(tmp_path) - api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError( - response=MagicMock(status_code=404) - ) + api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError(response=MagicMock(status_code=404)) _setup_create_response(api) _push(api, "my-task", filepath) assert "Task 'my-task' pushed." in capsys.readouterr().out @@ -214,9 +212,7 @@ def test_push_rejects_pending_task(self, api, tmp_path, state): def test_push_propagates_server_error(self, api, tmp_path): """Non-404 HTTP errors (e.g. 500) are re-raised, not swallowed.""" filepath = _write_task_file(tmp_path) - api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError( - response=MagicMock(status_code=500) - ) + api._mock_benchmarks.get_benchmark_task.side_effect = HTTPError(response=MagicMock(status_code=500)) with pytest.raises(HTTPError): _push(api, "my-task", filepath) @@ -224,11 +220,11 @@ def test_push_handles_api_error(self, api, tmp_path): """Push raises ValueError when response contains error_message.""" filepath = _write_task_file(tmp_path) _setup_completed_task(api) - + resp = MagicMock() resp.error_message = "Some backend error" api._mock_benchmarks.create_benchmark_task.return_value = resp - + with pytest.raises(ValueError, match="Failed to push task: Some backend error"): _push(api, "my-task", filepath) @@ -268,9 +264,7 @@ def test_run_schedules_models(self, api, capsys, models): def test_run_reports_skipped_with_reason(self, api, capsys): _setup_completed_task(api) - _setup_batch_schedule( - api, [_make_run_result(scheduled=False, skipped_reason="Already running")] - ) + _setup_batch_schedule(api, [_make_run_result(scheduled=False, skipped_reason="Already running")]) api.benchmarks_tasks_run_cli("my-task", ["gemini-pro"]) output = capsys.readouterr().out assert "gemini-pro: Skipped" in output From 3652f491abb3844f415a8a255187d84039dd9ff4 Mon Sep 17 00:00:00 2001 From: Li Ma Date: Fri, 10 Apr 2026 21:37:43 +0000 Subject: [PATCH 9/9] addres comments. --- src/kaggle/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kaggle/cli.py b/src/kaggle/cli.py index 3829ceb0..5a4567c0 100644 --- a/src/kaggle/cli.py +++ b/src/kaggle/cli.py @@ -1201,7 +1201,7 @@ class Help(object): command_models_update = "Update a model" # Benchmarks commands - command_benchmarks_tasks_push = "Register a task from a Python source file" + command_benchmarks_tasks_push = "Create or update a task from a Python source file" command_benchmarks_tasks_run = "Run a task against model(s)" # Files commands