diff --git a/.gitconfig b/python/.gitconfig similarity index 100% rename from .gitconfig rename to python/.gitconfig diff --git a/python/codegen/codegen/generated_enum.py b/python/codegen/codegen/generated_enum.py index f2eecf82e4..c4f6e8ad00 100644 --- a/python/codegen/codegen/generated_enum.py +++ b/python/codegen/codegen/generated_enum.py @@ -80,6 +80,7 @@ def get_code(generated: GeneratedEnum) -> str: def _camel_to_upper_snake(value): - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", value) + s1 = value.replace("-", "_") # TODO: regex sanitzer + s2 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s1) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).upper() + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s2).upper() diff --git a/python/codegen/codegen/main.py b/python/codegen/codegen/main.py index df26810338..57b5eba184 100644 --- a/python/codegen/codegen/main.py +++ b/python/codegen/codegen/main.py @@ -123,6 +123,43 @@ def _generate_code( return dataclasses, enums +def _get_cross_namespace_imports( + namespace: str, + dataclasses: dict[str, GeneratedDataclass], + enums: dict[str, GeneratedEnum], +) -> tuple[list[str], dict[str, list[str]]]: + """ + Collect cross-namespace imports (types from other packages referenced by this namespace). + Returns (export_list, import_dict) where import_dict maps package -> list of class names. + """ + cross_namespace_packages = {} + cross_namespace_exports = [] + + root_package = packages.get_root_package(namespace) + + # Collect all referenced packages from dataclasses + for dataclass in dataclasses.values(): + for field in dataclass.fields: + if field.type_name.package and field.type_name.package != root_package: + if field.type_name.package not in cross_namespace_packages: + cross_namespace_packages[field.type_name.package] = [] + class_name = field.type_name.name.strip('"') + if class_name not in cross_namespace_packages[field.type_name.package]: + cross_namespace_packages[field.type_name.package].append(class_name) + cross_namespace_exports.append(class_name) + + # Collect all referenced packages from enums + for enum in enums.values(): + if enum.package and enum.package != root_package: + if enum.package not in cross_namespace_packages: + cross_namespace_packages[enum.package] = [] + if enum.class_name not in cross_namespace_packages[enum.package]: + cross_namespace_packages[enum.package].append(enum.class_name) + cross_namespace_exports.append(enum.class_name) + + return cross_namespace_exports, cross_namespace_packages + + def _write_exports( namespace: str, dataclasses: dict[str, GeneratedDataclass], @@ -141,6 +178,13 @@ def _write_exports( for _, enum in enums.items(): exports += [enum.class_name, f"{enum.class_name}Param"] + # Add cross-namespace imports to exports + cross_ns_exports, cross_ns_packages = _get_cross_namespace_imports( + namespace, dataclasses, enums + ) + exports.extend(cross_ns_exports) + + exports = list(set(exports)) # Remove duplicates exports.sort() b = CodeBuilder() @@ -155,6 +199,14 @@ def _write_exports( generated_imports.append_dataclass_imports(b, dataclasses, exclude_packages=[]) generated_imports.append_enum_imports(b, enums, exclude_packages=[]) + # Add cross-namespace re-exports + for package, class_names in sorted(cross_ns_packages.items()): + for class_name in sorted(set(class_names)): + b.append(f"from {package} import {class_name}").newline() + + if cross_ns_packages: + b.newline() + # FIXME should be better generalized if namespace == "jobs": _append_resolve_recursive_imports(b) diff --git a/python/codegen/codegen/packages.py b/python/codegen/codegen/packages.py index 48fe8270ab..14a911a8c2 100644 --- a/python/codegen/codegen/packages.py +++ b/python/codegen/codegen/packages.py @@ -1,3 +1,4 @@ +import os import re from typing import Optional @@ -7,6 +8,7 @@ "resources.Pipeline": "pipelines", "resources.Schema": "schemas", "resources.Volume": "volumes", + "resources.ModelServingEndpoint": "model_serving_endpoints", # serving } RESOURCE_TYPES = list(RESOURCE_NAMESPACE.keys()) @@ -20,6 +22,8 @@ "pipelines", "resources", "catalog", + "model_serving_endpoints", + "serving", # this exists within model_serving_endpoints and for some reason needs to be loaded separately ] RENAMES = { diff --git a/python/databricks/bundles/core/__init__.py b/python/databricks/bundles/core/__init__.py index 5c525861ac..55510ab1ec 100644 --- a/python/databricks/bundles/core/__init__.py +++ b/python/databricks/bundles/core/__init__.py @@ -17,6 +17,7 @@ "load_resources_from_module", "load_resources_from_modules", "load_resources_from_package_module", + "model_serving_endpoint_mutator", "pipeline_mutator", "schema_mutator", "variables", @@ -40,6 +41,7 @@ from databricks.bundles.core._resource_mutator import ( ResourceMutator, job_mutator, + model_serving_endpoint_mutator, pipeline_mutator, schema_mutator, volume_mutator, diff --git a/python/databricks/bundles/core/_resource_mutator.py b/python/databricks/bundles/core/_resource_mutator.py index 90e8987216..a0a4d5963b 100644 --- a/python/databricks/bundles/core/_resource_mutator.py +++ b/python/databricks/bundles/core/_resource_mutator.py @@ -7,6 +7,9 @@ if TYPE_CHECKING: from databricks.bundles.jobs._models.job import Job + from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint import ( + ModelServingEndpoint, + ) from databricks.bundles.pipelines._models.pipeline import Pipeline from databricks.bundles.schemas._models.schema import Schema from databricks.bundles.volumes._models.volume import Volume @@ -193,3 +196,39 @@ def my_volume_mutator(bundle: Bundle, volume: Volume) -> Volume: from databricks.bundles.volumes._models.volume import Volume return ResourceMutator(resource_type=Volume, function=function) + + +@overload +def model_serving_endpoint_mutator( + function: Callable[[Bundle, "ModelServingEndpoint"], "ModelServingEndpoint"], +) -> ResourceMutator["ModelServingEndpoint"]: ... + + +@overload +def model_serving_endpoint_mutator( + function: Callable[["ModelServingEndpoint"], "ModelServingEndpoint"], +) -> ResourceMutator["ModelServingEndpoint"]: ... + + +def model_serving_endpoint_mutator( + function: Callable, +) -> ResourceMutator["ModelServingEndpoint"]: + """ + Decorator for defining a model serving endpoint mutator. Function should return a new instance of the model serving endpoint with the desired changes, + instead of mutating the input model serving endpoint. + + Example: + + .. code-block:: python + + @model_serving_endpoint_mutator + def my_model_serving_endpoint_mutator(bundle: Bundle, model_serving_endpoint: ModelServingEndpoint) -> ModelServingEndpoint: + return replace(model_serving_endpoint, name="my_model_serving_endpoint") + + :param function: Function that mutates a model serving endpoint. + """ + from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint import ( + ModelServingEndpoint, + ) + + return ResourceMutator(resource_type=ModelServingEndpoint, function=function) diff --git a/python/databricks/bundles/core/_resource_type.py b/python/databricks/bundles/core/_resource_type.py index 9e9bb1bdf8..cce0e84969 100644 --- a/python/databricks/bundles/core/_resource_type.py +++ b/python/databricks/bundles/core/_resource_type.py @@ -32,6 +32,9 @@ def all(cls) -> tuple["_ResourceType", ...]: # be imported in databricks.bundles. from databricks.bundles.jobs._models.job import Job + from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint import ( + ModelServingEndpoint, + ) from databricks.bundles.pipelines._models.pipeline import Pipeline from databricks.bundles.schemas._models.schema import Schema from databricks.bundles.volumes._models.volume import Volume @@ -57,4 +60,9 @@ def all(cls) -> tuple["_ResourceType", ...]: plural_name="schemas", singular_name="schema", ), + _ResourceType( + resource_type=ModelServingEndpoint, + plural_name="model_serving_endpoints", + singular_name="model_serving_endpoint", + ), ) diff --git a/python/databricks/bundles/core/_resources.py b/python/databricks/bundles/core/_resources.py index 9be121718e..cbca43d5fa 100644 --- a/python/databricks/bundles/core/_resources.py +++ b/python/databricks/bundles/core/_resources.py @@ -9,6 +9,10 @@ from databricks.bundles.jobs._models.job import Job, JobParam from databricks.bundles.pipelines._models.pipeline import Pipeline, PipelineParam from databricks.bundles.schemas._models.schema import Schema, SchemaParam + from databricks.bundles.serving._models.model_serving_endpoint import ( + ModelServingEndpoint, + ModelServingEndpointParam, + ) from databricks.bundles.volumes._models.volume import Volume, VolumeParam __all__ = ["Resources"] @@ -60,6 +64,7 @@ def __init__(self): self._pipelines = dict[str, "Pipeline"]() self._schemas = dict[str, "Schema"]() self._volumes = dict[str, "Volume"]() + self._model_serving_endpoints = dict[str, "ModelServingEndpoint"]() self._locations = dict[tuple[str, ...], Location]() self._diagnostics = Diagnostics() @@ -79,6 +84,10 @@ def schemas(self) -> dict[str, "Schema"]: def volumes(self) -> dict[str, "Volume"]: return self._volumes + @property + def model_serving_endpoints(self) -> dict[str, "ModelServingEndpoint"]: + return self._model_serving_endpoints + @property def diagnostics(self) -> Diagnostics: """ @@ -103,6 +112,7 @@ def add_resource( """ from databricks.bundles.jobs import Job + from databricks.bundles.model_serving_endpoints import ModelServingEndpoint from databricks.bundles.pipelines import Pipeline from databricks.bundles.schemas import Schema from databricks.bundles.volumes import Volume @@ -118,6 +128,10 @@ def add_resource( self.add_schema(resource_name, resource, location=location) case Volume(): self.add_volume(resource_name, resource, location=location) + case ModelServingEndpoint(): + self.add_model_serving_endpoint( + resource_name, resource, location=location + ) case _: raise ValueError(f"Unsupported resource type: {type(resource)}") @@ -249,6 +263,40 @@ def add_volume( self._volumes[resource_name] = volume + def add_model_serving_endpoint( + self, + resource_name: str, + model_serving_endpoint: "ModelServingEndpointParam", + *, + location: Optional[Location] = None, + ) -> None: + """ + Adds a model serving endpoint to the collection of resources. Resource name must be unique across all model serving endpoints. + + :param resource_name: unique identifier for the model serving endpoint + :param model_serving_endpoint: the model serving endpoint to add, can be ModelServingEndpoint or dict + :param location: optional location of the model serving endpoint in the source code + """ + from databricks.bundles.model_serving_endpoints import ModelServingEndpoint + + model_serving_endpoint = _transform( + ModelServingEndpoint, model_serving_endpoint + ) + path = ("resources", "model_serving_endpoints", resource_name) + location = location or Location.from_stack_frame(depth=1) + + if self._model_serving_endpoints.get(resource_name): + self.add_diagnostic_error( + msg=f"Duplicate resource name '{resource_name}' for a model serving endpoint. Resource names must be unique.", + location=location, + path=path, + ) + else: + if location: + self.add_location(path, location) + + self._model_serving_endpoints[resource_name] = model_serving_endpoint + def add_location(self, path: tuple[str, ...], location: Location) -> None: """ Associate source code location with a path in the bundle configuration. @@ -331,6 +379,9 @@ def add_resources(self, other: "Resources") -> None: for name, volume in other.volumes.items(): self.add_volume(name, volume) + for name, model_serving_endpoint in other.model_serving_endpoints.items(): + self.add_model_serving_endpoint(name, model_serving_endpoint) + for path, location in other._locations.items(): self.add_location(path, location) diff --git a/python/databricks/bundles/jobs/__init__.py b/python/databricks/bundles/jobs/__init__.py index a1f36e2399..17f2f2279c 100644 --- a/python/databricks/bundles/jobs/__init__.py +++ b/python/databricks/bundles/jobs/__init__.py @@ -259,6 +259,10 @@ "TriggerSettings", "TriggerSettingsDict", "TriggerSettingsParam", + "VariableOr", + "VariableOrDict", + "VariableOrList", + "VariableOrOptional", "VolumesStorageInfo", "VolumesStorageInfoDict", "VolumesStorageInfoParam", @@ -277,6 +281,12 @@ ] +from databricks.bundles.core import ( + VariableOr, + VariableOrDict, + VariableOrList, + VariableOrOptional, +) from databricks.bundles.jobs._models.adlsgen2_info import ( Adlsgen2Info, Adlsgen2InfoDict, diff --git a/python/databricks/bundles/model_serving_endpoints/__init__.py b/python/databricks/bundles/model_serving_endpoints/__init__.py new file mode 100644 index 0000000000..3ebc51e1b8 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/__init__.py @@ -0,0 +1,330 @@ +__all__ = [ + "Ai21LabsConfig", + "Ai21LabsConfigDict", + "Ai21LabsConfigParam", + "AiGatewayConfig", + "AiGatewayConfigDict", + "AiGatewayConfigParam", + "AiGatewayGuardrailParameters", + "AiGatewayGuardrailParametersDict", + "AiGatewayGuardrailParametersParam", + "AiGatewayGuardrailPiiBehavior", + "AiGatewayGuardrailPiiBehaviorBehavior", + "AiGatewayGuardrailPiiBehaviorBehaviorParam", + "AiGatewayGuardrailPiiBehaviorDict", + "AiGatewayGuardrailPiiBehaviorParam", + "AiGatewayGuardrails", + "AiGatewayGuardrailsDict", + "AiGatewayGuardrailsParam", + "AiGatewayInferenceTableConfig", + "AiGatewayInferenceTableConfigDict", + "AiGatewayInferenceTableConfigParam", + "AiGatewayRateLimit", + "AiGatewayRateLimitDict", + "AiGatewayRateLimitKey", + "AiGatewayRateLimitKeyParam", + "AiGatewayRateLimitParam", + "AiGatewayRateLimitRenewalPeriod", + "AiGatewayRateLimitRenewalPeriodParam", + "AiGatewayUsageTrackingConfig", + "AiGatewayUsageTrackingConfigDict", + "AiGatewayUsageTrackingConfigParam", + "AmazonBedrockConfig", + "AmazonBedrockConfigBedrockProvider", + "AmazonBedrockConfigBedrockProviderParam", + "AmazonBedrockConfigDict", + "AmazonBedrockConfigParam", + "AnthropicConfig", + "AnthropicConfigDict", + "AnthropicConfigParam", + "ApiKeyAuth", + "ApiKeyAuthDict", + "ApiKeyAuthParam", + "AutoCaptureConfigInput", + "AutoCaptureConfigInputDict", + "AutoCaptureConfigInputParam", + "BearerTokenAuth", + "BearerTokenAuthDict", + "BearerTokenAuthParam", + "CohereConfig", + "CohereConfigDict", + "CohereConfigParam", + "CustomProviderConfig", + "CustomProviderConfigDict", + "CustomProviderConfigParam", + "DatabricksModelServingConfig", + "DatabricksModelServingConfigDict", + "DatabricksModelServingConfigParam", + "EmailNotifications", + "EmailNotificationsDict", + "EmailNotificationsParam", + "EndpointCoreConfigInput", + "EndpointCoreConfigInputDict", + "EndpointCoreConfigInputParam", + "EndpointTag", + "EndpointTagDict", + "EndpointTagParam", + "ExternalModel", + "ExternalModelDict", + "ExternalModelParam", + "ExternalModelProvider", + "ExternalModelProviderParam", + "FallbackConfig", + "FallbackConfigDict", + "FallbackConfigParam", + "GoogleCloudVertexAiConfig", + "GoogleCloudVertexAiConfigDict", + "GoogleCloudVertexAiConfigParam", + "Lifecycle", + "LifecycleDict", + "LifecycleParam", + "ModelServingEndpoint", + "ModelServingEndpointDict", + "ModelServingEndpointParam", + "ModelServingEndpointPermission", + "ModelServingEndpointPermissionDict", + "ModelServingEndpointPermissionLevel", + "ModelServingEndpointPermissionLevelParam", + "ModelServingEndpointPermissionParam", + "OpenAiConfig", + "OpenAiConfigDict", + "OpenAiConfigParam", + "PaLmConfig", + "PaLmConfigDict", + "PaLmConfigParam", + "RateLimit", + "RateLimitDict", + "RateLimitKey", + "RateLimitKeyParam", + "RateLimitParam", + "RateLimitRenewalPeriod", + "RateLimitRenewalPeriodParam", + "Route", + "RouteDict", + "RouteParam", + "ServedEntityInput", + "ServedEntityInputDict", + "ServedEntityInputParam", + "ServedModelInput", + "ServedModelInputDict", + "ServedModelInputParam", + "ServedModelInputWorkloadType", + "ServedModelInputWorkloadTypeParam", + "ServingModelWorkloadType", + "ServingModelWorkloadTypeParam", + "TrafficConfig", + "TrafficConfigDict", + "TrafficConfigParam", + "VariableOr", + "VariableOrDict", + "VariableOrList", + "VariableOrOptional", +] + + +from databricks.bundles.core import ( + VariableOr, + VariableOrDict, + VariableOrList, + VariableOrOptional, +) +from databricks.bundles.model_serving_endpoints._models.ai21_labs_config import ( + Ai21LabsConfig, + Ai21LabsConfigDict, + Ai21LabsConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_config import ( + AiGatewayConfig, + AiGatewayConfigDict, + AiGatewayConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrail_parameters import ( + AiGatewayGuardrailParameters, + AiGatewayGuardrailParametersDict, + AiGatewayGuardrailParametersParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrail_pii_behavior import ( + AiGatewayGuardrailPiiBehavior, + AiGatewayGuardrailPiiBehaviorDict, + AiGatewayGuardrailPiiBehaviorParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrail_pii_behavior_behavior import ( + AiGatewayGuardrailPiiBehaviorBehavior, + AiGatewayGuardrailPiiBehaviorBehaviorParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrails import ( + AiGatewayGuardrails, + AiGatewayGuardrailsDict, + AiGatewayGuardrailsParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_inference_table_config import ( + AiGatewayInferenceTableConfig, + AiGatewayInferenceTableConfigDict, + AiGatewayInferenceTableConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_rate_limit import ( + AiGatewayRateLimit, + AiGatewayRateLimitDict, + AiGatewayRateLimitParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_rate_limit_key import ( + AiGatewayRateLimitKey, + AiGatewayRateLimitKeyParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_rate_limit_renewal_period import ( + AiGatewayRateLimitRenewalPeriod, + AiGatewayRateLimitRenewalPeriodParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_usage_tracking_config import ( + AiGatewayUsageTrackingConfig, + AiGatewayUsageTrackingConfigDict, + AiGatewayUsageTrackingConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.amazon_bedrock_config import ( + AmazonBedrockConfig, + AmazonBedrockConfigDict, + AmazonBedrockConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.amazon_bedrock_config_bedrock_provider import ( + AmazonBedrockConfigBedrockProvider, + AmazonBedrockConfigBedrockProviderParam, +) +from databricks.bundles.model_serving_endpoints._models.anthropic_config import ( + AnthropicConfig, + AnthropicConfigDict, + AnthropicConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.api_key_auth import ( + ApiKeyAuth, + ApiKeyAuthDict, + ApiKeyAuthParam, +) +from databricks.bundles.model_serving_endpoints._models.auto_capture_config_input import ( + AutoCaptureConfigInput, + AutoCaptureConfigInputDict, + AutoCaptureConfigInputParam, +) +from databricks.bundles.model_serving_endpoints._models.bearer_token_auth import ( + BearerTokenAuth, + BearerTokenAuthDict, + BearerTokenAuthParam, +) +from databricks.bundles.model_serving_endpoints._models.cohere_config import ( + CohereConfig, + CohereConfigDict, + CohereConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.custom_provider_config import ( + CustomProviderConfig, + CustomProviderConfigDict, + CustomProviderConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.databricks_model_serving_config import ( + DatabricksModelServingConfig, + DatabricksModelServingConfigDict, + DatabricksModelServingConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.email_notifications import ( + EmailNotifications, + EmailNotificationsDict, + EmailNotificationsParam, +) +from databricks.bundles.model_serving_endpoints._models.endpoint_core_config_input import ( + EndpointCoreConfigInput, + EndpointCoreConfigInputDict, + EndpointCoreConfigInputParam, +) +from databricks.bundles.model_serving_endpoints._models.endpoint_tag import ( + EndpointTag, + EndpointTagDict, + EndpointTagParam, +) +from databricks.bundles.model_serving_endpoints._models.external_model import ( + ExternalModel, + ExternalModelDict, + ExternalModelParam, +) +from databricks.bundles.model_serving_endpoints._models.external_model_provider import ( + ExternalModelProvider, + ExternalModelProviderParam, +) +from databricks.bundles.model_serving_endpoints._models.fallback_config import ( + FallbackConfig, + FallbackConfigDict, + FallbackConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.google_cloud_vertex_ai_config import ( + GoogleCloudVertexAiConfig, + GoogleCloudVertexAiConfigDict, + GoogleCloudVertexAiConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.lifecycle import ( + Lifecycle, + LifecycleDict, + LifecycleParam, +) +from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint import ( + ModelServingEndpoint, + ModelServingEndpointDict, + ModelServingEndpointParam, +) +from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint_permission import ( + ModelServingEndpointPermission, + ModelServingEndpointPermissionDict, + ModelServingEndpointPermissionParam, +) +from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint_permission_level import ( + ModelServingEndpointPermissionLevel, + ModelServingEndpointPermissionLevelParam, +) +from databricks.bundles.model_serving_endpoints._models.open_ai_config import ( + OpenAiConfig, + OpenAiConfigDict, + OpenAiConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.pa_lm_config import ( + PaLmConfig, + PaLmConfigDict, + PaLmConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.rate_limit import ( + RateLimit, + RateLimitDict, + RateLimitParam, +) +from databricks.bundles.model_serving_endpoints._models.rate_limit_key import ( + RateLimitKey, + RateLimitKeyParam, +) +from databricks.bundles.model_serving_endpoints._models.rate_limit_renewal_period import ( + RateLimitRenewalPeriod, + RateLimitRenewalPeriodParam, +) +from databricks.bundles.model_serving_endpoints._models.route import ( + Route, + RouteDict, + RouteParam, +) +from databricks.bundles.model_serving_endpoints._models.served_entity_input import ( + ServedEntityInput, + ServedEntityInputDict, + ServedEntityInputParam, +) +from databricks.bundles.model_serving_endpoints._models.served_model_input import ( + ServedModelInput, + ServedModelInputDict, + ServedModelInputParam, +) +from databricks.bundles.model_serving_endpoints._models.served_model_input_workload_type import ( + ServedModelInputWorkloadType, + ServedModelInputWorkloadTypeParam, +) +from databricks.bundles.model_serving_endpoints._models.serving_model_workload_type import ( + ServingModelWorkloadType, + ServingModelWorkloadTypeParam, +) +from databricks.bundles.model_serving_endpoints._models.traffic_config import ( + TrafficConfig, + TrafficConfigDict, + TrafficConfigParam, +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai21_labs_config.py b/python/databricks/bundles/model_serving_endpoints/_models/ai21_labs_config.py new file mode 100644 index 0000000000..3740c8af26 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai21_labs_config.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class Ai21LabsConfig: + """""" + + ai21labs_api_key: VariableOrOptional[str] = None + """ + The Databricks secret key reference for an AI21 Labs API key. If you + prefer to paste your API key directly, see `ai21labs_api_key_plaintext`. + You must provide an API key using one of the following fields: + `ai21labs_api_key` or `ai21labs_api_key_plaintext`. + """ + + ai21labs_api_key_plaintext: VariableOrOptional[str] = None + """ + An AI21 Labs API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `ai21labs_api_key`. You + must provide an API key using one of the following fields: + `ai21labs_api_key` or `ai21labs_api_key_plaintext`. + """ + + @classmethod + def from_dict(cls, value: "Ai21LabsConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "Ai21LabsConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class Ai21LabsConfigDict(TypedDict, total=False): + """""" + + ai21labs_api_key: VariableOrOptional[str] + """ + The Databricks secret key reference for an AI21 Labs API key. If you + prefer to paste your API key directly, see `ai21labs_api_key_plaintext`. + You must provide an API key using one of the following fields: + `ai21labs_api_key` or `ai21labs_api_key_plaintext`. + """ + + ai21labs_api_key_plaintext: VariableOrOptional[str] + """ + An AI21 Labs API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `ai21labs_api_key`. You + must provide an API key using one of the following fields: + `ai21labs_api_key` or `ai21labs_api_key_plaintext`. + """ + + +Ai21LabsConfigParam = Ai21LabsConfigDict | Ai21LabsConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_config.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_config.py new file mode 100644 index 0000000000..ba04ebf748 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_config.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrList, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrails import ( + AiGatewayGuardrails, + AiGatewayGuardrailsParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_inference_table_config import ( + AiGatewayInferenceTableConfig, + AiGatewayInferenceTableConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_rate_limit import ( + AiGatewayRateLimit, + AiGatewayRateLimitParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_usage_tracking_config import ( + AiGatewayUsageTrackingConfig, + AiGatewayUsageTrackingConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.fallback_config import ( + FallbackConfig, + FallbackConfigParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AiGatewayConfig: + """""" + + fallback_config: VariableOrOptional[FallbackConfig] = None + """ + Configuration for traffic fallback which auto fallbacks to other served entities if the request to a served + entity fails with certain error codes, to increase availability. + """ + + guardrails: VariableOrOptional[AiGatewayGuardrails] = None + """ + Configuration for AI Guardrails to prevent unwanted data and unsafe data in requests and responses. + """ + + inference_table_config: VariableOrOptional[AiGatewayInferenceTableConfig] = None + """ + Configuration for payload logging using inference tables. + Use these tables to monitor and audit data being sent to and received from model APIs and to improve model quality. + """ + + rate_limits: VariableOrList[AiGatewayRateLimit] = field(default_factory=list) + """ + Configuration for rate limits which can be set to limit endpoint traffic. + """ + + usage_tracking_config: VariableOrOptional[AiGatewayUsageTrackingConfig] = None + """ + Configuration to enable usage tracking using system tables. + These tables allow you to monitor operational usage on endpoints and their associated costs. + """ + + @classmethod + def from_dict(cls, value: "AiGatewayConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AiGatewayConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class AiGatewayConfigDict(TypedDict, total=False): + """""" + + fallback_config: VariableOrOptional[FallbackConfigParam] + """ + Configuration for traffic fallback which auto fallbacks to other served entities if the request to a served + entity fails with certain error codes, to increase availability. + """ + + guardrails: VariableOrOptional[AiGatewayGuardrailsParam] + """ + Configuration for AI Guardrails to prevent unwanted data and unsafe data in requests and responses. + """ + + inference_table_config: VariableOrOptional[AiGatewayInferenceTableConfigParam] + """ + Configuration for payload logging using inference tables. + Use these tables to monitor and audit data being sent to and received from model APIs and to improve model quality. + """ + + rate_limits: VariableOrList[AiGatewayRateLimitParam] + """ + Configuration for rate limits which can be set to limit endpoint traffic. + """ + + usage_tracking_config: VariableOrOptional[AiGatewayUsageTrackingConfigParam] + """ + Configuration to enable usage tracking using system tables. + These tables allow you to monitor operational usage on endpoints and their associated costs. + """ + + +AiGatewayConfigParam = AiGatewayConfigDict | AiGatewayConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_parameters.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_parameters.py new file mode 100644 index 0000000000..a2f450f0f9 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_parameters.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrList, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrail_pii_behavior import ( + AiGatewayGuardrailPiiBehavior, + AiGatewayGuardrailPiiBehaviorParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AiGatewayGuardrailParameters: + """""" + + invalid_keywords: VariableOrList[str] = field(default_factory=list) + """ + [DEPRECATED] List of invalid keywords. + AI guardrail uses keyword or string matching to decide if the keyword exists in the request or response content. + """ + + pii: VariableOrOptional[AiGatewayGuardrailPiiBehavior] = None + """ + Configuration for guardrail PII filter. + """ + + safety: VariableOrOptional[bool] = None + """ + Indicates whether the safety filter is enabled. + """ + + valid_topics: VariableOrList[str] = field(default_factory=list) + """ + [DEPRECATED] The list of allowed topics. + Given a chat request, this guardrail flags the request if its topic is not in the allowed topics. + """ + + @classmethod + def from_dict(cls, value: "AiGatewayGuardrailParametersDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AiGatewayGuardrailParametersDict": + return _transform_to_json_value(self) # type:ignore + + +class AiGatewayGuardrailParametersDict(TypedDict, total=False): + """""" + + invalid_keywords: VariableOrList[str] + """ + [DEPRECATED] List of invalid keywords. + AI guardrail uses keyword or string matching to decide if the keyword exists in the request or response content. + """ + + pii: VariableOrOptional[AiGatewayGuardrailPiiBehaviorParam] + """ + Configuration for guardrail PII filter. + """ + + safety: VariableOrOptional[bool] + """ + Indicates whether the safety filter is enabled. + """ + + valid_topics: VariableOrList[str] + """ + [DEPRECATED] The list of allowed topics. + Given a chat request, this guardrail flags the request if its topic is not in the allowed topics. + """ + + +AiGatewayGuardrailParametersParam = ( + AiGatewayGuardrailParametersDict | AiGatewayGuardrailParameters +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_pii_behavior.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_pii_behavior.py new file mode 100644 index 0000000000..ede187e426 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_pii_behavior.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrail_pii_behavior_behavior import ( + AiGatewayGuardrailPiiBehaviorBehavior, + AiGatewayGuardrailPiiBehaviorBehaviorParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AiGatewayGuardrailPiiBehavior: + """""" + + behavior: VariableOrOptional[AiGatewayGuardrailPiiBehaviorBehavior] = None + """ + Configuration for input guardrail filters. + """ + + @classmethod + def from_dict(cls, value: "AiGatewayGuardrailPiiBehaviorDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AiGatewayGuardrailPiiBehaviorDict": + return _transform_to_json_value(self) # type:ignore + + +class AiGatewayGuardrailPiiBehaviorDict(TypedDict, total=False): + """""" + + behavior: VariableOrOptional[AiGatewayGuardrailPiiBehaviorBehaviorParam] + """ + Configuration for input guardrail filters. + """ + + +AiGatewayGuardrailPiiBehaviorParam = ( + AiGatewayGuardrailPiiBehaviorDict | AiGatewayGuardrailPiiBehavior +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_pii_behavior_behavior.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_pii_behavior_behavior.py new file mode 100644 index 0000000000..85ae2ea86b --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrail_pii_behavior_behavior.py @@ -0,0 +1,13 @@ +from enum import Enum +from typing import Literal + + +class AiGatewayGuardrailPiiBehaviorBehavior(Enum): + NONE = "NONE" + BLOCK = "BLOCK" + MASK = "MASK" + + +AiGatewayGuardrailPiiBehaviorBehaviorParam = ( + Literal["NONE", "BLOCK", "MASK"] | AiGatewayGuardrailPiiBehaviorBehavior +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrails.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrails.py new file mode 100644 index 0000000000..b11a088f61 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_guardrails.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.ai_gateway_guardrail_parameters import ( + AiGatewayGuardrailParameters, + AiGatewayGuardrailParametersParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AiGatewayGuardrails: + """""" + + input: VariableOrOptional[AiGatewayGuardrailParameters] = None + """ + Configuration for input guardrail filters. + """ + + output: VariableOrOptional[AiGatewayGuardrailParameters] = None + """ + Configuration for output guardrail filters. + """ + + @classmethod + def from_dict(cls, value: "AiGatewayGuardrailsDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AiGatewayGuardrailsDict": + return _transform_to_json_value(self) # type:ignore + + +class AiGatewayGuardrailsDict(TypedDict, total=False): + """""" + + input: VariableOrOptional[AiGatewayGuardrailParametersParam] + """ + Configuration for input guardrail filters. + """ + + output: VariableOrOptional[AiGatewayGuardrailParametersParam] + """ + Configuration for output guardrail filters. + """ + + +AiGatewayGuardrailsParam = AiGatewayGuardrailsDict | AiGatewayGuardrails diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_inference_table_config.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_inference_table_config.py new file mode 100644 index 0000000000..6e44aacdf6 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_inference_table_config.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AiGatewayInferenceTableConfig: + """""" + + catalog_name: VariableOrOptional[str] = None + """ + The name of the catalog in Unity Catalog. Required when enabling inference tables. + NOTE: On update, you have to disable inference table first in order to change the catalog name. + """ + + enabled: VariableOrOptional[bool] = None + """ + Indicates whether the inference table is enabled. + """ + + schema_name: VariableOrOptional[str] = None + """ + The name of the schema in Unity Catalog. Required when enabling inference tables. + NOTE: On update, you have to disable inference table first in order to change the schema name. + """ + + table_name_prefix: VariableOrOptional[str] = None + """ + The prefix of the table in Unity Catalog. + NOTE: On update, you have to disable inference table first in order to change the prefix name. + """ + + @classmethod + def from_dict(cls, value: "AiGatewayInferenceTableConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AiGatewayInferenceTableConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class AiGatewayInferenceTableConfigDict(TypedDict, total=False): + """""" + + catalog_name: VariableOrOptional[str] + """ + The name of the catalog in Unity Catalog. Required when enabling inference tables. + NOTE: On update, you have to disable inference table first in order to change the catalog name. + """ + + enabled: VariableOrOptional[bool] + """ + Indicates whether the inference table is enabled. + """ + + schema_name: VariableOrOptional[str] + """ + The name of the schema in Unity Catalog. Required when enabling inference tables. + NOTE: On update, you have to disable inference table first in order to change the schema name. + """ + + table_name_prefix: VariableOrOptional[str] + """ + The prefix of the table in Unity Catalog. + NOTE: On update, you have to disable inference table first in order to change the prefix name. + """ + + +AiGatewayInferenceTableConfigParam = ( + AiGatewayInferenceTableConfigDict | AiGatewayInferenceTableConfig +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit.py new file mode 100644 index 0000000000..46f109be09 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.ai_gateway_rate_limit_key import ( + AiGatewayRateLimitKey, + AiGatewayRateLimitKeyParam, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_rate_limit_renewal_period import ( + AiGatewayRateLimitRenewalPeriod, + AiGatewayRateLimitRenewalPeriodParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AiGatewayRateLimit: + """""" + + renewal_period: VariableOr[AiGatewayRateLimitRenewalPeriod] + """ + Renewal period field for a rate limit. Currently, only 'minute' is supported. + """ + + calls: VariableOrOptional[int] = None + """ + Used to specify how many calls are allowed for a key within the renewal_period. + """ + + key: VariableOrOptional[AiGatewayRateLimitKey] = None + """ + Key field for a rate limit. Currently, 'user', 'user_group, 'service_principal', and 'endpoint' are supported, + with 'endpoint' being the default if not specified. + """ + + principal: VariableOrOptional[str] = None + """ + Principal field for a user, user group, or service principal to apply rate limiting to. Accepts a user email, group name, or service principal application ID. + """ + + tokens: VariableOrOptional[int] = None + """ + Used to specify how many tokens are allowed for a key within the renewal_period. + """ + + @classmethod + def from_dict(cls, value: "AiGatewayRateLimitDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AiGatewayRateLimitDict": + return _transform_to_json_value(self) # type:ignore + + +class AiGatewayRateLimitDict(TypedDict, total=False): + """""" + + renewal_period: VariableOr[AiGatewayRateLimitRenewalPeriodParam] + """ + Renewal period field for a rate limit. Currently, only 'minute' is supported. + """ + + calls: VariableOrOptional[int] + """ + Used to specify how many calls are allowed for a key within the renewal_period. + """ + + key: VariableOrOptional[AiGatewayRateLimitKeyParam] + """ + Key field for a rate limit. Currently, 'user', 'user_group, 'service_principal', and 'endpoint' are supported, + with 'endpoint' being the default if not specified. + """ + + principal: VariableOrOptional[str] + """ + Principal field for a user, user group, or service principal to apply rate limiting to. Accepts a user email, group name, or service principal application ID. + """ + + tokens: VariableOrOptional[int] + """ + Used to specify how many tokens are allowed for a key within the renewal_period. + """ + + +AiGatewayRateLimitParam = AiGatewayRateLimitDict | AiGatewayRateLimit diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit_key.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit_key.py new file mode 100644 index 0000000000..6bd7f92388 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit_key.py @@ -0,0 +1,15 @@ +from enum import Enum +from typing import Literal + + +class AiGatewayRateLimitKey(Enum): + USER = "user" + ENDPOINT = "endpoint" + USER_GROUP = "user_group" + SERVICE_PRINCIPAL = "service_principal" + + +AiGatewayRateLimitKeyParam = ( + Literal["user", "endpoint", "user_group", "service_principal"] + | AiGatewayRateLimitKey +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit_renewal_period.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit_renewal_period.py new file mode 100644 index 0000000000..c528cd3c15 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_rate_limit_renewal_period.py @@ -0,0 +1,11 @@ +from enum import Enum +from typing import Literal + + +class AiGatewayRateLimitRenewalPeriod(Enum): + MINUTE = "minute" + + +AiGatewayRateLimitRenewalPeriodParam = ( + Literal["minute"] | AiGatewayRateLimitRenewalPeriod +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_usage_tracking_config.py b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_usage_tracking_config.py new file mode 100644 index 0000000000..a6a46b2565 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/ai_gateway_usage_tracking_config.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AiGatewayUsageTrackingConfig: + """""" + + enabled: VariableOrOptional[bool] = None + """ + Whether to enable usage tracking. + """ + + @classmethod + def from_dict(cls, value: "AiGatewayUsageTrackingConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AiGatewayUsageTrackingConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class AiGatewayUsageTrackingConfigDict(TypedDict, total=False): + """""" + + enabled: VariableOrOptional[bool] + """ + Whether to enable usage tracking. + """ + + +AiGatewayUsageTrackingConfigParam = ( + AiGatewayUsageTrackingConfigDict | AiGatewayUsageTrackingConfig +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/amazon_bedrock_config.py b/python/databricks/bundles/model_serving_endpoints/_models/amazon_bedrock_config.py new file mode 100644 index 0000000000..9286c3f3fd --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/amazon_bedrock_config.py @@ -0,0 +1,146 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.amazon_bedrock_config_bedrock_provider import ( + AmazonBedrockConfigBedrockProvider, + AmazonBedrockConfigBedrockProviderParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AmazonBedrockConfig: + """""" + + aws_region: VariableOr[str] + """ + The AWS region to use. Bedrock has to be enabled there. + """ + + bedrock_provider: VariableOr[AmazonBedrockConfigBedrockProvider] + """ + The underlying provider in Amazon Bedrock. Supported values (case + insensitive) include: Anthropic, Cohere, AI21Labs, Amazon. + """ + + aws_access_key_id: VariableOrOptional[str] = None + """ + The Databricks secret key reference for an AWS access key ID with + permissions to interact with Bedrock services. If you prefer to paste + your API key directly, see `aws_access_key_id_plaintext`. You must provide an API + key using one of the following fields: `aws_access_key_id` or + `aws_access_key_id_plaintext`. + """ + + aws_access_key_id_plaintext: VariableOrOptional[str] = None + """ + An AWS access key ID with permissions to interact with Bedrock services + provided as a plaintext string. If you prefer to reference your key using + Databricks Secrets, see `aws_access_key_id`. You must provide an API key + using one of the following fields: `aws_access_key_id` or + `aws_access_key_id_plaintext`. + """ + + aws_secret_access_key: VariableOrOptional[str] = None + """ + The Databricks secret key reference for an AWS secret access key paired + with the access key ID, with permissions to interact with Bedrock + services. If you prefer to paste your API key directly, see + `aws_secret_access_key_plaintext`. You must provide an API key using one + of the following fields: `aws_secret_access_key` or + `aws_secret_access_key_plaintext`. + """ + + aws_secret_access_key_plaintext: VariableOrOptional[str] = None + """ + An AWS secret access key paired with the access key ID, with permissions + to interact with Bedrock services provided as a plaintext string. If you + prefer to reference your key using Databricks Secrets, see + `aws_secret_access_key`. You must provide an API key using one of the + following fields: `aws_secret_access_key` or + `aws_secret_access_key_plaintext`. + """ + + instance_profile_arn: VariableOrOptional[str] = None + """ + ARN of the instance profile that the external model will use to access AWS resources. + You must authenticate using an instance profile or access keys. + If you prefer to authenticate using access keys, see `aws_access_key_id`, + `aws_access_key_id_plaintext`, `aws_secret_access_key` and `aws_secret_access_key_plaintext`. + """ + + @classmethod + def from_dict(cls, value: "AmazonBedrockConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AmazonBedrockConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class AmazonBedrockConfigDict(TypedDict, total=False): + """""" + + aws_region: VariableOr[str] + """ + The AWS region to use. Bedrock has to be enabled there. + """ + + bedrock_provider: VariableOr[AmazonBedrockConfigBedrockProviderParam] + """ + The underlying provider in Amazon Bedrock. Supported values (case + insensitive) include: Anthropic, Cohere, AI21Labs, Amazon. + """ + + aws_access_key_id: VariableOrOptional[str] + """ + The Databricks secret key reference for an AWS access key ID with + permissions to interact with Bedrock services. If you prefer to paste + your API key directly, see `aws_access_key_id_plaintext`. You must provide an API + key using one of the following fields: `aws_access_key_id` or + `aws_access_key_id_plaintext`. + """ + + aws_access_key_id_plaintext: VariableOrOptional[str] + """ + An AWS access key ID with permissions to interact with Bedrock services + provided as a plaintext string. If you prefer to reference your key using + Databricks Secrets, see `aws_access_key_id`. You must provide an API key + using one of the following fields: `aws_access_key_id` or + `aws_access_key_id_plaintext`. + """ + + aws_secret_access_key: VariableOrOptional[str] + """ + The Databricks secret key reference for an AWS secret access key paired + with the access key ID, with permissions to interact with Bedrock + services. If you prefer to paste your API key directly, see + `aws_secret_access_key_plaintext`. You must provide an API key using one + of the following fields: `aws_secret_access_key` or + `aws_secret_access_key_plaintext`. + """ + + aws_secret_access_key_plaintext: VariableOrOptional[str] + """ + An AWS secret access key paired with the access key ID, with permissions + to interact with Bedrock services provided as a plaintext string. If you + prefer to reference your key using Databricks Secrets, see + `aws_secret_access_key`. You must provide an API key using one of the + following fields: `aws_secret_access_key` or + `aws_secret_access_key_plaintext`. + """ + + instance_profile_arn: VariableOrOptional[str] + """ + ARN of the instance profile that the external model will use to access AWS resources. + You must authenticate using an instance profile or access keys. + If you prefer to authenticate using access keys, see `aws_access_key_id`, + `aws_access_key_id_plaintext`, `aws_secret_access_key` and `aws_secret_access_key_plaintext`. + """ + + +AmazonBedrockConfigParam = AmazonBedrockConfigDict | AmazonBedrockConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/amazon_bedrock_config_bedrock_provider.py b/python/databricks/bundles/model_serving_endpoints/_models/amazon_bedrock_config_bedrock_provider.py new file mode 100644 index 0000000000..465998a992 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/amazon_bedrock_config_bedrock_provider.py @@ -0,0 +1,15 @@ +from enum import Enum +from typing import Literal + + +class AmazonBedrockConfigBedrockProvider(Enum): + ANTHROPIC = "anthropic" + COHERE = "cohere" + AI21LABS = "ai21labs" + AMAZON = "amazon" + + +AmazonBedrockConfigBedrockProviderParam = ( + Literal["anthropic", "cohere", "ai21labs", "amazon"] + | AmazonBedrockConfigBedrockProvider +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/anthropic_config.py b/python/databricks/bundles/model_serving_endpoints/_models/anthropic_config.py new file mode 100644 index 0000000000..f16aa0fef9 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/anthropic_config.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AnthropicConfig: + """""" + + anthropic_api_key: VariableOrOptional[str] = None + """ + The Databricks secret key reference for an Anthropic API key. If you + prefer to paste your API key directly, see `anthropic_api_key_plaintext`. + You must provide an API key using one of the following fields: + `anthropic_api_key` or `anthropic_api_key_plaintext`. + """ + + anthropic_api_key_plaintext: VariableOrOptional[str] = None + """ + The Anthropic API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `anthropic_api_key`. You + must provide an API key using one of the following fields: + `anthropic_api_key` or `anthropic_api_key_plaintext`. + """ + + @classmethod + def from_dict(cls, value: "AnthropicConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AnthropicConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class AnthropicConfigDict(TypedDict, total=False): + """""" + + anthropic_api_key: VariableOrOptional[str] + """ + The Databricks secret key reference for an Anthropic API key. If you + prefer to paste your API key directly, see `anthropic_api_key_plaintext`. + You must provide an API key using one of the following fields: + `anthropic_api_key` or `anthropic_api_key_plaintext`. + """ + + anthropic_api_key_plaintext: VariableOrOptional[str] + """ + The Anthropic API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `anthropic_api_key`. You + must provide an API key using one of the following fields: + `anthropic_api_key` or `anthropic_api_key_plaintext`. + """ + + +AnthropicConfigParam = AnthropicConfigDict | AnthropicConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/api_key_auth.py b/python/databricks/bundles/model_serving_endpoints/_models/api_key_auth.py new file mode 100644 index 0000000000..4a8d58e72d --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/api_key_auth.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class ApiKeyAuth: + """""" + + key: VariableOr[str] + """ + The name of the API key parameter used for authentication. + """ + + value: VariableOrOptional[str] = None + """ + The Databricks secret key reference for an API Key. + If you prefer to paste your token directly, see `value_plaintext`. + """ + + value_plaintext: VariableOrOptional[str] = None + """ + The API Key provided as a plaintext string. If you prefer to reference your + token using Databricks Secrets, see `value`. + """ + + @classmethod + def from_dict(cls, value: "ApiKeyAuthDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "ApiKeyAuthDict": + return _transform_to_json_value(self) # type:ignore + + +class ApiKeyAuthDict(TypedDict, total=False): + """""" + + key: VariableOr[str] + """ + The name of the API key parameter used for authentication. + """ + + value: VariableOrOptional[str] + """ + The Databricks secret key reference for an API Key. + If you prefer to paste your token directly, see `value_plaintext`. + """ + + value_plaintext: VariableOrOptional[str] + """ + The API Key provided as a plaintext string. If you prefer to reference your + token using Databricks Secrets, see `value`. + """ + + +ApiKeyAuthParam = ApiKeyAuthDict | ApiKeyAuth diff --git a/python/databricks/bundles/model_serving_endpoints/_models/auto_capture_config_input.py b/python/databricks/bundles/model_serving_endpoints/_models/auto_capture_config_input.py new file mode 100644 index 0000000000..ceb9b7eb0d --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/auto_capture_config_input.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class AutoCaptureConfigInput: + """""" + + catalog_name: VariableOrOptional[str] = None + """ + The name of the catalog in Unity Catalog. NOTE: On update, you cannot change the catalog name if the inference table is already enabled. + """ + + enabled: VariableOrOptional[bool] = None + """ + Indicates whether the inference table is enabled. + """ + + schema_name: VariableOrOptional[str] = None + """ + The name of the schema in Unity Catalog. NOTE: On update, you cannot change the schema name if the inference table is already enabled. + """ + + table_name_prefix: VariableOrOptional[str] = None + """ + The prefix of the table in Unity Catalog. NOTE: On update, you cannot change the prefix name if the inference table is already enabled. + """ + + @classmethod + def from_dict(cls, value: "AutoCaptureConfigInputDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "AutoCaptureConfigInputDict": + return _transform_to_json_value(self) # type:ignore + + +class AutoCaptureConfigInputDict(TypedDict, total=False): + """""" + + catalog_name: VariableOrOptional[str] + """ + The name of the catalog in Unity Catalog. NOTE: On update, you cannot change the catalog name if the inference table is already enabled. + """ + + enabled: VariableOrOptional[bool] + """ + Indicates whether the inference table is enabled. + """ + + schema_name: VariableOrOptional[str] + """ + The name of the schema in Unity Catalog. NOTE: On update, you cannot change the schema name if the inference table is already enabled. + """ + + table_name_prefix: VariableOrOptional[str] + """ + The prefix of the table in Unity Catalog. NOTE: On update, you cannot change the prefix name if the inference table is already enabled. + """ + + +AutoCaptureConfigInputParam = AutoCaptureConfigInputDict | AutoCaptureConfigInput diff --git a/python/databricks/bundles/model_serving_endpoints/_models/bearer_token_auth.py b/python/databricks/bundles/model_serving_endpoints/_models/bearer_token_auth.py new file mode 100644 index 0000000000..3be88b4e89 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/bearer_token_auth.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class BearerTokenAuth: + """""" + + token: VariableOrOptional[str] = None + """ + The Databricks secret key reference for a token. + If you prefer to paste your token directly, see `token_plaintext`. + """ + + token_plaintext: VariableOrOptional[str] = None + """ + The token provided as a plaintext string. If you prefer to reference your + token using Databricks Secrets, see `token`. + """ + + @classmethod + def from_dict(cls, value: "BearerTokenAuthDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "BearerTokenAuthDict": + return _transform_to_json_value(self) # type:ignore + + +class BearerTokenAuthDict(TypedDict, total=False): + """""" + + token: VariableOrOptional[str] + """ + The Databricks secret key reference for a token. + If you prefer to paste your token directly, see `token_plaintext`. + """ + + token_plaintext: VariableOrOptional[str] + """ + The token provided as a plaintext string. If you prefer to reference your + token using Databricks Secrets, see `token`. + """ + + +BearerTokenAuthParam = BearerTokenAuthDict | BearerTokenAuth diff --git a/python/databricks/bundles/model_serving_endpoints/_models/cohere_config.py b/python/databricks/bundles/model_serving_endpoints/_models/cohere_config.py new file mode 100644 index 0000000000..52bbd262cd --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/cohere_config.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class CohereConfig: + """""" + + cohere_api_base: VariableOrOptional[str] = None + """ + This is an optional field to provide a customized base URL for the Cohere + API. If left unspecified, the standard Cohere base URL is used. + """ + + cohere_api_key: VariableOrOptional[str] = None + """ + The Databricks secret key reference for a Cohere API key. If you prefer + to paste your API key directly, see `cohere_api_key_plaintext`. You must + provide an API key using one of the following fields: `cohere_api_key` or + `cohere_api_key_plaintext`. + """ + + cohere_api_key_plaintext: VariableOrOptional[str] = None + """ + The Cohere API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `cohere_api_key`. You + must provide an API key using one of the following fields: + `cohere_api_key` or `cohere_api_key_plaintext`. + """ + + @classmethod + def from_dict(cls, value: "CohereConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "CohereConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class CohereConfigDict(TypedDict, total=False): + """""" + + cohere_api_base: VariableOrOptional[str] + """ + This is an optional field to provide a customized base URL for the Cohere + API. If left unspecified, the standard Cohere base URL is used. + """ + + cohere_api_key: VariableOrOptional[str] + """ + The Databricks secret key reference for a Cohere API key. If you prefer + to paste your API key directly, see `cohere_api_key_plaintext`. You must + provide an API key using one of the following fields: `cohere_api_key` or + `cohere_api_key_plaintext`. + """ + + cohere_api_key_plaintext: VariableOrOptional[str] + """ + The Cohere API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `cohere_api_key`. You + must provide an API key using one of the following fields: + `cohere_api_key` or `cohere_api_key_plaintext`. + """ + + +CohereConfigParam = CohereConfigDict | CohereConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/custom_provider_config.py b/python/databricks/bundles/model_serving_endpoints/_models/custom_provider_config.py new file mode 100644 index 0000000000..6c30e58cdc --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/custom_provider_config.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.api_key_auth import ( + ApiKeyAuth, + ApiKeyAuthParam, +) +from databricks.bundles.model_serving_endpoints._models.bearer_token_auth import ( + BearerTokenAuth, + BearerTokenAuthParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class CustomProviderConfig: + """ + Configs needed to create a custom provider model route. + """ + + custom_provider_url: VariableOr[str] + """ + This is a field to provide the URL of the custom provider API. + """ + + api_key_auth: VariableOrOptional[ApiKeyAuth] = None + """ + This is a field to provide API key authentication for the custom provider API. + You can only specify one authentication method. + """ + + bearer_token_auth: VariableOrOptional[BearerTokenAuth] = None + """ + This is a field to provide bearer token authentication for the custom provider API. + You can only specify one authentication method. + """ + + @classmethod + def from_dict(cls, value: "CustomProviderConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "CustomProviderConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class CustomProviderConfigDict(TypedDict, total=False): + """""" + + custom_provider_url: VariableOr[str] + """ + This is a field to provide the URL of the custom provider API. + """ + + api_key_auth: VariableOrOptional[ApiKeyAuthParam] + """ + This is a field to provide API key authentication for the custom provider API. + You can only specify one authentication method. + """ + + bearer_token_auth: VariableOrOptional[BearerTokenAuthParam] + """ + This is a field to provide bearer token authentication for the custom provider API. + You can only specify one authentication method. + """ + + +CustomProviderConfigParam = CustomProviderConfigDict | CustomProviderConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/databricks_model_serving_config.py b/python/databricks/bundles/model_serving_endpoints/_models/databricks_model_serving_config.py new file mode 100644 index 0000000000..9358896ca3 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/databricks_model_serving_config.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class DatabricksModelServingConfig: + """""" + + databricks_workspace_url: VariableOr[str] + """ + The URL of the Databricks workspace containing the model serving endpoint + pointed to by this external model. + """ + + databricks_api_token: VariableOrOptional[str] = None + """ + The Databricks secret key reference for a Databricks API token that + corresponds to a user or service principal with Can Query access to the + model serving endpoint pointed to by this external model. If you prefer + to paste your API key directly, see `databricks_api_token_plaintext`. You + must provide an API key using one of the following fields: + `databricks_api_token` or `databricks_api_token_plaintext`. + """ + + databricks_api_token_plaintext: VariableOrOptional[str] = None + """ + The Databricks API token that corresponds to a user or service principal + with Can Query access to the model serving endpoint pointed to by this + external model provided as a plaintext string. If you prefer to reference + your key using Databricks Secrets, see `databricks_api_token`. You must + provide an API key using one of the following fields: + `databricks_api_token` or `databricks_api_token_plaintext`. + """ + + @classmethod + def from_dict(cls, value: "DatabricksModelServingConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "DatabricksModelServingConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class DatabricksModelServingConfigDict(TypedDict, total=False): + """""" + + databricks_workspace_url: VariableOr[str] + """ + The URL of the Databricks workspace containing the model serving endpoint + pointed to by this external model. + """ + + databricks_api_token: VariableOrOptional[str] + """ + The Databricks secret key reference for a Databricks API token that + corresponds to a user or service principal with Can Query access to the + model serving endpoint pointed to by this external model. If you prefer + to paste your API key directly, see `databricks_api_token_plaintext`. You + must provide an API key using one of the following fields: + `databricks_api_token` or `databricks_api_token_plaintext`. + """ + + databricks_api_token_plaintext: VariableOrOptional[str] + """ + The Databricks API token that corresponds to a user or service principal + with Can Query access to the model serving endpoint pointed to by this + external model provided as a plaintext string. If you prefer to reference + your key using Databricks Secrets, see `databricks_api_token`. You must + provide an API key using one of the following fields: + `databricks_api_token` or `databricks_api_token_plaintext`. + """ + + +DatabricksModelServingConfigParam = ( + DatabricksModelServingConfigDict | DatabricksModelServingConfig +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/email_notifications.py b/python/databricks/bundles/model_serving_endpoints/_models/email_notifications.py new file mode 100644 index 0000000000..eb5dab77a3 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/email_notifications.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrList + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class EmailNotifications: + """""" + + on_update_failure: VariableOrList[str] = field(default_factory=list) + """ + A list of email addresses to be notified when an endpoint fails to update its configuration or state. + """ + + on_update_success: VariableOrList[str] = field(default_factory=list) + """ + A list of email addresses to be notified when an endpoint successfully updates its configuration or state. + """ + + @classmethod + def from_dict(cls, value: "EmailNotificationsDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "EmailNotificationsDict": + return _transform_to_json_value(self) # type:ignore + + +class EmailNotificationsDict(TypedDict, total=False): + """""" + + on_update_failure: VariableOrList[str] + """ + A list of email addresses to be notified when an endpoint fails to update its configuration or state. + """ + + on_update_success: VariableOrList[str] + """ + A list of email addresses to be notified when an endpoint successfully updates its configuration or state. + """ + + +EmailNotificationsParam = EmailNotificationsDict | EmailNotifications diff --git a/python/databricks/bundles/model_serving_endpoints/_models/endpoint_core_config_input.py b/python/databricks/bundles/model_serving_endpoints/_models/endpoint_core_config_input.py new file mode 100644 index 0000000000..ad3e496553 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/endpoint_core_config_input.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrList, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.auto_capture_config_input import ( + AutoCaptureConfigInput, + AutoCaptureConfigInputParam, +) +from databricks.bundles.model_serving_endpoints._models.served_entity_input import ( + ServedEntityInput, + ServedEntityInputParam, +) +from databricks.bundles.model_serving_endpoints._models.served_model_input import ( + ServedModelInput, + ServedModelInputParam, +) +from databricks.bundles.model_serving_endpoints._models.traffic_config import ( + TrafficConfig, + TrafficConfigParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class EndpointCoreConfigInput: + """""" + + auto_capture_config: VariableOrOptional[AutoCaptureConfigInput] = None + """ + Configuration for Inference Tables which automatically logs requests and responses to Unity Catalog. + Note: this field is deprecated for creating new provisioned throughput endpoints, + or updating existing provisioned throughput endpoints that never have inference table configured; + in these cases please use AI Gateway to manage inference tables. + """ + + served_entities: VariableOrList[ServedEntityInput] = field(default_factory=list) + """ + The list of served entities under the serving endpoint config. + """ + + served_models: VariableOrList[ServedModelInput] = field(default_factory=list) + """ + (Deprecated, use served_entities instead) The list of served models under the serving endpoint config. + """ + + traffic_config: VariableOrOptional[TrafficConfig] = None + """ + The traffic configuration associated with the serving endpoint config. + """ + + @classmethod + def from_dict(cls, value: "EndpointCoreConfigInputDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "EndpointCoreConfigInputDict": + return _transform_to_json_value(self) # type:ignore + + +class EndpointCoreConfigInputDict(TypedDict, total=False): + """""" + + auto_capture_config: VariableOrOptional[AutoCaptureConfigInputParam] + """ + Configuration for Inference Tables which automatically logs requests and responses to Unity Catalog. + Note: this field is deprecated for creating new provisioned throughput endpoints, + or updating existing provisioned throughput endpoints that never have inference table configured; + in these cases please use AI Gateway to manage inference tables. + """ + + served_entities: VariableOrList[ServedEntityInputParam] + """ + The list of served entities under the serving endpoint config. + """ + + served_models: VariableOrList[ServedModelInputParam] + """ + (Deprecated, use served_entities instead) The list of served models under the serving endpoint config. + """ + + traffic_config: VariableOrOptional[TrafficConfigParam] + """ + The traffic configuration associated with the serving endpoint config. + """ + + +EndpointCoreConfigInputParam = EndpointCoreConfigInputDict | EndpointCoreConfigInput diff --git a/python/databricks/bundles/model_serving_endpoints/_models/endpoint_tag.py b/python/databricks/bundles/model_serving_endpoints/_models/endpoint_tag.py new file mode 100644 index 0000000000..d6e7296063 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/endpoint_tag.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class EndpointTag: + """""" + + key: VariableOr[str] + """ + Key field for a serving endpoint tag. + """ + + value: VariableOrOptional[str] = None + """ + Optional value field for a serving endpoint tag. + """ + + @classmethod + def from_dict(cls, value: "EndpointTagDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "EndpointTagDict": + return _transform_to_json_value(self) # type:ignore + + +class EndpointTagDict(TypedDict, total=False): + """""" + + key: VariableOr[str] + """ + Key field for a serving endpoint tag. + """ + + value: VariableOrOptional[str] + """ + Optional value field for a serving endpoint tag. + """ + + +EndpointTagParam = EndpointTagDict | EndpointTag diff --git a/python/databricks/bundles/model_serving_endpoints/_models/external_model.py b/python/databricks/bundles/model_serving_endpoints/_models/external_model.py new file mode 100644 index 0000000000..5826c929cf --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/external_model.py @@ -0,0 +1,192 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.ai21_labs_config import ( + Ai21LabsConfig, + Ai21LabsConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.amazon_bedrock_config import ( + AmazonBedrockConfig, + AmazonBedrockConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.anthropic_config import ( + AnthropicConfig, + AnthropicConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.cohere_config import ( + CohereConfig, + CohereConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.custom_provider_config import ( + CustomProviderConfig, + CustomProviderConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.databricks_model_serving_config import ( + DatabricksModelServingConfig, + DatabricksModelServingConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.external_model_provider import ( + ExternalModelProvider, + ExternalModelProviderParam, +) +from databricks.bundles.model_serving_endpoints._models.google_cloud_vertex_ai_config import ( + GoogleCloudVertexAiConfig, + GoogleCloudVertexAiConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.open_ai_config import ( + OpenAiConfig, + OpenAiConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.pa_lm_config import ( + PaLmConfig, + PaLmConfigParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class ExternalModel: + """""" + + name: VariableOr[str] + """ + The name of the external model. + """ + + provider: VariableOr[ExternalModelProvider] + """ + The name of the provider for the external model. Currently, the supported providers are 'ai21labs', 'anthropic', 'amazon-bedrock', 'cohere', 'databricks-model-serving', 'google-cloud-vertex-ai', 'openai', 'palm', and 'custom'. + """ + + task: VariableOr[str] + """ + The task type of the external model. + """ + + ai21labs_config: VariableOrOptional[Ai21LabsConfig] = None + """ + AI21Labs Config. Only required if the provider is 'ai21labs'. + """ + + amazon_bedrock_config: VariableOrOptional[AmazonBedrockConfig] = None + """ + Amazon Bedrock Config. Only required if the provider is 'amazon-bedrock'. + """ + + anthropic_config: VariableOrOptional[AnthropicConfig] = None + """ + Anthropic Config. Only required if the provider is 'anthropic'. + """ + + cohere_config: VariableOrOptional[CohereConfig] = None + """ + Cohere Config. Only required if the provider is 'cohere'. + """ + + custom_provider_config: VariableOrOptional[CustomProviderConfig] = None + """ + Custom Provider Config. Only required if the provider is 'custom'. + """ + + databricks_model_serving_config: VariableOrOptional[ + DatabricksModelServingConfig + ] = None + """ + Databricks Model Serving Config. Only required if the provider is 'databricks-model-serving'. + """ + + google_cloud_vertex_ai_config: VariableOrOptional[GoogleCloudVertexAiConfig] = None + """ + Google Cloud Vertex AI Config. Only required if the provider is 'google-cloud-vertex-ai'. + """ + + openai_config: VariableOrOptional[OpenAiConfig] = None + """ + OpenAI Config. Only required if the provider is 'openai'. + """ + + palm_config: VariableOrOptional[PaLmConfig] = None + """ + PaLM Config. Only required if the provider is 'palm'. + """ + + @classmethod + def from_dict(cls, value: "ExternalModelDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "ExternalModelDict": + return _transform_to_json_value(self) # type:ignore + + +class ExternalModelDict(TypedDict, total=False): + """""" + + name: VariableOr[str] + """ + The name of the external model. + """ + + provider: VariableOr[ExternalModelProviderParam] + """ + The name of the provider for the external model. Currently, the supported providers are 'ai21labs', 'anthropic', 'amazon-bedrock', 'cohere', 'databricks-model-serving', 'google-cloud-vertex-ai', 'openai', 'palm', and 'custom'. + """ + + task: VariableOr[str] + """ + The task type of the external model. + """ + + ai21labs_config: VariableOrOptional[Ai21LabsConfigParam] + """ + AI21Labs Config. Only required if the provider is 'ai21labs'. + """ + + amazon_bedrock_config: VariableOrOptional[AmazonBedrockConfigParam] + """ + Amazon Bedrock Config. Only required if the provider is 'amazon-bedrock'. + """ + + anthropic_config: VariableOrOptional[AnthropicConfigParam] + """ + Anthropic Config. Only required if the provider is 'anthropic'. + """ + + cohere_config: VariableOrOptional[CohereConfigParam] + """ + Cohere Config. Only required if the provider is 'cohere'. + """ + + custom_provider_config: VariableOrOptional[CustomProviderConfigParam] + """ + Custom Provider Config. Only required if the provider is 'custom'. + """ + + databricks_model_serving_config: VariableOrOptional[ + DatabricksModelServingConfigParam + ] + """ + Databricks Model Serving Config. Only required if the provider is 'databricks-model-serving'. + """ + + google_cloud_vertex_ai_config: VariableOrOptional[GoogleCloudVertexAiConfigParam] + """ + Google Cloud Vertex AI Config. Only required if the provider is 'google-cloud-vertex-ai'. + """ + + openai_config: VariableOrOptional[OpenAiConfigParam] + """ + OpenAI Config. Only required if the provider is 'openai'. + """ + + palm_config: VariableOrOptional[PaLmConfigParam] + """ + PaLM Config. Only required if the provider is 'palm'. + """ + + +ExternalModelParam = ExternalModelDict | ExternalModel diff --git a/python/databricks/bundles/model_serving_endpoints/_models/external_model_provider.py b/python/databricks/bundles/model_serving_endpoints/_models/external_model_provider.py new file mode 100644 index 0000000000..8ec807df30 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/external_model_provider.py @@ -0,0 +1,30 @@ +from enum import Enum +from typing import Literal + + +class ExternalModelProvider(Enum): + AI21LABS = "ai21labs" + ANTHROPIC = "anthropic" + AMAZON_BEDROCK = "amazon-bedrock" + COHERE = "cohere" + DATABRICKS_MODEL_SERVING = "databricks-model-serving" + GOOGLE_CLOUD_VERTEX_AI = "google-cloud-vertex-ai" + OPENAI = "openai" + PALM = "palm" + CUSTOM = "custom" + + +ExternalModelProviderParam = ( + Literal[ + "ai21labs", + "anthropic", + "amazon-bedrock", + "cohere", + "databricks-model-serving", + "google-cloud-vertex-ai", + "openai", + "palm", + "custom", + ] + | ExternalModelProvider +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/fallback_config.py b/python/databricks/bundles/model_serving_endpoints/_models/fallback_config.py new file mode 100644 index 0000000000..e4f547577b --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/fallback_config.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class FallbackConfig: + """""" + + enabled: VariableOr[bool] + """ + Whether to enable traffic fallback. When a served entity in the serving endpoint returns specific error + codes (e.g. 500), the request will automatically be round-robin attempted with other served entities in the same + endpoint, following the order of served entity list, until a successful response is returned. + If all attempts fail, return the last response with the error code. + """ + + @classmethod + def from_dict(cls, value: "FallbackConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "FallbackConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class FallbackConfigDict(TypedDict, total=False): + """""" + + enabled: VariableOr[bool] + """ + Whether to enable traffic fallback. When a served entity in the serving endpoint returns specific error + codes (e.g. 500), the request will automatically be round-robin attempted with other served entities in the same + endpoint, following the order of served entity list, until a successful response is returned. + If all attempts fail, return the last response with the error code. + """ + + +FallbackConfigParam = FallbackConfigDict | FallbackConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/google_cloud_vertex_ai_config.py b/python/databricks/bundles/model_serving_endpoints/_models/google_cloud_vertex_ai_config.py new file mode 100644 index 0000000000..7b4667ac43 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/google_cloud_vertex_ai_config.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class GoogleCloudVertexAiConfig: + """""" + + project_id: VariableOr[str] + """ + This is the Google Cloud project id that the service account is + associated with. + """ + + region: VariableOr[str] + """ + This is the region for the Google Cloud Vertex AI Service. See [supported + regions] for more details. Some models are only available in specific + regions. + + [supported regions]: https://cloud.google.com/vertex-ai/docs/general/locations + """ + + private_key: VariableOrOptional[str] = None + """ + The Databricks secret key reference for a private key for the service + account which has access to the Google Cloud Vertex AI Service. See [Best + practices for managing service account keys]. If you prefer to paste your + API key directly, see `private_key_plaintext`. You must provide an API + key using one of the following fields: `private_key` or + `private_key_plaintext` + + [Best practices for managing service account keys]: https://cloud.google.com/iam/docs/best-practices-for-managing-service-account-keys + """ + + private_key_plaintext: VariableOrOptional[str] = None + """ + The private key for the service account which has access to the Google + Cloud Vertex AI Service provided as a plaintext secret. See [Best + practices for managing service account keys]. If you prefer to reference + your key using Databricks Secrets, see `private_key`. You must provide an + API key using one of the following fields: `private_key` or + `private_key_plaintext`. + + [Best practices for managing service account keys]: https://cloud.google.com/iam/docs/best-practices-for-managing-service-account-keys + """ + + @classmethod + def from_dict(cls, value: "GoogleCloudVertexAiConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "GoogleCloudVertexAiConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class GoogleCloudVertexAiConfigDict(TypedDict, total=False): + """""" + + project_id: VariableOr[str] + """ + This is the Google Cloud project id that the service account is + associated with. + """ + + region: VariableOr[str] + """ + This is the region for the Google Cloud Vertex AI Service. See [supported + regions] for more details. Some models are only available in specific + regions. + + [supported regions]: https://cloud.google.com/vertex-ai/docs/general/locations + """ + + private_key: VariableOrOptional[str] + """ + The Databricks secret key reference for a private key for the service + account which has access to the Google Cloud Vertex AI Service. See [Best + practices for managing service account keys]. If you prefer to paste your + API key directly, see `private_key_plaintext`. You must provide an API + key using one of the following fields: `private_key` or + `private_key_plaintext` + + [Best practices for managing service account keys]: https://cloud.google.com/iam/docs/best-practices-for-managing-service-account-keys + """ + + private_key_plaintext: VariableOrOptional[str] + """ + The private key for the service account which has access to the Google + Cloud Vertex AI Service provided as a plaintext secret. See [Best + practices for managing service account keys]. If you prefer to reference + your key using Databricks Secrets, see `private_key`. You must provide an + API key using one of the following fields: `private_key` or + `private_key_plaintext`. + + [Best practices for managing service account keys]: https://cloud.google.com/iam/docs/best-practices-for-managing-service-account-keys + """ + + +GoogleCloudVertexAiConfigParam = ( + GoogleCloudVertexAiConfigDict | GoogleCloudVertexAiConfig +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/lifecycle.py b/python/databricks/bundles/model_serving_endpoints/_models/lifecycle.py new file mode 100644 index 0000000000..c934967f37 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/lifecycle.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class Lifecycle: + """""" + + prevent_destroy: VariableOrOptional[bool] = None + """ + Lifecycle setting to prevent the resource from being destroyed. + """ + + @classmethod + def from_dict(cls, value: "LifecycleDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "LifecycleDict": + return _transform_to_json_value(self) # type:ignore + + +class LifecycleDict(TypedDict, total=False): + """""" + + prevent_destroy: VariableOrOptional[bool] + """ + Lifecycle setting to prevent the resource from being destroyed. + """ + + +LifecycleParam = LifecycleDict | Lifecycle diff --git a/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint.py b/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint.py new file mode 100644 index 0000000000..787c5b663f --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint.py @@ -0,0 +1,163 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._resource import Resource +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import ( + VariableOr, + VariableOrList, + VariableOrOptional, +) +from databricks.bundles.model_serving_endpoints._models.ai_gateway_config import ( + AiGatewayConfig, + AiGatewayConfigParam, +) +from databricks.bundles.model_serving_endpoints._models.email_notifications import ( + EmailNotifications, + EmailNotificationsParam, +) +from databricks.bundles.model_serving_endpoints._models.endpoint_core_config_input import ( + EndpointCoreConfigInput, + EndpointCoreConfigInputParam, +) +from databricks.bundles.model_serving_endpoints._models.endpoint_tag import ( + EndpointTag, + EndpointTagParam, +) +from databricks.bundles.model_serving_endpoints._models.lifecycle import ( + Lifecycle, + LifecycleParam, +) +from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint_permission import ( + ModelServingEndpointPermission, + ModelServingEndpointPermissionParam, +) +from databricks.bundles.model_serving_endpoints._models.rate_limit import ( + RateLimit, + RateLimitParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class ModelServingEndpoint(Resource): + """""" + + name: VariableOr[str] + """ + The name of the serving endpoint. This field is required and must be unique across a Databricks workspace. + An endpoint name can consist of alphanumeric characters, dashes, and underscores. + """ + + ai_gateway: VariableOrOptional[AiGatewayConfig] = None + """ + The AI Gateway configuration for the serving endpoint. NOTE: External model, provisioned throughput, and pay-per-token endpoints are fully supported; agent endpoints currently only support inference tables. + """ + + budget_policy_id: VariableOrOptional[str] = None + """ + The budget policy to be applied to the serving endpoint. + """ + + config: VariableOrOptional[EndpointCoreConfigInput] = None + """ + The core config of the serving endpoint. + """ + + description: VariableOrOptional[str] = None + + email_notifications: VariableOrOptional[EmailNotifications] = None + """ + Email notification settings. + """ + + lifecycle: VariableOrOptional[Lifecycle] = None + """ + Lifecycle is a struct that contains the lifecycle settings for a resource. It controls the behavior of the resource when it is deployed or destroyed. + """ + + permissions: VariableOrList[ModelServingEndpointPermission] = field( + default_factory=list + ) + + rate_limits: VariableOrList[RateLimit] = field(default_factory=list) + """ + [DEPRECATED] Rate limits to be applied to the serving endpoint. NOTE: this field is deprecated, please use AI Gateway to manage rate limits. + """ + + route_optimized: VariableOrOptional[bool] = None + """ + Enable route optimization for the serving endpoint. + """ + + tags: VariableOrList[EndpointTag] = field(default_factory=list) + """ + Tags to be attached to the serving endpoint and automatically propagated to billing logs. + """ + + @classmethod + def from_dict(cls, value: "ModelServingEndpointDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "ModelServingEndpointDict": + return _transform_to_json_value(self) # type:ignore + + +class ModelServingEndpointDict(TypedDict, total=False): + """""" + + name: VariableOr[str] + """ + The name of the serving endpoint. This field is required and must be unique across a Databricks workspace. + An endpoint name can consist of alphanumeric characters, dashes, and underscores. + """ + + ai_gateway: VariableOrOptional[AiGatewayConfigParam] + """ + The AI Gateway configuration for the serving endpoint. NOTE: External model, provisioned throughput, and pay-per-token endpoints are fully supported; agent endpoints currently only support inference tables. + """ + + budget_policy_id: VariableOrOptional[str] + """ + The budget policy to be applied to the serving endpoint. + """ + + config: VariableOrOptional[EndpointCoreConfigInputParam] + """ + The core config of the serving endpoint. + """ + + description: VariableOrOptional[str] + + email_notifications: VariableOrOptional[EmailNotificationsParam] + """ + Email notification settings. + """ + + lifecycle: VariableOrOptional[LifecycleParam] + """ + Lifecycle is a struct that contains the lifecycle settings for a resource. It controls the behavior of the resource when it is deployed or destroyed. + """ + + permissions: VariableOrList[ModelServingEndpointPermissionParam] + + rate_limits: VariableOrList[RateLimitParam] + """ + [DEPRECATED] Rate limits to be applied to the serving endpoint. NOTE: this field is deprecated, please use AI Gateway to manage rate limits. + """ + + route_optimized: VariableOrOptional[bool] + """ + Enable route optimization for the serving endpoint. + """ + + tags: VariableOrList[EndpointTagParam] + """ + Tags to be attached to the serving endpoint and automatically propagated to billing logs. + """ + + +ModelServingEndpointParam = ModelServingEndpointDict | ModelServingEndpoint diff --git a/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint_permission.py b/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint_permission.py new file mode 100644 index 0000000000..8cf1cf895b --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint_permission.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint_permission_level import ( + ModelServingEndpointPermissionLevel, + ModelServingEndpointPermissionLevelParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class ModelServingEndpointPermission: + """""" + + level: VariableOr[ModelServingEndpointPermissionLevel] + + group_name: VariableOrOptional[str] = None + + service_principal_name: VariableOrOptional[str] = None + + user_name: VariableOrOptional[str] = None + + @classmethod + def from_dict(cls, value: "ModelServingEndpointPermissionDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "ModelServingEndpointPermissionDict": + return _transform_to_json_value(self) # type:ignore + + +class ModelServingEndpointPermissionDict(TypedDict, total=False): + """""" + + level: VariableOr[ModelServingEndpointPermissionLevelParam] + + group_name: VariableOrOptional[str] + + service_principal_name: VariableOrOptional[str] + + user_name: VariableOrOptional[str] + + +ModelServingEndpointPermissionParam = ( + ModelServingEndpointPermissionDict | ModelServingEndpointPermission +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint_permission_level.py b/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint_permission_level.py new file mode 100644 index 0000000000..c57f560374 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/model_serving_endpoint_permission_level.py @@ -0,0 +1,13 @@ +from enum import Enum +from typing import Literal + + +class ModelServingEndpointPermissionLevel(Enum): + CAN_MANAGE = "CAN_MANAGE" + CAN_QUERY = "CAN_QUERY" + CAN_VIEW = "CAN_VIEW" + + +ModelServingEndpointPermissionLevelParam = ( + Literal["CAN_MANAGE", "CAN_QUERY", "CAN_VIEW"] | ModelServingEndpointPermissionLevel +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/open_ai_config.py b/python/databricks/bundles/model_serving_endpoints/_models/open_ai_config.py new file mode 100644 index 0000000000..f34e4f2774 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/open_ai_config.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class OpenAiConfig: + """ + Configs needed to create an OpenAI model route. + """ + + microsoft_entra_client_id: VariableOrOptional[str] = None + """ + This field is only required for Azure AD OpenAI and is the Microsoft + Entra Client ID. + """ + + microsoft_entra_client_secret: VariableOrOptional[str] = None + """ + The Databricks secret key reference for a client secret used for + Microsoft Entra ID authentication. If you prefer to paste your client + secret directly, see `microsoft_entra_client_secret_plaintext`. You must + provide an API key using one of the following fields: + `microsoft_entra_client_secret` or + `microsoft_entra_client_secret_plaintext`. + """ + + microsoft_entra_client_secret_plaintext: VariableOrOptional[str] = None + """ + The client secret used for Microsoft Entra ID authentication provided as + a plaintext string. If you prefer to reference your key using Databricks + Secrets, see `microsoft_entra_client_secret`. You must provide an API key + using one of the following fields: `microsoft_entra_client_secret` or + `microsoft_entra_client_secret_plaintext`. + """ + + microsoft_entra_tenant_id: VariableOrOptional[str] = None + """ + This field is only required for Azure AD OpenAI and is the Microsoft + Entra Tenant ID. + """ + + openai_api_base: VariableOrOptional[str] = None + """ + This is a field to provide a customized base URl for the OpenAI API. For + Azure OpenAI, this field is required, and is the base URL for the Azure + OpenAI API service provided by Azure. For other OpenAI API types, this + field is optional, and if left unspecified, the standard OpenAI base URL + is used. + """ + + openai_api_key: VariableOrOptional[str] = None + """ + The Databricks secret key reference for an OpenAI API key using the + OpenAI or Azure service. If you prefer to paste your API key directly, + see `openai_api_key_plaintext`. You must provide an API key using one of + the following fields: `openai_api_key` or `openai_api_key_plaintext`. + """ + + openai_api_key_plaintext: VariableOrOptional[str] = None + """ + The OpenAI API key using the OpenAI or Azure service provided as a + plaintext string. If you prefer to reference your key using Databricks + Secrets, see `openai_api_key`. You must provide an API key using one of + the following fields: `openai_api_key` or `openai_api_key_plaintext`. + """ + + openai_api_type: VariableOrOptional[str] = None + """ + This is an optional field to specify the type of OpenAI API to use. For + Azure OpenAI, this field is required, and adjust this parameter to + represent the preferred security access validation protocol. For access + token validation, use azure. For authentication using Azure Active + Directory (Azure AD) use, azuread. + """ + + openai_api_version: VariableOrOptional[str] = None + """ + This is an optional field to specify the OpenAI API version. For Azure + OpenAI, this field is required, and is the version of the Azure OpenAI + service to utilize, specified by a date. + """ + + openai_deployment_name: VariableOrOptional[str] = None + """ + This field is only required for Azure OpenAI and is the name of the + deployment resource for the Azure OpenAI service. + """ + + openai_organization: VariableOrOptional[str] = None + """ + This is an optional field to specify the organization in OpenAI or Azure + OpenAI. + """ + + @classmethod + def from_dict(cls, value: "OpenAiConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "OpenAiConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class OpenAiConfigDict(TypedDict, total=False): + """""" + + microsoft_entra_client_id: VariableOrOptional[str] + """ + This field is only required for Azure AD OpenAI and is the Microsoft + Entra Client ID. + """ + + microsoft_entra_client_secret: VariableOrOptional[str] + """ + The Databricks secret key reference for a client secret used for + Microsoft Entra ID authentication. If you prefer to paste your client + secret directly, see `microsoft_entra_client_secret_plaintext`. You must + provide an API key using one of the following fields: + `microsoft_entra_client_secret` or + `microsoft_entra_client_secret_plaintext`. + """ + + microsoft_entra_client_secret_plaintext: VariableOrOptional[str] + """ + The client secret used for Microsoft Entra ID authentication provided as + a plaintext string. If you prefer to reference your key using Databricks + Secrets, see `microsoft_entra_client_secret`. You must provide an API key + using one of the following fields: `microsoft_entra_client_secret` or + `microsoft_entra_client_secret_plaintext`. + """ + + microsoft_entra_tenant_id: VariableOrOptional[str] + """ + This field is only required for Azure AD OpenAI and is the Microsoft + Entra Tenant ID. + """ + + openai_api_base: VariableOrOptional[str] + """ + This is a field to provide a customized base URl for the OpenAI API. For + Azure OpenAI, this field is required, and is the base URL for the Azure + OpenAI API service provided by Azure. For other OpenAI API types, this + field is optional, and if left unspecified, the standard OpenAI base URL + is used. + """ + + openai_api_key: VariableOrOptional[str] + """ + The Databricks secret key reference for an OpenAI API key using the + OpenAI or Azure service. If you prefer to paste your API key directly, + see `openai_api_key_plaintext`. You must provide an API key using one of + the following fields: `openai_api_key` or `openai_api_key_plaintext`. + """ + + openai_api_key_plaintext: VariableOrOptional[str] + """ + The OpenAI API key using the OpenAI or Azure service provided as a + plaintext string. If you prefer to reference your key using Databricks + Secrets, see `openai_api_key`. You must provide an API key using one of + the following fields: `openai_api_key` or `openai_api_key_plaintext`. + """ + + openai_api_type: VariableOrOptional[str] + """ + This is an optional field to specify the type of OpenAI API to use. For + Azure OpenAI, this field is required, and adjust this parameter to + represent the preferred security access validation protocol. For access + token validation, use azure. For authentication using Azure Active + Directory (Azure AD) use, azuread. + """ + + openai_api_version: VariableOrOptional[str] + """ + This is an optional field to specify the OpenAI API version. For Azure + OpenAI, this field is required, and is the version of the Azure OpenAI + service to utilize, specified by a date. + """ + + openai_deployment_name: VariableOrOptional[str] + """ + This field is only required for Azure OpenAI and is the name of the + deployment resource for the Azure OpenAI service. + """ + + openai_organization: VariableOrOptional[str] + """ + This is an optional field to specify the organization in OpenAI or Azure + OpenAI. + """ + + +OpenAiConfigParam = OpenAiConfigDict | OpenAiConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/pa_lm_config.py b/python/databricks/bundles/model_serving_endpoints/_models/pa_lm_config.py new file mode 100644 index 0000000000..0781d2cf0f --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/pa_lm_config.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class PaLmConfig: + """""" + + palm_api_key: VariableOrOptional[str] = None + """ + The Databricks secret key reference for a PaLM API key. If you prefer to + paste your API key directly, see `palm_api_key_plaintext`. You must + provide an API key using one of the following fields: `palm_api_key` or + `palm_api_key_plaintext`. + """ + + palm_api_key_plaintext: VariableOrOptional[str] = None + """ + The PaLM API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `palm_api_key`. You must + provide an API key using one of the following fields: `palm_api_key` or + `palm_api_key_plaintext`. + """ + + @classmethod + def from_dict(cls, value: "PaLmConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "PaLmConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class PaLmConfigDict(TypedDict, total=False): + """""" + + palm_api_key: VariableOrOptional[str] + """ + The Databricks secret key reference for a PaLM API key. If you prefer to + paste your API key directly, see `palm_api_key_plaintext`. You must + provide an API key using one of the following fields: `palm_api_key` or + `palm_api_key_plaintext`. + """ + + palm_api_key_plaintext: VariableOrOptional[str] + """ + The PaLM API key provided as a plaintext string. If you prefer to + reference your key using Databricks Secrets, see `palm_api_key`. You must + provide an API key using one of the following fields: `palm_api_key` or + `palm_api_key_plaintext`. + """ + + +PaLmConfigParam = PaLmConfigDict | PaLmConfig diff --git a/python/databricks/bundles/model_serving_endpoints/_models/rate_limit.py b/python/databricks/bundles/model_serving_endpoints/_models/rate_limit.py new file mode 100644 index 0000000000..626be81971 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/rate_limit.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.rate_limit_key import ( + RateLimitKey, + RateLimitKeyParam, +) +from databricks.bundles.model_serving_endpoints._models.rate_limit_renewal_period import ( + RateLimitRenewalPeriod, + RateLimitRenewalPeriodParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class RateLimit: + """ + [DEPRECATED] + """ + + calls: VariableOr[int] + """ + Used to specify how many calls are allowed for a key within the renewal_period. + """ + + renewal_period: VariableOr[RateLimitRenewalPeriod] + """ + Renewal period field for a serving endpoint rate limit. Currently, only 'minute' is supported. + """ + + key: VariableOrOptional[RateLimitKey] = None + """ + Key field for a serving endpoint rate limit. Currently, only 'user' and 'endpoint' are supported, with 'endpoint' being the default if not specified. + """ + + @classmethod + def from_dict(cls, value: "RateLimitDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "RateLimitDict": + return _transform_to_json_value(self) # type:ignore + + +class RateLimitDict(TypedDict, total=False): + """""" + + calls: VariableOr[int] + """ + Used to specify how many calls are allowed for a key within the renewal_period. + """ + + renewal_period: VariableOr[RateLimitRenewalPeriodParam] + """ + Renewal period field for a serving endpoint rate limit. Currently, only 'minute' is supported. + """ + + key: VariableOrOptional[RateLimitKeyParam] + """ + Key field for a serving endpoint rate limit. Currently, only 'user' and 'endpoint' are supported, with 'endpoint' being the default if not specified. + """ + + +RateLimitParam = RateLimitDict | RateLimit diff --git a/python/databricks/bundles/model_serving_endpoints/_models/rate_limit_key.py b/python/databricks/bundles/model_serving_endpoints/_models/rate_limit_key.py new file mode 100644 index 0000000000..e1a0c85346 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/rate_limit_key.py @@ -0,0 +1,14 @@ +from enum import Enum +from typing import Literal + + +class RateLimitKey(Enum): + """ + [DEPRECATED] + """ + + USER = "user" + ENDPOINT = "endpoint" + + +RateLimitKeyParam = Literal["user", "endpoint"] | RateLimitKey diff --git a/python/databricks/bundles/model_serving_endpoints/_models/rate_limit_renewal_period.py b/python/databricks/bundles/model_serving_endpoints/_models/rate_limit_renewal_period.py new file mode 100644 index 0000000000..919d0e656e --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/rate_limit_renewal_period.py @@ -0,0 +1,13 @@ +from enum import Enum +from typing import Literal + + +class RateLimitRenewalPeriod(Enum): + """ + [DEPRECATED] + """ + + MINUTE = "minute" + + +RateLimitRenewalPeriodParam = Literal["minute"] | RateLimitRenewalPeriod diff --git a/python/databricks/bundles/model_serving_endpoints/_models/route.py b/python/databricks/bundles/model_serving_endpoints/_models/route.py new file mode 100644 index 0000000000..05a7b8a99d --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/route.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOr, VariableOrOptional + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class Route: + """""" + + traffic_percentage: VariableOr[int] + """ + The percentage of endpoint traffic to send to this route. It must be an integer between 0 and 100 inclusive. + """ + + served_entity_name: VariableOrOptional[str] = None + + served_model_name: VariableOrOptional[str] = None + """ + The name of the served model this route configures traffic for. + """ + + @classmethod + def from_dict(cls, value: "RouteDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "RouteDict": + return _transform_to_json_value(self) # type:ignore + + +class RouteDict(TypedDict, total=False): + """""" + + traffic_percentage: VariableOr[int] + """ + The percentage of endpoint traffic to send to this route. It must be an integer between 0 and 100 inclusive. + """ + + served_entity_name: VariableOrOptional[str] + + served_model_name: VariableOrOptional[str] + """ + The name of the served model this route configures traffic for. + """ + + +RouteParam = RouteDict | Route diff --git a/python/databricks/bundles/model_serving_endpoints/_models/served_entity_input.py b/python/databricks/bundles/model_serving_endpoints/_models/served_entity_input.py new file mode 100644 index 0000000000..17070e9503 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/served_entity_input.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrDict, VariableOrOptional +from databricks.bundles.model_serving_endpoints._models.external_model import ( + ExternalModel, + ExternalModelParam, +) +from databricks.bundles.model_serving_endpoints._models.serving_model_workload_type import ( + ServingModelWorkloadType, + ServingModelWorkloadTypeParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class ServedEntityInput: + """""" + + entity_name: VariableOrOptional[str] = None + """ + The name of the entity to be served. The entity may be a model in the Databricks Model Registry, a model in the Unity Catalog (UC), or a function of type FEATURE_SPEC in the UC. If it is a UC object, the full name of the object should be given in the form of **catalog_name.schema_name.model_name**. + """ + + entity_version: VariableOrOptional[str] = None + + environment_vars: VariableOrDict[str] = field(default_factory=dict) + """ + An object containing a set of optional, user-specified environment variable key-value pairs used for serving this entity. Note: this is an experimental feature and subject to change. Example entity environment variables that refer to Databricks secrets: `{"OPENAI_API_KEY": "{{secrets/my_scope/my_key}}", "DATABRICKS_TOKEN": "{{secrets/my_scope2/my_key2}}"}` + """ + + external_model: VariableOrOptional[ExternalModel] = None + """ + The external model to be served. NOTE: Only one of external_model and (entity_name, entity_version, workload_size, workload_type, and scale_to_zero_enabled) can be specified with the latter set being used for custom model serving for a Databricks registered model. For an existing endpoint with external_model, it cannot be updated to an endpoint without external_model. If the endpoint is created without external_model, users cannot update it to add external_model later. The task type of all external models within an endpoint must be the same. + """ + + instance_profile_arn: VariableOrOptional[str] = None + """ + ARN of the instance profile that the served entity uses to access AWS resources. + """ + + max_provisioned_concurrency: VariableOrOptional[int] = None + """ + The maximum provisioned concurrency that the endpoint can scale up to. Do not use if workload_size is specified. + """ + + max_provisioned_throughput: VariableOrOptional[int] = None + """ + The maximum tokens per second that the endpoint can scale up to. + """ + + min_provisioned_concurrency: VariableOrOptional[int] = None + """ + The minimum provisioned concurrency that the endpoint can scale down to. Do not use if workload_size is specified. + """ + + min_provisioned_throughput: VariableOrOptional[int] = None + """ + The minimum tokens per second that the endpoint can scale down to. + """ + + name: VariableOrOptional[str] = None + """ + The name of a served entity. It must be unique across an endpoint. A served entity name can consist of alphanumeric characters, dashes, and underscores. If not specified for an external model, this field defaults to external_model.name, with '.' and ':' replaced with '-', and if not specified for other entities, it defaults to entity_name-entity_version. + """ + + provisioned_model_units: VariableOrOptional[int] = None + """ + The number of model units provisioned. + """ + + scale_to_zero_enabled: VariableOrOptional[bool] = None + """ + Whether the compute resources for the served entity should scale down to zero. + """ + + workload_size: VariableOrOptional[str] = None + """ + The workload size of the served entity. The workload size corresponds to a range of provisioned concurrency that the compute autoscales between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are "Small" (4 - 4 provisioned concurrency), "Medium" (8 - 16 provisioned concurrency), and "Large" (16 - 64 provisioned concurrency). Additional custom workload sizes can also be used when available in the workspace. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size is 0. Do not use if min_provisioned_concurrency and max_provisioned_concurrency are specified. + """ + + workload_type: VariableOrOptional[ServingModelWorkloadType] = None + """ + The workload type of the served entity. The workload type selects which type of compute to use in the endpoint. The default value for this parameter is "CPU". For deep learning workloads, GPU acceleration is available by selecting workload types like GPU_SMALL and others. See the available [GPU types](https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#gpu-workload-types). + """ + + @classmethod + def from_dict(cls, value: "ServedEntityInputDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "ServedEntityInputDict": + return _transform_to_json_value(self) # type:ignore + + +class ServedEntityInputDict(TypedDict, total=False): + """""" + + entity_name: VariableOrOptional[str] + """ + The name of the entity to be served. The entity may be a model in the Databricks Model Registry, a model in the Unity Catalog (UC), or a function of type FEATURE_SPEC in the UC. If it is a UC object, the full name of the object should be given in the form of **catalog_name.schema_name.model_name**. + """ + + entity_version: VariableOrOptional[str] + + environment_vars: VariableOrDict[str] + """ + An object containing a set of optional, user-specified environment variable key-value pairs used for serving this entity. Note: this is an experimental feature and subject to change. Example entity environment variables that refer to Databricks secrets: `{"OPENAI_API_KEY": "{{secrets/my_scope/my_key}}", "DATABRICKS_TOKEN": "{{secrets/my_scope2/my_key2}}"}` + """ + + external_model: VariableOrOptional[ExternalModelParam] + """ + The external model to be served. NOTE: Only one of external_model and (entity_name, entity_version, workload_size, workload_type, and scale_to_zero_enabled) can be specified with the latter set being used for custom model serving for a Databricks registered model. For an existing endpoint with external_model, it cannot be updated to an endpoint without external_model. If the endpoint is created without external_model, users cannot update it to add external_model later. The task type of all external models within an endpoint must be the same. + """ + + instance_profile_arn: VariableOrOptional[str] + """ + ARN of the instance profile that the served entity uses to access AWS resources. + """ + + max_provisioned_concurrency: VariableOrOptional[int] + """ + The maximum provisioned concurrency that the endpoint can scale up to. Do not use if workload_size is specified. + """ + + max_provisioned_throughput: VariableOrOptional[int] + """ + The maximum tokens per second that the endpoint can scale up to. + """ + + min_provisioned_concurrency: VariableOrOptional[int] + """ + The minimum provisioned concurrency that the endpoint can scale down to. Do not use if workload_size is specified. + """ + + min_provisioned_throughput: VariableOrOptional[int] + """ + The minimum tokens per second that the endpoint can scale down to. + """ + + name: VariableOrOptional[str] + """ + The name of a served entity. It must be unique across an endpoint. A served entity name can consist of alphanumeric characters, dashes, and underscores. If not specified for an external model, this field defaults to external_model.name, with '.' and ':' replaced with '-', and if not specified for other entities, it defaults to entity_name-entity_version. + """ + + provisioned_model_units: VariableOrOptional[int] + """ + The number of model units provisioned. + """ + + scale_to_zero_enabled: VariableOrOptional[bool] + """ + Whether the compute resources for the served entity should scale down to zero. + """ + + workload_size: VariableOrOptional[str] + """ + The workload size of the served entity. The workload size corresponds to a range of provisioned concurrency that the compute autoscales between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are "Small" (4 - 4 provisioned concurrency), "Medium" (8 - 16 provisioned concurrency), and "Large" (16 - 64 provisioned concurrency). Additional custom workload sizes can also be used when available in the workspace. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size is 0. Do not use if min_provisioned_concurrency and max_provisioned_concurrency are specified. + """ + + workload_type: VariableOrOptional[ServingModelWorkloadTypeParam] + """ + The workload type of the served entity. The workload type selects which type of compute to use in the endpoint. The default value for this parameter is "CPU". For deep learning workloads, GPU acceleration is available by selecting workload types like GPU_SMALL and others. See the available [GPU types](https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#gpu-workload-types). + """ + + +ServedEntityInputParam = ServedEntityInputDict | ServedEntityInput diff --git a/python/databricks/bundles/model_serving_endpoints/_models/served_model_input.py b/python/databricks/bundles/model_serving_endpoints/_models/served_model_input.py new file mode 100644 index 0000000000..7693bbc53e --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/served_model_input.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import ( + VariableOr, + VariableOrDict, + VariableOrOptional, +) +from databricks.bundles.model_serving_endpoints._models.served_model_input_workload_type import ( + ServedModelInputWorkloadType, + ServedModelInputWorkloadTypeParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class ServedModelInput: + """""" + + model_name: VariableOr[str] + + model_version: VariableOr[str] + + scale_to_zero_enabled: VariableOr[bool] + """ + Whether the compute resources for the served entity should scale down to zero. + """ + + environment_vars: VariableOrDict[str] = field(default_factory=dict) + """ + An object containing a set of optional, user-specified environment variable key-value pairs used for serving this entity. Note: this is an experimental feature and subject to change. Example entity environment variables that refer to Databricks secrets: `{"OPENAI_API_KEY": "{{secrets/my_scope/my_key}}", "DATABRICKS_TOKEN": "{{secrets/my_scope2/my_key2}}"}` + """ + + instance_profile_arn: VariableOrOptional[str] = None + """ + ARN of the instance profile that the served entity uses to access AWS resources. + """ + + max_provisioned_concurrency: VariableOrOptional[int] = None + """ + The maximum provisioned concurrency that the endpoint can scale up to. Do not use if workload_size is specified. + """ + + max_provisioned_throughput: VariableOrOptional[int] = None + """ + The maximum tokens per second that the endpoint can scale up to. + """ + + min_provisioned_concurrency: VariableOrOptional[int] = None + """ + The minimum provisioned concurrency that the endpoint can scale down to. Do not use if workload_size is specified. + """ + + min_provisioned_throughput: VariableOrOptional[int] = None + """ + The minimum tokens per second that the endpoint can scale down to. + """ + + name: VariableOrOptional[str] = None + """ + The name of a served entity. It must be unique across an endpoint. A served entity name can consist of alphanumeric characters, dashes, and underscores. If not specified for an external model, this field defaults to external_model.name, with '.' and ':' replaced with '-', and if not specified for other entities, it defaults to entity_name-entity_version. + """ + + provisioned_model_units: VariableOrOptional[int] = None + """ + The number of model units provisioned. + """ + + workload_size: VariableOrOptional[str] = None + """ + The workload size of the served entity. The workload size corresponds to a range of provisioned concurrency that the compute autoscales between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are "Small" (4 - 4 provisioned concurrency), "Medium" (8 - 16 provisioned concurrency), and "Large" (16 - 64 provisioned concurrency). Additional custom workload sizes can also be used when available in the workspace. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size is 0. Do not use if min_provisioned_concurrency and max_provisioned_concurrency are specified. + """ + + workload_type: VariableOrOptional[ServedModelInputWorkloadType] = None + """ + The workload type of the served entity. The workload type selects which type of compute to use in the endpoint. The default value for this parameter is "CPU". For deep learning workloads, GPU acceleration is available by selecting workload types like GPU_SMALL and others. See the available [GPU types](https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#gpu-workload-types). + """ + + @classmethod + def from_dict(cls, value: "ServedModelInputDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "ServedModelInputDict": + return _transform_to_json_value(self) # type:ignore + + +class ServedModelInputDict(TypedDict, total=False): + """""" + + model_name: VariableOr[str] + + model_version: VariableOr[str] + + scale_to_zero_enabled: VariableOr[bool] + """ + Whether the compute resources for the served entity should scale down to zero. + """ + + environment_vars: VariableOrDict[str] + """ + An object containing a set of optional, user-specified environment variable key-value pairs used for serving this entity. Note: this is an experimental feature and subject to change. Example entity environment variables that refer to Databricks secrets: `{"OPENAI_API_KEY": "{{secrets/my_scope/my_key}}", "DATABRICKS_TOKEN": "{{secrets/my_scope2/my_key2}}"}` + """ + + instance_profile_arn: VariableOrOptional[str] + """ + ARN of the instance profile that the served entity uses to access AWS resources. + """ + + max_provisioned_concurrency: VariableOrOptional[int] + """ + The maximum provisioned concurrency that the endpoint can scale up to. Do not use if workload_size is specified. + """ + + max_provisioned_throughput: VariableOrOptional[int] + """ + The maximum tokens per second that the endpoint can scale up to. + """ + + min_provisioned_concurrency: VariableOrOptional[int] + """ + The minimum provisioned concurrency that the endpoint can scale down to. Do not use if workload_size is specified. + """ + + min_provisioned_throughput: VariableOrOptional[int] + """ + The minimum tokens per second that the endpoint can scale down to. + """ + + name: VariableOrOptional[str] + """ + The name of a served entity. It must be unique across an endpoint. A served entity name can consist of alphanumeric characters, dashes, and underscores. If not specified for an external model, this field defaults to external_model.name, with '.' and ':' replaced with '-', and if not specified for other entities, it defaults to entity_name-entity_version. + """ + + provisioned_model_units: VariableOrOptional[int] + """ + The number of model units provisioned. + """ + + workload_size: VariableOrOptional[str] + """ + The workload size of the served entity. The workload size corresponds to a range of provisioned concurrency that the compute autoscales between. A single unit of provisioned concurrency can process one request at a time. Valid workload sizes are "Small" (4 - 4 provisioned concurrency), "Medium" (8 - 16 provisioned concurrency), and "Large" (16 - 64 provisioned concurrency). Additional custom workload sizes can also be used when available in the workspace. If scale-to-zero is enabled, the lower bound of the provisioned concurrency for each workload size is 0. Do not use if min_provisioned_concurrency and max_provisioned_concurrency are specified. + """ + + workload_type: VariableOrOptional[ServedModelInputWorkloadTypeParam] + """ + The workload type of the served entity. The workload type selects which type of compute to use in the endpoint. The default value for this parameter is "CPU". For deep learning workloads, GPU acceleration is available by selecting workload types like GPU_SMALL and others. See the available [GPU types](https://docs.databricks.com/en/machine-learning/model-serving/create-manage-serving-endpoints.html#gpu-workload-types). + """ + + +ServedModelInputParam = ServedModelInputDict | ServedModelInput diff --git a/python/databricks/bundles/model_serving_endpoints/_models/served_model_input_workload_type.py b/python/databricks/bundles/model_serving_endpoints/_models/served_model_input_workload_type.py new file mode 100644 index 0000000000..df71b32cc6 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/served_model_input_workload_type.py @@ -0,0 +1,20 @@ +from enum import Enum +from typing import Literal + + +class ServedModelInputWorkloadType(Enum): + """ + Please keep this in sync with with workload types in InferenceEndpointEntities.scala + """ + + CPU = "CPU" + GPU_MEDIUM = "GPU_MEDIUM" + GPU_SMALL = "GPU_SMALL" + GPU_LARGE = "GPU_LARGE" + MULTIGPU_MEDIUM = "MULTIGPU_MEDIUM" + + +ServedModelInputWorkloadTypeParam = ( + Literal["CPU", "GPU_MEDIUM", "GPU_SMALL", "GPU_LARGE", "MULTIGPU_MEDIUM"] + | ServedModelInputWorkloadType +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/serving_model_workload_type.py b/python/databricks/bundles/model_serving_endpoints/_models/serving_model_workload_type.py new file mode 100644 index 0000000000..ebb024e130 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/serving_model_workload_type.py @@ -0,0 +1,20 @@ +from enum import Enum +from typing import Literal + + +class ServingModelWorkloadType(Enum): + """ + Please keep this in sync with with workload types in InferenceEndpointEntities.scala + """ + + CPU = "CPU" + GPU_MEDIUM = "GPU_MEDIUM" + GPU_SMALL = "GPU_SMALL" + GPU_LARGE = "GPU_LARGE" + MULTIGPU_MEDIUM = "MULTIGPU_MEDIUM" + + +ServingModelWorkloadTypeParam = ( + Literal["CPU", "GPU_MEDIUM", "GPU_SMALL", "GPU_LARGE", "MULTIGPU_MEDIUM"] + | ServingModelWorkloadType +) diff --git a/python/databricks/bundles/model_serving_endpoints/_models/traffic_config.py b/python/databricks/bundles/model_serving_endpoints/_models/traffic_config.py new file mode 100644 index 0000000000..34c0e23e14 --- /dev/null +++ b/python/databricks/bundles/model_serving_endpoints/_models/traffic_config.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypedDict + +from databricks.bundles.core._transform import _transform +from databricks.bundles.core._transform_to_json import _transform_to_json_value +from databricks.bundles.core._variable import VariableOrList +from databricks.bundles.model_serving_endpoints._models.route import ( + Route, + RouteParam, +) + +if TYPE_CHECKING: + from typing_extensions import Self + + +@dataclass(kw_only=True) +class TrafficConfig: + """""" + + routes: VariableOrList[Route] = field(default_factory=list) + """ + The list of routes that define traffic to each served entity. + """ + + @classmethod + def from_dict(cls, value: "TrafficConfigDict") -> "Self": + return _transform(cls, value) + + def as_dict(self) -> "TrafficConfigDict": + return _transform_to_json_value(self) # type:ignore + + +class TrafficConfigDict(TypedDict, total=False): + """""" + + routes: VariableOrList[RouteParam] + """ + The list of routes that define traffic to each served entity. + """ + + +TrafficConfigParam = TrafficConfigDict | TrafficConfig diff --git a/python/databricks/bundles/pipelines/__init__.py b/python/databricks/bundles/pipelines/__init__.py index cadfc3e87b..23f8a71cce 100644 --- a/python/databricks/bundles/pipelines/__init__.py +++ b/python/databricks/bundles/pipelines/__init__.py @@ -141,6 +141,10 @@ "TableSpecificConfigParam", "TableSpecificConfigScdType", "TableSpecificConfigScdTypeParam", + "VariableOr", + "VariableOrDict", + "VariableOrList", + "VariableOrOptional", "VolumesStorageInfo", "VolumesStorageInfoDict", "VolumesStorageInfoParam", @@ -150,6 +154,12 @@ ] +from databricks.bundles.core import ( + VariableOr, + VariableOrDict, + VariableOrList, + VariableOrOptional, +) from databricks.bundles.pipelines._models.adlsgen2_info import ( Adlsgen2Info, Adlsgen2InfoDict, diff --git a/python/databricks/bundles/schemas/__init__.py b/python/databricks/bundles/schemas/__init__.py index d4d0fa33a3..c5629f8304 100644 --- a/python/databricks/bundles/schemas/__init__.py +++ b/python/databricks/bundles/schemas/__init__.py @@ -10,9 +10,19 @@ "SchemaGrantPrivilege", "SchemaGrantPrivilegeParam", "SchemaParam", + "VariableOr", + "VariableOrDict", + "VariableOrList", + "VariableOrOptional", ] +from databricks.bundles.core import ( + VariableOr, + VariableOrDict, + VariableOrList, + VariableOrOptional, +) from databricks.bundles.schemas._models.lifecycle import ( Lifecycle, LifecycleDict, diff --git a/python/databricks/bundles/volumes/__init__.py b/python/databricks/bundles/volumes/__init__.py index 065713bf6c..02ba5f0717 100644 --- a/python/databricks/bundles/volumes/__init__.py +++ b/python/databricks/bundles/volumes/__init__.py @@ -2,6 +2,9 @@ "Lifecycle", "LifecycleDict", "LifecycleParam", + "VariableOr", + "VariableOrList", + "VariableOrOptional", "Volume", "VolumeDict", "VolumeGrant", @@ -15,6 +18,7 @@ ] +from databricks.bundles.core import VariableOr, VariableOrList, VariableOrOptional from databricks.bundles.volumes._models.lifecycle import ( Lifecycle, LifecycleDict,