-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Bug Report: ModelBuilder MLflow Path Resolution Fails for Models Logged with MLflow 3.x
PySDK Version
- PySDK V2 (2.x)
- PySDK V3 (3.x)
Describe the bug
When using ModelBuilder with MLflow Model Registry paths (models:/model-name/version), the SDK fails to resolve the correct S3 artifact path for models that were logged with MLflow 3.x.
The root cause is that ModelVersion.source returns different values depending on which MLflow version was used to log the model:
| Model Logged With | ModelVersion.source Returns |
|---|---|
| MLflow 2.x | s3://bucket/1/run_id/artifacts/model (direct S3 path) ✅ |
| MLflow 3.x | models:/m-16bfb69f57fc45869a5f407ccd596fe1 (internal URI) ❌ |
The SDK currently uses .source which only works for MLflow 2.x-logged models. For MLflow 3.x-logged models, it returns an unusable internal URI that cannot be used to download artifacts.
To reproduce
import mlflow
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.mode.function_pointers import Mode
# Prerequisites:
# 1. Have a SageMaker Managed MLflow tracking server
# 2. Log a model using MLflow 3.x (e.g., mlflow.pytorch.log_model())
# 3. Register the model to the MLflow Model Registry
MLFLOW_TRACKING_ARN = "arn:aws:sagemaker:us-east-1:123456789:mlflow-app/app-XXXXX"
mlflow.set_tracking_uri(MLFLOW_TRACKING_ARN)
# Verify the issue - check what .source returns vs get_model_version_download_uri()
client = mlflow.MlflowClient()
model_version_info = client.get_model_version("pytorch-simple-classifier", "1")
print(f"ModelVersion.source: {model_version_info.source}")
# For MLflow 3.x logged models: "models:/m-16bfb69f57fc45869a5f407ccd596fe1" (WRONG)
print(f"get_model_version_download_uri: {client.get_model_version_download_uri('pytorch-simple-classifier', '1')}")
# Always returns: "s3://bucket/path/to/artifacts" (CORRECT)
# Now try to use ModelBuilder - this will fail
schema_builder = SchemaBuilder(
sample_input=[[0.1, 0.2, 0.3, 0.4]],
sample_output=[[0.8, 0.2]]
)
model_builder = ModelBuilder(
mode=Mode.SAGEMAKER_ENDPOINT,
schema_builder=schema_builder,
model_metadata={
"MLFLOW_MODEL_PATH": "models:/pytorch-simple-classifier/1",
"MLFLOW_TRACKING_ARN": MLFLOW_TRACKING_ARN
}
)
# This fails because the SDK tries to use the internal "models:/m-xxx" URI
# instead of the actual S3 path
model = model_builder.build()Expected behavior
ModelBuilder.build() should successfully resolve the MLflow registry path to the correct S3 artifact location regardless of which MLflow version was used to log the model.
The SDK should use mlflow_client.get_model_version_download_uri(model_name, model_version) instead of model_version_info.source because get_model_version_download_uri() always returns the correct S3 path for all MLflow versions.
Screenshots or logs
Error when trying to download artifacts from the internal URI:
botocore.exceptions.ClientError: An error occurred (403) when calling the HeadObject operation: Forbidden
Or path parsing errors when the SDK tries to interpret models:/m-xxx as an S3 path.
System information
- SageMaker Python SDK version: 3.3.1
- Framework name: PyTorch
- Framework version: 2.5
- Python version: 3.11
- CPU or GPU: CPU (ml.m5.xlarge)
- Custom Docker image: N
Additional:
- MLflow version: 3.4.0+ (client-side when logging the model)
- SageMaker Managed MLflow version: 3.4.0
Additional context
Affected Code Location
- File:
sagemaker-serve/src/sagemaker/serve/model_builder_utils.py - Method:
_get_artifact_path()(lines ~1393-1399)
Current Implementation (Problematic)
The code currently retrieves the source from ModelVersion.source:
# Current code in _get_artifact_path()
model_version_info = mlflow_client.get_model_version(model_name, model_version)
source = model_version_info.source # Returns internal URI for MLflow 3.x modelsSuggested Fix
Use get_model_version_download_uri() which always returns the correct S3 path:
# Fixed code
source = mlflow_client.get_model_version_download_uri(model_name, model_version)This method works correctly for models logged with any MLflow version:
| Model Logged With | get_model_version_download_uri() Returns |
|---|---|
| MLflow 2.x | s3://bucket/1/run_id/artifacts/model ✅ |
| MLflow 3.x | s3://bucket/3/models/m-xxx/artifacts ✅ |
Workaround
Until this is fixed, users can work around the issue by:
- Using the direct S3 path instead of the registry path
- Pinning MLflow to 2.x when logging models
- Manually calling
get_model_version_download_uri()and passing the S3 path directly