Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
5 changes: 3 additions & 2 deletions python/codegen/codegen/generated_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_code(generated: GeneratedEnum) -> str:


def _camel_to_upper_snake(value):
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", value)
s1 = value.replace("-", "_") # TODO: regex sanitzer
s2 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s1)

return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).upper()
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s2).upper()
52 changes: 52 additions & 0 deletions python/codegen/codegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,43 @@ def _generate_code(
return dataclasses, enums


def _get_cross_namespace_imports(
namespace: str,
dataclasses: dict[str, GeneratedDataclass],
enums: dict[str, GeneratedEnum],
) -> tuple[list[str], dict[str, list[str]]]:
"""
Collect cross-namespace imports (types from other packages referenced by this namespace).
Returns (export_list, import_dict) where import_dict maps package -> list of class names.
"""
cross_namespace_packages = {}
cross_namespace_exports = []

root_package = packages.get_root_package(namespace)

# Collect all referenced packages from dataclasses
for dataclass in dataclasses.values():
for field in dataclass.fields:
if field.type_name.package and field.type_name.package != root_package:
if field.type_name.package not in cross_namespace_packages:
cross_namespace_packages[field.type_name.package] = []
class_name = field.type_name.name.strip('"')
if class_name not in cross_namespace_packages[field.type_name.package]:
cross_namespace_packages[field.type_name.package].append(class_name)
cross_namespace_exports.append(class_name)

# Collect all referenced packages from enums
for enum in enums.values():
if enum.package and enum.package != root_package:
if enum.package not in cross_namespace_packages:
cross_namespace_packages[enum.package] = []
if enum.class_name not in cross_namespace_packages[enum.package]:
cross_namespace_packages[enum.package].append(enum.class_name)
cross_namespace_exports.append(enum.class_name)

return cross_namespace_exports, cross_namespace_packages


def _write_exports(
namespace: str,
dataclasses: dict[str, GeneratedDataclass],
Expand All @@ -141,6 +178,13 @@ def _write_exports(
for _, enum in enums.items():
exports += [enum.class_name, f"{enum.class_name}Param"]

# Add cross-namespace imports to exports
cross_ns_exports, cross_ns_packages = _get_cross_namespace_imports(
namespace, dataclasses, enums
)
exports.extend(cross_ns_exports)

exports = list(set(exports)) # Remove duplicates
exports.sort()

b = CodeBuilder()
Expand All @@ -155,6 +199,14 @@ def _write_exports(
generated_imports.append_dataclass_imports(b, dataclasses, exclude_packages=[])
generated_imports.append_enum_imports(b, enums, exclude_packages=[])

# Add cross-namespace re-exports
for package, class_names in sorted(cross_ns_packages.items()):
for class_name in sorted(set(class_names)):
b.append(f"from {package} import {class_name}").newline()

if cross_ns_packages:
b.newline()

# FIXME should be better generalized
if namespace == "jobs":
_append_resolve_recursive_imports(b)
Expand Down
4 changes: 4 additions & 0 deletions python/codegen/codegen/packages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from typing import Optional

Expand All @@ -7,6 +8,7 @@
"resources.Pipeline": "pipelines",
"resources.Schema": "schemas",
"resources.Volume": "volumes",
"resources.ModelServingEndpoint": "model_serving_endpoints", # serving
}

RESOURCE_TYPES = list(RESOURCE_NAMESPACE.keys())
Expand All @@ -20,6 +22,8 @@
"pipelines",
"resources",
"catalog",
"model_serving_endpoints",
"serving", # this exists within model_serving_endpoints and for some reason needs to be loaded separately
]

RENAMES = {
Expand Down
2 changes: 2 additions & 0 deletions python/databricks/bundles/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"load_resources_from_module",
"load_resources_from_modules",
"load_resources_from_package_module",
"model_serving_endpoint_mutator",
"pipeline_mutator",
"schema_mutator",
"variables",
Expand All @@ -40,6 +41,7 @@
from databricks.bundles.core._resource_mutator import (
ResourceMutator,
job_mutator,
model_serving_endpoint_mutator,
pipeline_mutator,
schema_mutator,
volume_mutator,
Expand Down
39 changes: 39 additions & 0 deletions python/databricks/bundles/core/_resource_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

if TYPE_CHECKING:
from databricks.bundles.jobs._models.job import Job
from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint import (
ModelServingEndpoint,
)
from databricks.bundles.pipelines._models.pipeline import Pipeline
from databricks.bundles.schemas._models.schema import Schema
from databricks.bundles.volumes._models.volume import Volume
Expand Down Expand Up @@ -193,3 +196,39 @@ def my_volume_mutator(bundle: Bundle, volume: Volume) -> Volume:
from databricks.bundles.volumes._models.volume import Volume

return ResourceMutator(resource_type=Volume, function=function)


@overload
def model_serving_endpoint_mutator(
function: Callable[[Bundle, "ModelServingEndpoint"], "ModelServingEndpoint"],
) -> ResourceMutator["ModelServingEndpoint"]: ...


@overload
def model_serving_endpoint_mutator(
function: Callable[["ModelServingEndpoint"], "ModelServingEndpoint"],
) -> ResourceMutator["ModelServingEndpoint"]: ...


def model_serving_endpoint_mutator(
function: Callable,
) -> ResourceMutator["ModelServingEndpoint"]:
"""
Decorator for defining a model serving endpoint mutator. Function should return a new instance of the model serving endpoint with the desired changes,
instead of mutating the input model serving endpoint.

Example:

.. code-block:: python

@model_serving_endpoint_mutator
def my_model_serving_endpoint_mutator(bundle: Bundle, model_serving_endpoint: ModelServingEndpoint) -> ModelServingEndpoint:
return replace(model_serving_endpoint, name="my_model_serving_endpoint")

:param function: Function that mutates a model serving endpoint.
"""
from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint import (
ModelServingEndpoint,
)

return ResourceMutator(resource_type=ModelServingEndpoint, function=function)
8 changes: 8 additions & 0 deletions python/databricks/bundles/core/_resource_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def all(cls) -> tuple["_ResourceType", ...]:
# be imported in databricks.bundles.<resource_type>

from databricks.bundles.jobs._models.job import Job
from databricks.bundles.model_serving_endpoints._models.model_serving_endpoint import (
ModelServingEndpoint,
)
from databricks.bundles.pipelines._models.pipeline import Pipeline
from databricks.bundles.schemas._models.schema import Schema
from databricks.bundles.volumes._models.volume import Volume
Expand All @@ -57,4 +60,9 @@ def all(cls) -> tuple["_ResourceType", ...]:
plural_name="schemas",
singular_name="schema",
),
_ResourceType(
resource_type=ModelServingEndpoint,
plural_name="model_serving_endpoints",
singular_name="model_serving_endpoint",
),
)
51 changes: 51 additions & 0 deletions python/databricks/bundles/core/_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from databricks.bundles.jobs._models.job import Job, JobParam
from databricks.bundles.pipelines._models.pipeline import Pipeline, PipelineParam
from databricks.bundles.schemas._models.schema import Schema, SchemaParam
from databricks.bundles.serving._models.model_serving_endpoint import (
ModelServingEndpoint,
ModelServingEndpointParam,
)
from databricks.bundles.volumes._models.volume import Volume, VolumeParam

__all__ = ["Resources"]
Expand Down Expand Up @@ -60,6 +64,7 @@ def __init__(self):
self._pipelines = dict[str, "Pipeline"]()
self._schemas = dict[str, "Schema"]()
self._volumes = dict[str, "Volume"]()
self._model_serving_endpoints = dict[str, "ModelServingEndpoint"]()
self._locations = dict[tuple[str, ...], Location]()
self._diagnostics = Diagnostics()

Expand All @@ -79,6 +84,10 @@ def schemas(self) -> dict[str, "Schema"]:
def volumes(self) -> dict[str, "Volume"]:
return self._volumes

@property
def model_serving_endpoints(self) -> dict[str, "ModelServingEndpoint"]:
return self._model_serving_endpoints

@property
def diagnostics(self) -> Diagnostics:
"""
Expand All @@ -103,6 +112,7 @@ def add_resource(
"""

from databricks.bundles.jobs import Job
from databricks.bundles.model_serving_endpoints import ModelServingEndpoint
from databricks.bundles.pipelines import Pipeline
from databricks.bundles.schemas import Schema
from databricks.bundles.volumes import Volume
Expand All @@ -118,6 +128,10 @@ def add_resource(
self.add_schema(resource_name, resource, location=location)
case Volume():
self.add_volume(resource_name, resource, location=location)
case ModelServingEndpoint():
self.add_model_serving_endpoint(
resource_name, resource, location=location
)
case _:
raise ValueError(f"Unsupported resource type: {type(resource)}")

Expand Down Expand Up @@ -249,6 +263,40 @@ def add_volume(

self._volumes[resource_name] = volume

def add_model_serving_endpoint(
self,
resource_name: str,
model_serving_endpoint: "ModelServingEndpointParam",
*,
location: Optional[Location] = None,
) -> None:
"""
Adds a model serving endpoint to the collection of resources. Resource name must be unique across all model serving endpoints.

:param resource_name: unique identifier for the model serving endpoint
:param model_serving_endpoint: the model serving endpoint to add, can be ModelServingEndpoint or dict
:param location: optional location of the model serving endpoint in the source code
"""
from databricks.bundles.model_serving_endpoints import ModelServingEndpoint

model_serving_endpoint = _transform(
ModelServingEndpoint, model_serving_endpoint
)
path = ("resources", "model_serving_endpoints", resource_name)
location = location or Location.from_stack_frame(depth=1)

if self._model_serving_endpoints.get(resource_name):
self.add_diagnostic_error(
msg=f"Duplicate resource name '{resource_name}' for a model serving endpoint. Resource names must be unique.",
location=location,
path=path,
)
else:
if location:
self.add_location(path, location)

self._model_serving_endpoints[resource_name] = model_serving_endpoint

def add_location(self, path: tuple[str, ...], location: Location) -> None:
"""
Associate source code location with a path in the bundle configuration.
Expand Down Expand Up @@ -331,6 +379,9 @@ def add_resources(self, other: "Resources") -> None:
for name, volume in other.volumes.items():
self.add_volume(name, volume)

for name, model_serving_endpoint in other.model_serving_endpoints.items():
self.add_model_serving_endpoint(name, model_serving_endpoint)

for path, location in other._locations.items():
self.add_location(path, location)

Expand Down
10 changes: 10 additions & 0 deletions python/databricks/bundles/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@
"TriggerSettings",
"TriggerSettingsDict",
"TriggerSettingsParam",
"VariableOr",
"VariableOrDict",
"VariableOrList",
"VariableOrOptional",
"VolumesStorageInfo",
"VolumesStorageInfoDict",
"VolumesStorageInfoParam",
Expand All @@ -277,6 +281,12 @@
]


from databricks.bundles.core import (
VariableOr,
VariableOrDict,
VariableOrList,
VariableOrOptional,
)
from databricks.bundles.jobs._models.adlsgen2_info import (
Adlsgen2Info,
Adlsgen2InfoDict,
Expand Down
Loading