diff --git a/experimental/databricks-ml-training/SKILL.md b/experimental/databricks-ml-training/SKILL.md new file mode 100644 index 0000000..9fd85fe --- /dev/null +++ b/experimental/databricks-ml-training/SKILL.md @@ -0,0 +1,257 @@ +--- +name: databricks-ml-training +description: "Classical ML and custom-agent model training, MLflow tracking, and Unity Catalog model registration on Databricks. Use when the user asks to: train models (with MLflow, sklearn, XGBoost, LightGBM, PyTorch, custom pyfunc, etc.); run hyperparameter tuning with Optuna; register models to Unity Catalog and promote versions with `@prod` / `@challenger` aliases; load a registered model for batch scoring via `mlflow.pyfunc.spark_udf`; run inferences as batch, build custom MLflow PyFunc models (Models from Code); author a custom MLflow `ResponsesAgent` (LangGraph, OpenAI-compatible chat) with UC Function or Vector Search tools. NOT for: managing existing serving endpoints (use databricks-model-serving); no-code Knowledge Assistants or Supervisor Agents (use databricks-agent-bricks); MLflow evaluation / scorers (use databricks-mlflow-evaluation)." +compatibility: Requires databricks CLI (>= v0.294.0) +metadata: + version: "0.1.0" +parent: databricks-core +--- + +# ML Training on Databricks + +**FIRST**: Use the parent `databricks-core` skill for CLI basics, authentication, and profile selection. + +Train with MLflow → register to Unity Catalog → consume the **same artifact** as either a batch Spark UDF over Delta or (when low-latency is required) a real-time serving endpoint. + +> **Always train on Databricks** (serverless job or notebook), never in the local Python process the agent is running in. Local training has no access to the silver tables, no MLflow tracking server, no UC registry path, and dies if the chat session drops — submit `databricks jobs submit --no-wait` (see "Train + deploy as a serverless job" below). Only fall back to local execution if the user explicitly asks for it. + +If you need to deploy a real time model serving endpoint **after** the model is registered (creating endpoints, traffic config, version-swapping, querying, Foundation Model API endpoints), see [databricks-model-serving](../../skills/databricks-model-serving/SKILL.md). + +| Consumption | When | How | +|---|---|---| +| **Batch UDF** | Dashboards, daily/hourly scores, predictions read by Genie/Dashboards or an app (often synced to a Lakebase table) | `mlflow.pyfunc.spark_udf(...)` → `INSERT INTO gold_predictions` | +| **Real-time endpoint** | Score on a user action (fraud at authorization, rec at page load) — sub-100ms | `mlflow.deployments.get_deploy_client()` (classical) / `agents.deploy()` (agents). Endpoint lifecycle: see [databricks-model-serving](../../skills/databricks-model-serving/SKILL.md). | + +## Default Canonical flow + +``` +silver_ + silver_ + ▼ + notebook (as a serverless job): + ├── train with mlflow.autolog (XGBoost / sklearn / etc.) + ├── mlflow.register_model → UC: {catalog}.{schema}.{model} + ├── set_registered_model_alias(name, "prod", version) + └── spark_udf(@prod) over latest features → MERGE into gold_predictions + ▼ +gold__predictions ◄── dashboards, apps, Genie read this +``` + +One notebook, one artifact. Re-running = retraining. Gold is where truth lives — read paths never call the model directly. Keep label-window logic (`failure occurred within 7 days`) in the notebook during dev; once stable, promote to a silver materialized view in SDP. + +--- + +## Train and register (the 90% case) + +`mlflow.autolog()` captures params, metrics, code, and the model artifact for every run; `registered_model_name=...` auto-registers the best run to UC (auto-incremented version). Wrap training with **Optuna** so each trial is a child run and the best one is what gets registered. + +**Always `mlflow.set_registry_uri("databricks-uc")`** — without it, models land in the deprecated workspace registry. **The experiment's parent folder must exist** — `set_experiment` does NOT auto-create it (fails with `NOT_FOUND: Parent directory does not exist`). Pre-create it once with `databricks workspace mkdirs` before the job runs. + +```bash +# Once per project — create the parent folder for the MLflow experiment. +databricks workspace mkdirs /Users/me@example.com/turbine_project +``` + +Use the Databricks notebook source format (`# Databricks notebook source` header, `# COMMAND ----------` separators, `# MAGIC %md`/`%sql` magics for markdown/SQL cells): + +```python +# Databricks notebook source +# MAGIC %md +# MAGIC # Turbine failure prediction +# MAGIC +# MAGIC Train an XGBoost classifier on engineered turbine telemetry features. +# MAGIC ## Data exploration + +# COMMAND ---------- + +# (basic data exploration — class balance, schema sanity, etc.) + +# COMMAND ---------- +# MAGIC %md +# MAGIC ## Training the model + +# COMMAND ---------- + +import mlflow, mlflow.xgboost, optuna +from mlflow.tracking import MlflowClient +from xgboost import XGBClassifier +from sklearn.metrics import roc_auc_score + +mlflow.set_registry_uri("databricks-uc") +mlflow.set_experiment("/Users/me@example.com/turbine_project/mlflow_experiment") + +CATALOG, SCHEMA, NAME = "ai_demo_gen", "wind_farm", "turbine_failure" +FULL_NAME = f"{CATALOG}.{SCHEMA}.{NAME}" + +mlflow.xgboost.autolog(log_input_examples=True, registered_model_name=FULL_NAME) + +# For imbalanced labels: stratify the split, set scale_pos_weight = neg/pos. +def objective(trial): + params = { + "n_estimators": trial.suggest_int("n_estimators", 100, 400), + "max_depth": trial.suggest_int("max_depth", 3, 10), + "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3, log=True), + } + with mlflow.start_run(nested=True): + m = XGBClassifier(**params).fit(X_train, y_train) + return roc_auc_score(y_test, m.predict_proba(X_test)[:, 1]) + +with mlflow.start_run(run_name="hpo"): + optuna.create_study(direction="maximize").optimize(objective, n_trials=20) + +# COMMAND ---------- +# MAGIC %md +# MAGIC ## Promote to @prod alias + +# COMMAND ---------- +# Stages are deprecated — UC uses movable aliases. +client = MlflowClient(registry_uri="databricks-uc") +latest = max(client.search_model_versions(f"name='{FULL_NAME}'"), + key=lambda v: int(v.version)) +client.set_registered_model_alias(FULL_NAME, "prod", latest.version) +``` + +**Framework autolog**: `mlflow.{sklearn,xgboost,lightgbm,pytorch,tensorflow,spark}.autolog()`. + +**Aliases, not stages**: UC dropped `Staging`/`Production`. Use movable `@prod`/`@challenger`; load with `models:/{full_name}@prod`. Promoting a new version is one `set_registered_model_alias` call. + +--- + +## Consume: batch scoring over Delta + +The cheap, default path. Load the registered model as a Spark UDF and score a Delta table; write predictions to a gold table that downstream consumers read. + +```python +# COMMAND ---------- +# MAGIC %md +# MAGIC ## Score and save to a gold predictions table + +# COMMAND ---------- +import mlflow + +# env_manager rules: +# "local" → same runtime as training (same notebook/job). Fastest, default in dev/demo. +# "virtualenv"→ different runtime than training; rebuilds the model's env. +# "uv" → same as virtualenv but faster (MLflow ≥ 2.22). +predict = mlflow.pyfunc.spark_udf( + spark, + model_uri=f"models:/{FULL_NAME}@prod", + env_manager="local", +) + +features = spark.table(f"{CATALOG}.{SCHEMA}.silver_turbine_features_latest") +scored = features.withColumn("risk_score", predict(*[features[c] for c in feature_cols])) + +# Overwrite-per-run pattern for "latest score per entity": +scored.select("turbine_id", "risk_score", F.current_timestamp().alias("scored_at")) \ + .write.mode("overwrite").saveAsTable(f"{CATALOG}.{SCHEMA}.gold_turbine_predictions") +``` + +For incremental scoring with history, MERGE into the predictions table instead of overwrite. + +--- + +## Real-time serving (when required) + +After registering a model to UC, deploy it behind a Model Serving endpoint. The dev-side call is `mlflow.deployments.get_deploy_client("databricks").create_endpoint(...)` for classical ML or `agents.deploy(...)` for `ResponsesAgent`s. First deploy is ~5 min for classical ML. + +For endpoint create / update / version-swap, traffic config, AI Gateway, querying, the `state.ready` + `state.config_update` two-field readiness check, and Foundation Model API endpoints, see **[databricks-model-serving](../../skills/databricks-model-serving/SKILL.md)**. + +--- + +## Train + deploy as a serverless job + +Training notebooks run a few minutes (Optuna + UC register; endpoint warmup adds 5–15 min if you also deploy). Submit as a serverless one-time run so the CLI doesn't block. The notebook ends with `dbutils.notebook.exit(json.dumps({...}))` so the structured result (`model_version`, `val_auc`, `endpoint_name`) reaches `.notebook_output.result`. + +```bash +# 1. Upload the training notebook +databricks workspace import /Workspace/Users/me@example.com/turbine_project/train \ + --file ./train_notebook.py --format SOURCE --language PYTHON --overwrite + +# 2. Submit as serverless one-time run (returns {"run_id": N} immediately with --no-wait) +RUN_ID=$(databricks jobs submit --no-wait --json '{ + "run_name": "turbine-train-and-deploy", + "tasks": [{ + "task_key": "train", + "notebook_task": {"notebook_path": "/Workspace/Users/me@example.com/turbine_project/train"}, + "environment_key": "ml_env" + }], + "environments": [{ + "environment_key": "ml_env", + "spec": { + "client": "4", + "dependencies": ["mlflow==2.22.0", "xgboost==2.1.3", "optuna==4.1.0", "scikit-learn==1.5.2"] + } + }] +}' | jq -r .run_id) + +# 3. Poll until a terminal life_cycle_state. +for _ in $(seq 60); do + STATE=$(databricks jobs get-run "$RUN_ID" | jq -r '.state.life_cycle_state // "UNKNOWN"') + echo "$(date +%H:%M:%S) $STATE" + [[ "$STATE" =~ ^(TERMINATED|SKIPPED|INTERNAL_ERROR)$ ]] && break + sleep 30 +done +[[ "$STATE" =~ ^(TERMINATED|SKIPPED|INTERNAL_ERROR)$ ]] || { databricks jobs cancel-run "$RUN_ID"; exit 1; } + +# life_cycle_state TERMINATED only means "the run ended" — check result_state. +RESULT=$(databricks jobs get-run "$RUN_ID" | jq -r '.state.result_state // "UNKNOWN"') +echo "result_state=$RESULT" +[[ "$RESULT" == "SUCCESS" ]] || { echo "Run did not succeed"; exit 1; } + +# 4. Pull structured output via the TASK run_id (NOT the submit run_id). +TASK_RUN_ID=$(databricks jobs get-run "$RUN_ID" | jq -r '.tasks[0].run_id') +databricks jobs get-run-output "$TASK_RUN_ID" | jq '.notebook_output.result' +# → '{"model_version":"3","val_auc":0.91,"rows_scored":124,"endpoint":"turbine-risk-endpoint"}' +``` + +For the four `jobs submit` traps (`spec.client: "4"` requirement, TASK-vs-submit run_id, `print()` unreliable, tags rejected) and full debugging flow, see **[databricks-jobs](../../skills/databricks-jobs/SKILL.md#one-time-runs-jobs-submit--async-pattern-for-notebooks)**. + +--- + +## Custom pyfunc + +When sklearn/XGBoost autolog isn't enough — custom preprocessing, multiple sub-models, external API calls, ensemble logic. See **[references/custom-pyfunc.md](references/custom-pyfunc.md)** for a full worked example. Two non-obvious things: + +- **`python_model="path/to/file.py"`** (file path, not class instance) + `mlflow.models.set_model(MyModel())` at the end of that file. This is the "Models from Code" pattern — the file is logged verbatim, no pickling of the class. +- **`mlflow.models.predict(model_uri=..., input_data=..., env_manager="uv")`** before deploying. Catches missing deps before the endpoint does. + +--- + +## Custom GenAI agents + +Hand-rolled `ResponsesAgent` (LangGraph + UC Function tools + Vector Search retrieval) — see **[references/genai-agents.md](references/genai-agents.md)**. + +Prefer no-code authoring via [databricks-agent-bricks](../databricks-agent-bricks/SKILL.md) (Knowledge Assistants, Supervisor Agents) unless the user explicitly needs a custom LangGraph agent. + +--- + +## Gotchas (the ones that cost time) + +| Trap | Fix | +|---|---| +| Model lands in workspace registry, not UC | `mlflow.set_registry_uri("databricks-uc")` *before* logging | +| Endpoint returns PERMISSION_DENIED at first query | Pass `resources=[...]` to `log_model` (covers UC functions, VS indexes, other endpoints, Lakebase) — see [references/genai-agents.md#resources-that-need-passthrough-auth](references/genai-agents.md#resources-that-need-passthrough-auth) for the full list | +| Used `transition_model_version_stage` | Stages are deprecated in UC. Use `client.set_registered_model_alias(name, "prod", version)` | +| `spark_udf` rebuilds a virtualenv on every call | Pass `env_manager="local"` when training+scoring share a runtime | +| `pip_requirements` mismatch crashes endpoint at load | Pin exact versions; or pull live with `f"mlflow=={get_distribution('mlflow').version}"` | +| `agents.deploy()` produced a weirdly-named endpoint | Pass `endpoint_name=...` explicitly. Auto-derived name is `agents_--` | + +Endpoint-lifecycle gotchas (readiness two-state, version-swap, Serving-UI SP filter) live in [databricks-model-serving](../../skills/databricks-model-serving/SKILL.md). + +--- + +## Reference files + +| File | Contents | +|---|---| +| [references/custom-pyfunc.md](references/custom-pyfunc.md) | Single end-to-end custom pyfunc example: artifacts, signature, code_paths, log → register → deploy → query. | +| [references/genai-agents.md](references/genai-agents.md) | Custom LangGraph `ResponsesAgent` with UC Function + Vector Search tools. `create_text_output_item` gotcha and the `resources=[...]` passthrough-auth list. For no-code agents prefer **databricks-agent-bricks**. | + +## Related skills + +- **[databricks-model-serving](../../skills/databricks-model-serving/SKILL.md)** — serving-endpoint lifecycle (create, query, update-config, version-swap, AI Gateway, Foundation Model API endpoints). +- **[databricks-agent-bricks](../databricks-agent-bricks/SKILL.md)** — no-code Knowledge Assistants and Supervisor Agents. Prefer this over hand-rolling agents. +- **[databricks-mlflow-evaluation](../databricks-mlflow-evaluation/SKILL.md)** — evaluate model/agent quality before promoting `@prod`. +- **[databricks-vector-search](../databricks-vector-search/SKILL.md)** — vector indexes used as retrieval tools in agents. +- **[databricks-jobs](../../skills/databricks-jobs/SKILL.md)** — async deploy pattern (`--no-wait`, TASK run_id trap). +- **[databricks-unity-catalog](../databricks-unity-catalog/SKILL.md)** — UC governs the registered model: permissions, lineage, audit. diff --git a/experimental/databricks-ml-training/agents/openai.yaml b/experimental/databricks-ml-training/agents/openai.yaml new file mode 100644 index 0000000..f875f53 --- /dev/null +++ b/experimental/databricks-ml-training/agents/openai.yaml @@ -0,0 +1,7 @@ +interface: + display_name: "Databricks ML Training" + short_description: "Train and register ML models on Databricks with MLflow" + icon_small: "./assets/databricks.svg" + icon_large: "./assets/databricks.png" + brand_color: "#FF3621" + default_prompt: "Use $databricks-ml-training for training and registering ML models on Databricks." diff --git a/experimental/databricks-ml-training/assets/databricks.png b/experimental/databricks-ml-training/assets/databricks.png new file mode 100644 index 0000000..263fe98 Binary files /dev/null and b/experimental/databricks-ml-training/assets/databricks.png differ diff --git a/experimental/databricks-ml-training/assets/databricks.svg b/experimental/databricks-ml-training/assets/databricks.svg new file mode 100644 index 0000000..9d19110 --- /dev/null +++ b/experimental/databricks-ml-training/assets/databricks.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/experimental/databricks-ml-training/references/custom-pyfunc.md b/experimental/databricks-ml-training/references/custom-pyfunc.md new file mode 100644 index 0000000..034c35e --- /dev/null +++ b/experimental/databricks-ml-training/references/custom-pyfunc.md @@ -0,0 +1,106 @@ +# Custom pyfunc model + +When sklearn / XGBoost autolog isn't enough: custom preprocessing not captured by a sklearn pipeline, multiple sub-models behind one endpoint, external API calls during inference, business-logic-heavy post-processing. + +Same UC registry + serving story as classical ML — only the *logging* step changes. + +## End-to-end example: file-based pyfunc with preprocessing + sub-model + +Project layout: + +``` +my_model/ +├── model.py # PythonModel + mlflow.models.set_model(...) +├── log_model.py # Logs + registers to UC +└── artifacts/ + ├── preprocessor.pkl + └── booster.json +``` + +```python +# model.py — logged verbatim via python_model="model.py" (Models from Code). +# DO NOT pickle a class instance; use this file-path pattern instead. +import json, pickle, pandas as pd +import mlflow +from mlflow.pyfunc import PythonModel + +class TurbineRiskModel(PythonModel): + def load_context(self, context): + with open(context.artifacts["preprocessor"], "rb") as f: + self.pre = pickle.load(f) + from xgboost import Booster + self.booster = Booster() + self.booster.load_model(context.artifacts["booster"]) + + def predict(self, context, model_input: pd.DataFrame, params=None) -> pd.DataFrame: + X = self.pre.transform(model_input) + proba = self.booster.predict(X) + return pd.DataFrame({ + "risk_score": proba, + "risk_level": ["HIGH" if p > 0.7 else "MEDIUM" if p > 0.4 else "LOW" for p in proba], + }) + +mlflow.models.set_model(TurbineRiskModel()) +``` + +```python +# log_model.py +import mlflow +from mlflow.models import infer_signature +from mlflow.tracking import MlflowClient + +mlflow.set_registry_uri("databricks-uc") +mlflow.set_experiment("/Users/me@example.com/turbine_risk") + +CATALOG, SCHEMA, NAME = "ai_demo_gen", "wind_farm", "turbine_risk" +FULL_NAME = f"{CATALOG}.{SCHEMA}.{NAME}" + +sample_input = pd.DataFrame({"vib_rms": [0.4], "rpm_mean": [18.2], "bearing_temp_max": [71.3]}) +sample_output = pd.DataFrame({"risk_score": [0.0], "risk_level": ["LOW"]}) + +with mlflow.start_run(): + info = mlflow.pyfunc.log_model( + name="model", + python_model="model.py", # file path, not an instance + artifacts={ + "preprocessor": "artifacts/preprocessor.pkl", + "booster": "artifacts/booster.json", + }, + signature=infer_signature(sample_input, sample_output), + input_example=sample_input, + # Pin exact versions — endpoint rebuilds the env from these: + pip_requirements=["mlflow==2.22.0", "xgboost==2.1.3", "scikit-learn==1.5.2", "pandas"], + # Extra modules to ship with the model (e.g. shared util libs): + # code_paths=["src/utils.py"], + registered_model_name=FULL_NAME, + ) + +# Pre-deploy validation — rebuilds the env locally and runs predict(). +# Catches missing deps / signature drift BEFORE the endpoint does. +mlflow.models.predict( + model_uri=info.model_uri, + input_data=sample_input, + env_manager="uv", # MLflow ≥ 2.22; falls back to "virtualenv" otherwise +) + +# Promote to @prod +client = MlflowClient(registry_uri="databricks-uc") +v = max(client.search_model_versions(f"name='{FULL_NAME}'"), key=lambda x: int(x.version)).version +client.set_registered_model_alias(FULL_NAME, "prod", v) +``` + +**Why `python_model="model.py"`**: file logged verbatim, no class pickling — avoids Python-version unpickle crashes between training and serving runtimes. Pair with `code_paths=[...]` to ship companion modules; `mlflow.models.set_model(instance)` at end of file is the contract (exactly one call). + +## Consume + +Same two paths as autologged classical ML — see [SKILL.md § batch scoring](../SKILL.md#consume-batch-scoring-over-delta). + +- **Batch**: `mlflow.pyfunc.spark_udf(spark, model_uri=f"models:/{FULL_NAME}@prod", env_manager="local")` over a Delta table. +- **Real-time**: `client.create_endpoint(...)` for the dev-side call; endpoint lifecycle in [databricks-model-serving](../../../skills/databricks-model-serving/SKILL.md). Query returns a DataFrame-shaped JSON since `predict` returns a DataFrame. + +```bash +databricks serving-endpoints query turbine-risk-endpoint --json '{ + "dataframe_records": [{"vib_rms": 0.6, "rpm_mean": 19.0, "bearing_temp_max": 78.0}] +}' +# → {"predictions": [{"risk_score": 0.82, "risk_level": "HIGH"}]} +``` diff --git a/experimental/databricks-ml-training/references/genai-agents.md b/experimental/databricks-ml-training/references/genai-agents.md new file mode 100644 index 0000000..589cd7a --- /dev/null +++ b/experimental/databricks-ml-training/references/genai-agents.md @@ -0,0 +1,251 @@ +# Custom GenAI agents with MLflow ResponsesAgent + +Edge case. **For most demos, use [databricks-agent-bricks](../../databricks-agent-bricks/SKILL.md)** — pre-built Knowledge Assistants and Supervisor Agents wire up Genie + KAs + tools without any agent code. Hand-roll a `ResponsesAgent` only when you need a custom orchestration the supervisor can't express (custom routing logic, multi-step plans, agent that calls another agent over HTTP). + +## What ResponsesAgent is + +MLflow 3's standardized agent interface. OpenAI-compatible request/response (`{input: [{role, content}]}` → `{output: [...]}`). Supports streaming. Logs with `python_model="agent.py"` (file-based) and deploys via `databricks.agents.deploy()` to a serving endpoint with built-in tracing and eval hooks. + +## Full example: LangGraph agent with UC Function + Vector Search tools + +Project layout: + +``` +my_agent/ +├── agent.py # LangGraphAgent + tools + mlflow.models.set_model(...) +├── log_model.py # Logs with resources= for auto-auth, registers to UC +└── deploy_agent.py # Submitted as a job because deploy takes ~15 min +``` + +```python +# agent.py +import mlflow +from mlflow.pyfunc import ResponsesAgent +from mlflow.types.responses import ( + ResponsesAgentRequest, ResponsesAgentResponse, ResponsesAgentStreamEvent, + output_to_responses_items_stream, to_chat_completions_input, +) +from databricks_langchain import ( + ChatDatabricks, UCFunctionToolkit, VectorSearchRetrieverTool, +) +from langchain_core.messages import AIMessage +from langchain_core.runnables import RunnableLambda +from langgraph.graph import END, StateGraph +from langgraph.graph.message import add_messages +from langgraph.prebuilt.tool_node import ToolNode +from typing import Annotated, Generator, Sequence, TypedDict + +LLM_ENDPOINT = "databricks-claude-sonnet-4-6" # resolve at runtime — see training-and-serving.md +VS_INDEX = "ai_demo_gen.wind_farm.docs_index" +UC_FUNCTIONS = ["ai_demo_gen.wind_farm.lookup_turbine_history"] +SYSTEM_PROMPT = ( + "You are a turbine ops assistant. Use lookup_turbine_history for hardware " + "history queries, the docs retriever for procedure questions." +) + +class State(TypedDict): + messages: Annotated[Sequence, add_messages] + +class TurbineAgent(ResponsesAgent): + def __init__(self): + self.llm = ChatDatabricks(endpoint=LLM_ENDPOINT, temperature=0.1) + # Tools — UC functions and Vector Search both come from databricks_langchain. + self.tools = list(UCFunctionToolkit(function_names=UC_FUNCTIONS).tools) + self.vs_tool = VectorSearchRetrieverTool( + index_name=VS_INDEX, num_results=5, + columns=["content", "doc_uri", "title"], + ) + self.tools.append(self.vs_tool) + self.llm_with_tools = self.llm.bind_tools(self.tools) + + def _graph(self): + def call_model(state): + msgs = [{"role": "system", "content": SYSTEM_PROMPT}] + state["messages"] + return {"messages": [self.llm_with_tools.invoke(msgs)]} + def should_continue(state): + last = state["messages"][-1] + return "tools" if isinstance(last, AIMessage) and last.tool_calls else "end" + + g = StateGraph(State) + g.add_node("agent", RunnableLambda(call_model)) + g.add_node("tools", ToolNode(self.tools)) + g.set_entry_point("agent") + g.add_conditional_edges("agent", should_continue, {"tools": "tools", "end": END}) + g.add_edge("tools", "agent") + return g.compile() + + def predict_stream(self, req: ResponsesAgentRequest) -> Generator[ResponsesAgentStreamEvent, None, None]: + msgs = to_chat_completions_input([m.model_dump() for m in req.input]) + for kind, payload in self._graph().stream({"messages": msgs}, stream_mode=["updates"]): + if kind != "updates": continue + for node in payload.values(): + if node.get("messages"): + yield from output_to_responses_items_stream(node["messages"]) + + def predict(self, req: ResponsesAgentRequest) -> ResponsesAgentResponse: + items = [ev.item for ev in self.predict_stream(req) + if ev.type == "response.output_item.done"] + return ResponsesAgentResponse(output=items) + +mlflow.langchain.autolog() +mlflow.models.set_model(TurbineAgent()) +``` + +### CRITICAL: output items must use helper methods + +The supervisor will silently drop your output if you return raw dicts: + +```python +# WRONG — raw dicts silently fail +return ResponsesAgentResponse(output=[{"role": "assistant", "content": "..."}]) + +# CORRECT +return ResponsesAgentResponse(output=[ + self.create_text_output_item(text="...", id="msg_1"), +]) +``` + +Three helpers on `ResponsesAgent`: +- `self.create_text_output_item(text, id)` — text response. +- `self.create_function_call_item(id, call_id, name, arguments)` — tool call. +- `self.create_function_call_output_item(call_id, output)` — tool result. + +LangGraph's `output_to_responses_items_stream` (used above) emits these correctly, so the helpers are mainly relevant when hand-building events. + +## Log + register + +The non-obvious bit: `resources=[...]` is mandatory for auto-passthrough auth. Without it the deployed endpoint has no creds for the LLM, the UC functions, or the Vector Search index — every query returns `PERMISSION_DENIED` and the error doesn't explain why. + +```python +# log_model.py +import mlflow +from mlflow.models.resources import ( + DatabricksServingEndpoint, DatabricksFunction, DatabricksVectorSearchIndex, +) +from mlflow.tracking import MlflowClient +from agent import LLM_ENDPOINT, VS_INDEX, UC_FUNCTIONS + +mlflow.set_registry_uri("databricks-uc") +mlflow.set_experiment("/Users/me@example.com/turbine_agent") + +FULL_NAME = "ai_demo_gen.wind_farm.turbine_agent" + +resources = [ + DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT), + DatabricksVectorSearchIndex(index_name=VS_INDEX), + *[DatabricksFunction(function_name=f) for f in UC_FUNCTIONS], +] + +with mlflow.start_run(): + info = mlflow.pyfunc.log_model( + name="agent", + python_model="agent.py", # file path; agent.py calls set_model() + resources=resources, # auto-auth — DO NOT skip + input_example={"input": [{"role": "user", "content": "What's the maintenance history for turbine WTG-12?"}]}, + pip_requirements=[ + "mlflow==2.22.0", + "databricks-langchain", + "langgraph==0.3.4", + "databricks-agents", + "pydantic>=2", + ], + registered_model_name=FULL_NAME, + ) + +# Pre-deploy validation — rebuild the env, run a request, surface failures early. +mlflow.models.predict( + model_uri=info.model_uri, + input_data={"input": [{"role": "user", "content": "ping"}]}, + env_manager="uv", +) + +client = MlflowClient(registry_uri="databricks-uc") +v = max(client.search_model_versions(f"name='{FULL_NAME}'"), key=lambda x: int(x.version)).version +client.set_registered_model_alias(FULL_NAME, "prod", v) +``` + +### Resources that need passthrough auth + +| Resource | Import (`mlflow.models.resources`) | +|---|---| +| Foundation Model API / custom serving endpoint | `DatabricksServingEndpoint(endpoint_name=...)` | +| UC SQL/Python function | `DatabricksFunction(function_name=...)` | +| Vector Search index | `DatabricksVectorSearchIndex(index_name=...)` | +| Lakebase Postgres | `DatabricksLakebase(database_instance_name=...)` | + +Anything the agent calls that isn't covered here will hit auth errors at the endpoint. + +## Deploy (async job, ~15 min) + +`databricks.agents.deploy()` blocks for ~15 minutes — don't run it inline from the CLI. Submit as a serverless job so the chat session doesn't hold the connection. + +**Before submitting, check whether a deploy is already in flight or already done.** Re-submitting on top of a running deploy wastes ~15 min of serverless and can race for the same endpoint name. + +```bash +# 1. Is a deploy_agent run already active for this model? Match on run_name. +databricks jobs list-runs --active-only --output json \ + | jq --arg name "deploy_${MODEL_NAME}" '.runs[]? | select(.run_name == $name) | {run_id, state}' + +# 2. Does the target endpoint already exist? If READY on the right version, skip the redeploy. +databricks serving-endpoints get 2>/dev/null \ + | jq '{ready: .state.ready, served: [.config.served_models[] | {name, model_version}]}' +``` + +If either check returns a hit, follow the existing run with `jobs get-run ` instead of submitting a new one. + +```python +# deploy_agent.py +import json, sys +from databricks import agents + +model_name = sys.argv[1] +version = sys.argv[2] +endpoint_name = sys.argv[3] if len(sys.argv) > 3 else None + +# Always pass endpoint_name explicitly — auto-derived names are +# `agents_--` with dots → dashes, which is unpredictable. +kwargs = {"tags": {"aidevkit_project": "ai-dev-kit"}} +if endpoint_name: + kwargs["endpoint_name"] = endpoint_name + +deployment = agents.deploy(model_name, version, **kwargs) + +# Land structured output via dbutils.notebook.exit — print() unreliable on serverless. +dbutils.notebook.exit(json.dumps({ + "endpoint_name": deployment.endpoint_name, + "query_endpoint": deployment.query_endpoint, +})) +``` + +Submit via the same `jobs submit --no-wait` pattern shown in [SKILL.md § Train + deploy as a serverless job](../SKILL.md#train--deploy-as-a-serverless-job) — same script, just `deploy_agent.py` as the notebook. + +## Query + +```bash +databricks serving-endpoints query turbine-agent-endpoint --json '{ + "messages": [{"role": "user", "content": "What is the maintenance history for WTG-12?"}], + "max_tokens": 800 +}' +``` + +OpenAI-compatible client also works: + +```python +from openai import OpenAI +client = OpenAI( + base_url=f"{WORKSPACE_URL}/serving-endpoints/turbine-agent-endpoint", + api_key=DATABRICKS_TOKEN, +) +client.chat.completions.create( + model="turbine-agent-endpoint", + messages=[{"role": "user", "content": "..."}], +) +``` + +## Iteration + +`databricks workspace import-dir ./my_agent ... --overwrite` then re-run `log_model.py`. `agents.deploy()` with a new version **updates the existing endpoint in place** — no need to recreate. Re-deploy only when changing endpoint config (workload size, route splits). + +## Packages + +DBR 16.1+ has `mlflow` 3.x, `langchain`, `pydantic`, `databricks-sdk` pre-installed. Typically only need `%pip install -q databricks-langchain langgraph databricks-agents`. diff --git a/manifest.json b/manifest.json index e6925dc..436f9ac 100644 --- a/manifest.json +++ b/manifest.json @@ -226,6 +226,19 @@ "repo_dir": "experimental", "version": "0.0.1" }, + "databricks-ml-training": { + "description": "Classical ML and custom-agent model training, MLflow tracking, and Unity Catalog model registration on Databricks. Use when the user asks to: train models (with MLflow, sklearn, XGBoost, LightGBM, PyTorch, custom pyfunc, etc.); run hyperparameter tuning with Optuna; register models to Unity Catalog and promote versions with `@prod` / `@challenger` aliases; load a registered model for batch scoring via `mlflow.pyfunc.spark_udf`; run inferences as batch, build custom MLflow PyFunc models (Models from Code); author a custom MLflow `ResponsesAgent` (LangGraph, OpenAI-compatible chat) with UC Function or Vector Search tools. NOT for: managing existing serving endpoints (use databricks-model-serving); no-code Knowledge Assistants or Supervisor Agents (use databricks-agent-bricks); MLflow evaluation / scorers (use databricks-mlflow-evaluation).", + "files": [ + "SKILL.md", + "agents/openai.yaml", + "assets/databricks.png", + "assets/databricks.svg", + "references/custom-pyfunc.md", + "references/genai-agents.md" + ], + "repo_dir": "experimental", + "version": "0.1.0" + }, "databricks-mlflow-evaluation": { "description": "MLflow 3 GenAI agent evaluation. Use when writing mlflow.genai.evaluate() code, creating @scorer functions, using built-in scorers (Guidelines, Correctness, Safety, RetrievalGroundedness), building eval datasets from traces, setting up trace ingestion and production monitoring, aligning judges with MemAlign from domain expert feedback, or running optimize_prompts() with GEPA for automated prompt improvement.", "files": [ @@ -258,7 +271,7 @@ "references/off-platform-streaming.md" ], "repo_dir": "skills", - "version": "0.1.0" + "version": "0.4.0" }, "databricks-pipelines": { "description": "Databricks Spark Declarative Pipelines (SDP) for ETL and streaming", diff --git a/skills/databricks-model-serving/SKILL.md b/skills/databricks-model-serving/SKILL.md index da751ed..2c23e78 100644 --- a/skills/databricks-model-serving/SKILL.md +++ b/skills/databricks-model-serving/SKILL.md @@ -1,9 +1,9 @@ --- name: databricks-model-serving -description: "Manage Databricks Model Serving endpoints via CLI. Use when asked to create, configure, query, or manage model serving endpoints for LLM inference, custom models, or external models." +description: "Databricks Model Serving endpoint lifecycle and ops. Use when asked to: create, query, update, scale, or delete serving endpoints via CLI or the MLflow Deployments client; configure traffic routing for A/B / canary deployments; do zero-downtime version swaps; manage AI Gateway rate limits and usage tracking; discover Foundation Model API endpoints at runtime; integrate an endpoint into a Databricks App. NOT for: training models, MLflow autologging, UC model registration, custom PyFunc authoring, or hand-rolled ResponsesAgent code (use databricks-ml-training); no-code Knowledge Assistants or Supervisor Agents (use databricks-agent-bricks); MLflow evaluation / scorers (use databricks-mlflow-evaluation)." compatibility: Requires databricks CLI (>= v0.294.0) metadata: - version: "0.1.0" + version: "0.4.0" parent: databricks-core --- @@ -17,7 +17,7 @@ Model Serving provides managed endpoints for serving LLMs, custom ML models, and | Type | When to Use | Key Detail | |------|-------------|------------| -| Pay-per-token | Foundation Model APIs (Llama, DBRX, etc.) | Uses `system.ai.*` catalog models, simplest setup | +| Pay-per-token | Foundation Model APIs (Llama, GPT-5, Claude, Gemini, etc.) | Uses `system.ai.*` catalog models, pre-provisioned in every workspace. Discover at runtime — see [Foundation Model API endpoints](#foundation-model-api-endpoints) below. | | Provisioned throughput | Dedicated GPU capacity | Guaranteed throughput, higher cost | | Custom model | Your own MLflow models or containers | Deploy any model with an MLflow signature | @@ -74,7 +74,7 @@ databricks serving-endpoints create \ }' --profile ``` -- Discover available Foundation Models: check the `system.ai` catalog in Unity Catalog, or use `databricks serving-endpoints list --profile ` to see available endpoints. Use `databricks serving-endpoints get-open-api --profile ` to inspect the endpoint's API schema. +- Discover available Foundation Models: see [Foundation Model API endpoints](#foundation-model-api-endpoints) below for the runtime-list snippet and default-picking rules. You can also check the `system.ai` catalog in Unity Catalog, or run `databricks serving-endpoints list --profile ` to see what's deployed in the workspace. Use `databricks serving-endpoints get-open-api --profile ` to inspect a specific endpoint's API schema. - Long-running operation; the CLI waits for completion by default. Use `--no-wait` to return immediately, then poll: ```bash databricks serving-endpoints get --profile @@ -82,29 +82,65 @@ databricks serving-endpoints create \ ``` - For provisioned throughput or custom model endpoints, run `databricks serving-endpoints create -h` to discover the required JSON fields for your endpoint type. +### MLflow Deployments client (Python alternative) + +`mlflow.deployments.get_deploy_client("databricks").create_endpoint(name=..., config={...})` takes the same JSON shape as the CLI. Two gotchas: + +- **`tags=` is a top-level kwarg**, NOT a field inside `config`. Same `[{key, value}]` shape as `serving-endpoints patch --add-tags`. +- **`traffic_config.routes[].served_model_name` = `"-"`** (e.g. `"turbine_failure-3"`). The API auto-derives this from the entity, but you reference the exact string in `traffic_config` — get the format wrong and the route silently doesn't match. + +### Zero-downtime version swap + +To roll an endpoint to a new model version: repoint the alias **and** call `update_endpoint` with the new `served_entities` + matching `traffic_config`. Missing either half is the common bug — alias-only doesn't update the endpoint; `update_endpoint`-only leaves the alias pointing at the old version. + +```python +client.set_registered_model_alias(FULL_NAME, "prod", new_version) +client.update_endpoint(endpoint=ENDPOINT_NAME, config={ + "served_entities": [{"entity_name": FULL_NAME, "entity_version": new_version, + "workload_size": "Small", "scale_to_zero_enabled": True}], + "traffic_config": {"routes": [ + {"served_model_name": f"{NAME}-{new_version}", "traffic_percentage": 100} + ]}, +}) +``` + +The CLI equivalent is `databricks serving-endpoints update-config --json '...'`. Either way, poll both `state.ready` and `state.config_update` afterward — see Endpoint Readiness below. + ### Endpoint Readiness -After `create` or `update-config`, the endpoint provisions compute and loads the model. **Do not query the endpoint until it is ready.** +After `create` or `update-config`, the endpoint provisions compute and loads the model. **Do not query the endpoint until it is ready.** Two state fields matter and they mean different things: + +- `state.ready` — `READY` once the endpoint has any working config. Stays `READY` during a version swap. +- `state.config_update` — `NOT_UPDATING` once the *current* config update finishes; `IN_PROGRESS` during a version swap. -Poll for readiness: +A loop watching only `state.ready` will say "ready" mid version-swap while the old version is still serving. **Poll both:** ```bash -databricks serving-endpoints get --profile -o json -# Ready when: state.ready == "READY" AND state.config_update == "NOT_UPDATING" +databricks serving-endpoints get --profile \ + | jq '{ready: .state.ready, config_update: .state.config_update}' +# Fully ready when ready == "READY" AND config_update == "NOT_UPDATING" ``` -Provisioning may take several minutes. Provisioned throughput endpoints take the longest (GPU allocation). Queries to endpoints that are not yet `READY` return 404 or 503 errors. +Provisioning may take several minutes. Provisioned throughput endpoints take the longest (GPU allocation). Queries to endpoints that are not yet `READY` return 404 or 503. ## Query an Endpoint +Chat / agent endpoints use the `messages` array: + ```bash databricks serving-endpoints query \ - --json '{"messages": [{"role": "user", "content": "Hello, how are you?"}]}' \ - --profile + --json '{"messages": [{"role": "user", "content": "Hello"}]}' --profile ``` -- Use `--stream` for streaming responses. -- For non-chat endpoints (embeddings, custom models): use `get-open-api ` first to discover the request/response schema, then construct the appropriate JSON payload. +Classical-ML endpoints use `dataframe_records` (one record per row): + +```bash +databricks serving-endpoints query \ + --json '{"dataframe_records": [{"vibration": 0.42, "rpm": 18.3, "temp_c": 71.2}]}' +``` + +- Use `--stream` for streaming responses on chat endpoints. +- For embeddings or other custom schemas: use `get-open-api ` first to discover the request/response shape. ## Get Endpoint Schema (OpenAPI) @@ -177,6 +213,30 @@ env: Then add a tRPC route to call it from your app. For the full app integration pattern, use the **`databricks-apps`** skill and read the [Model Serving Guide](../databricks-apps/references/appkit/model-serving.md). +### Develop & deploy new models + +This skill is ops-focused (manage existing endpoints). For the dev-side flow — training, MLflow tracking, UC registration, custom PyFunc authoring, and hand-rolled `ResponsesAgent` code — see **[databricks-ml-training](../../experimental/databricks-ml-training/SKILL.md)** (experimental). + +## Foundation Model API endpoints + +Pay-per-token, pre-provisioned in every workspace. New models land regularly and a static skill list goes stale fast — **always list at runtime instead of hard-coding names**. Filter by the `databricks-` name prefix AND by the served entity being in `system.ai.*` (other endpoints like `databricks-app-template-serving` share the prefix but aren't FM API endpoints). + +```bash +# FM API endpoints in this workspace, grouped by task (chat / embeddings / etc.) +databricks serving-endpoints list \ + | jq -r '.[] + | select(.name | startswith("databricks-")) + | select((.config.served_entities[0].entity_name // "") | startswith("system.ai.")) + | "\(.task)\t\(.name)"' \ + | sort +``` + +**Defaults when the user doesn't specify**: pick the highest-numbered Claude Sonnet for agents, the highest-numbered `-codex-max` for code, `databricks-gte-large-en` for embeddings — resolve actual names from the live list above. + +## Off-platform streaming + +For apps deployed **outside** Databricks Apps (Vercel, AWS, standalone Node.js) hitting Databricks AI Gateway with Vercel AI SDK v6, see [references/off-platform-streaming.md](references/off-platform-streaming.md). For AppKit-based apps, use the `databricks-apps` skill's built-in serving plugin instead. + ## Troubleshooting | Error | Solution | @@ -187,3 +247,4 @@ Then add a tRPC route to call it from your app. For the full app integration pat | `RESOURCE_DOES_NOT_EXIST` | Verify endpoint name with `list` | | Query returns 404 | Endpoint may still be provisioning; check `state.ready` via `get` | | `RATE_LIMIT_EXCEEDED` (429) | AI Gateway rate limit; check `put-ai-gateway` config or retry after backoff | +| Endpoint missing from the Serving UI after deploy | UI filter defaults to "Owned by me". Deploy jobs run as a service principal, so the endpoint is hidden until you switch to "All". `databricks serving-endpoints list` always shows it. |