-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat(benchmarks): implement Kaggle client (push/run functionality) #960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
46c5a9b
46b090f
bdbef87
398c9c6
2504099
d422b0c
d14a77b
6b0679c
dc4434c
3652f49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,15 @@ | |
| from kagglesdk import get_access_token_from_env, KaggleClient, KaggleCredentials, KaggleEnv, KaggleOAuth # type: ignore[attr-defined] | ||
| from kagglesdk.admin.types.inbox_file_service import CreateInboxFileRequest | ||
| from kagglesdk.blobs.types.blob_api_service import ApiStartBlobUploadRequest, ApiStartBlobUploadResponse, ApiBlobType | ||
| from kagglesdk.benchmarks.types.benchmark_enums import BenchmarkTaskRunState, BenchmarkTaskVersionCreationState | ||
| from kagglesdk.benchmarks.types.benchmark_tasks_api_service import ( | ||
| ApiCreateBenchmarkTaskRequest, | ||
| ApiGetBenchmarkTaskRequest, | ||
| ApiListBenchmarkTaskRunsRequest, | ||
| ApiBenchmarkTaskSlug, | ||
| ApiBatchScheduleBenchmarkTaskRunsRequest, | ||
| ) | ||
| from kagglesdk.benchmarks.types.benchmarks_api_service import ApiListBenchmarkModelsRequest | ||
| from kagglesdk.competitions.types.competition_api_service import ( | ||
| ApiListCompetitionsRequest, | ||
| ApiCreateCodeSubmissionRequest, | ||
|
|
@@ -5429,6 +5438,268 @@ def _check_response_version(self, response: Response): | |
| def get_response_processor(self): | ||
| return self._check_response_version | ||
|
|
||
| # ---- Benchmarks CLI ---- | ||
|
|
||
| _TERMINAL_RUN_STATES = { | ||
| BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED, | ||
| BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_ERRORED, | ||
| } | ||
|
|
||
| _PENDING_CREATION_STATES = { | ||
| BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_QUEUED, | ||
| BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_RUNNING, | ||
| } | ||
|
|
||
| @staticmethod | ||
| def _make_task_slug(task: str) -> ApiBenchmarkTaskSlug: | ||
| """Build an ApiBenchmarkTaskSlug from a task name string.""" | ||
| slug = ApiBenchmarkTaskSlug() | ||
| slug.task_slug = task | ||
| return slug | ||
|
|
||
| @staticmethod | ||
| def _normalize_model_list(model) -> list: | ||
| """Normalize a model argument (str, list, or None) into a list.""" | ||
| if isinstance(model, list): | ||
| return model | ||
| return [model] if model else [] | ||
|
|
||
| @staticmethod | ||
| def _paginate(fetch_page, get_items): | ||
| """Exhaust a paginated API, returning all items.""" | ||
| items = [] | ||
| page_token = "" | ||
| while True: | ||
| response = fetch_page(page_token) | ||
| items.extend(get_items(response)) | ||
| page_token = response.next_page_token or "" | ||
| if not page_token: | ||
| break | ||
| return items | ||
|
|
||
| def _get_task_names_from_file(self, file_content: str) -> List[str]: | ||
| """Extract task names from a Python file.""" | ||
| import ast | ||
|
|
||
| task_names = [] | ||
| try: | ||
| tree = ast.parse(file_content) | ||
| except SyntaxError: | ||
| return [] | ||
|
|
||
| for node in ast.walk(tree): | ||
| if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | ||
| continue | ||
|
|
||
| for decorator in node.decorator_list: | ||
andrewmwang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| func = decorator.func if isinstance(decorator, ast.Call) else decorator | ||
|
|
||
| if not ( | ||
| (isinstance(func, ast.Name) and func.id == "task") | ||
| or (isinstance(func, ast.Attribute) and func.attr == "task") | ||
| ): | ||
| continue | ||
|
|
||
| name = None | ||
| if isinstance(decorator, ast.Call): | ||
| name = next( | ||
| ( | ||
| k.value.value | ||
| for k in decorator.keywords | ||
| if k.arg == "name" and isinstance(k.value, ast.Constant) | ||
| ), | ||
| None, | ||
| ) | ||
|
|
||
| task_names.append(name if name else node.name.title().replace("_", " ")) | ||
|
|
||
| return task_names | ||
|
|
||
| def _get_benchmark_task(self, task: str, kaggle): | ||
| """Get benchmark task details from the server.""" | ||
| request = ApiGetBenchmarkTaskRequest() | ||
| request.slug = self._make_task_slug(task) | ||
| return kaggle.benchmarks.benchmark_tasks_api_client.get_benchmark_task(request) | ||
|
|
||
| def _validate_task_in_file(self, task: str, file: str, file_content: str): | ||
| """Validate that the task name is defined in the Python file.""" | ||
| task_names = self._get_task_names_from_file(file_content) | ||
| if not task_names: | ||
| raise ValueError(f"No @task decorators found in file {file}. The file must define at least one task.") | ||
| if task not in task_names: | ||
| raise ValueError(f"Task '{task}' not found in file {file}. Found tasks: {', '.join(task_names)}") | ||
|
|
||
| def benchmarks_tasks_push_cli(self, task, file): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you passing "task" as a parameter to the CLI method instead of reading it from the Python file?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's because currently the users can define multiple tasks in a single file. We need user to decide which one to create. |
||
| if not os.path.isfile(file): | ||
| raise ValueError(f"File {file} does not exist") | ||
| if not file.endswith(".py"): | ||
| raise ValueError(f"File {file} must be a .py file") | ||
|
|
||
| with open(file) as f: | ||
| content = f.read() | ||
|
|
||
| self._validate_task_in_file(task, file, content) | ||
|
|
||
| # Convert .py file with percent delimiters to .ipynb | ||
| import jupytext | ||
|
|
||
| notebook = jupytext.reads(content, fmt="py:percent") | ||
| # Add kernelspec metadata so papermill can execute it on the server | ||
| notebook.metadata["kernelspec"] = { | ||
andrewmwang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "display_name": "Python 3", | ||
| "language": "python", | ||
| "name": "python3", | ||
| } | ||
| notebook_content = jupytext.writes(notebook, fmt="ipynb") | ||
|
|
||
| with self.build_kaggle_client() as kaggle: | ||
| try: | ||
| task_info = self._get_benchmark_task(task, kaggle) | ||
| if task_info.creation_state in self._PENDING_CREATION_STATES: | ||
| raise ValueError(f"Task '{task}' is currently being created (pending). Cannot push now.") | ||
| except HTTPError as e: | ||
| if e.response.status_code != 404: | ||
| raise | ||
|
|
||
| request = ApiCreateBenchmarkTaskRequest() | ||
| request.slug = task | ||
| request.text = notebook_content | ||
|
|
||
| response = kaggle.benchmarks.benchmark_tasks_api_client.create_benchmark_task(request) | ||
| error = getattr(response, "error_message", None) or getattr(response, "errorMessage", None) | ||
| if error: | ||
| raise ValueError(f"Failed to push task: {error}") | ||
| print(f"Task '{task}' pushed.") | ||
andrewmwang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| url = response.url | ||
| if url.startswith("/"): | ||
| url = "https://www.kaggle.com" + url | ||
| print(f"Task URL: {url}") | ||
|
|
||
| def _select_models_interactively(self, kaggle, page_size=20): | ||
| """Prompt the user to pick benchmark models from a paginated list.""" | ||
|
|
||
| def _fetch_models(page_token): | ||
| req = ApiListBenchmarkModelsRequest() | ||
| if page_token: | ||
| req.page_token = page_token | ||
| return kaggle.benchmarks.benchmarks_api_client.list_benchmark_models(req) | ||
|
|
||
| available = self._paginate(_fetch_models, lambda r: r.benchmark_models) | ||
| if not available: | ||
| raise ValueError("No benchmark models available. Cannot schedule runs.") | ||
|
|
||
| total = len(available) | ||
| total_pages = -(-total // page_size) # ceiling division | ||
| current_page = 0 | ||
|
|
||
| print(f"No model specified. {total} model(s) available:") | ||
| while True: | ||
| start = current_page * page_size | ||
| for i, m in enumerate(available[start : start + page_size], start=start + 1): | ||
| print(f" {i}. {m.version.slug} ({m.display_name})") | ||
|
|
||
| nav_hints = [] | ||
| if total_pages > 1: | ||
| print(f" [Page {current_page + 1}/{total_pages}]") | ||
| if current_page < total_pages - 1: | ||
| nav_hints.append("'n'=next") | ||
| if current_page > 0: | ||
| nav_hints.append("'p'=prev") | ||
|
|
||
| prompt_parts = ["Enter model numbers (comma-separated)", "'all'"] | ||
| if nav_hints: | ||
| prompt_parts.extend(nav_hints) | ||
| selection = input(", ".join(prompt_parts) + ": ").strip().lower() | ||
|
|
||
| if selection == "n" and current_page < total_pages - 1: | ||
| current_page += 1 | ||
| elif selection == "p" and current_page > 0: | ||
| current_page -= 1 | ||
| elif selection == "all": | ||
| return [m.version.slug for m in available] | ||
| else: | ||
| try: | ||
| indices = [int(s) for s in selection.split(",")] | ||
| return [available[i - 1].version.slug for i in indices] | ||
| except (ValueError, IndexError): | ||
| raise ValueError(f"Invalid selection: {selection}") | ||
|
|
||
| def _poll_runs(self, kaggle, task_slug_obj, models, wait, poll_interval): | ||
| """Poll run status until all runs are terminal or timeout.""" | ||
|
|
||
| def _fetch_runs(page_token): | ||
| req = ApiListBenchmarkTaskRunsRequest() | ||
| req.task_slug = task_slug_obj | ||
| if models: | ||
| req.model_version_slugs = models | ||
| if page_token: | ||
| req.page_token = page_token | ||
| return kaggle.benchmarks.benchmark_tasks_api_client.list_benchmark_task_runs(req) | ||
|
|
||
| print("Waiting for run(s) to complete...") | ||
| start_time = time.time() | ||
| while True: | ||
| all_runs = self._paginate(_fetch_runs, lambda r: r.runs) | ||
|
|
||
| if all_runs and all(r.state in self._TERMINAL_RUN_STATES for r in all_runs): | ||
| print("All runs completed:") | ||
| for r in all_runs: | ||
| label = ( | ||
| "COMPLETED" | ||
| if r.state == BenchmarkTaskRunState.BENCHMARK_TASK_RUN_STATE_COMPLETED | ||
| else "ERRORED" | ||
| ) | ||
| print(f" {r.model_version_slug}: {label}") | ||
| return | ||
|
|
||
| pending = sum(1 for r in all_runs if r.state not in self._TERMINAL_RUN_STATES) | ||
| print(f" {pending} run(s) still in progress...") | ||
|
|
||
| if wait > 0 and (time.time() - start_time) > wait: | ||
| print(f"Timed out waiting for runs after {wait} seconds.") | ||
| return | ||
|
|
||
| time.sleep(poll_interval) | ||
|
|
||
| def benchmarks_tasks_run_cli(self, task, model=None, wait=None, poll_interval=10): | ||
| models = self._normalize_model_list(model) | ||
| task_slug_obj = self._make_task_slug(task) | ||
|
|
||
| with self.build_kaggle_client() as kaggle: | ||
| # Verify the task exists and is ready to run | ||
| task_info = self._get_benchmark_task(task, kaggle) | ||
| if ( | ||
| task_info.creation_state | ||
| != BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_COMPLETED | ||
| ): | ||
| error_msg = f"Task '{task}' is not ready to run (status: {task_info.creation_state})." | ||
| if ( | ||
| task_info.creation_state | ||
| == BenchmarkTaskVersionCreationState.BENCHMARK_TASK_VERSION_CREATION_STATE_ERRORED | ||
| ): | ||
| error_msg += f" Task Info: {task_info}." | ||
| error_msg += " Only completed tasks can be run." | ||
| raise ValueError(error_msg) | ||
|
|
||
| if not models: | ||
| models = self._select_models_interactively(kaggle) | ||
| print(f"Selected models: {models}") | ||
|
|
||
| request = ApiBatchScheduleBenchmarkTaskRunsRequest() | ||
| request.task_slugs = [task_slug_obj] | ||
| request.model_version_slugs = models | ||
|
|
||
| response = kaggle.benchmarks.benchmark_tasks_api_client.batch_schedule_benchmark_task_runs(request) | ||
| print(f"Submitted run(s) for task '{task}'.") | ||
| for model_slug, res in zip(models, response.results): | ||
| if res.run_scheduled: | ||
| print(f" {model_slug}: Scheduled") | ||
| else: | ||
| print(f" {model_slug}: Skipped ({res.run_skipped_reason})") | ||
|
|
||
| if wait is not None: | ||
| self._poll_runs(kaggle, task_slug_obj, models, wait, poll_interval) | ||
|
|
||
|
|
||
| class TqdmBufferedReader(io.BufferedReader): | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.