Skip to content
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

### Bugs Fixed

- Fixed `BatchEndpoint` defaults serialization regression where `deployment_name` was sent to the service as snake_case instead of camelCase (`deploymentName`), causing `begin_create_or_update` to fail with "Could not find member 'deployment_name' on object of type 'BatchEndpointDefaults'". `BatchEndpoint.defaults` is now consistently exposed as a snake_case dict to users and converted to the correct wire format on serialization.
- Fixed cross-tenant registry endpoint resolution for deployment template operations by using the registry discovery API instead of ARM calls.
- Fixed deployment template update failing with immutable field errors by ensuring `allowedInstanceType` and `allowedEnvironmentVariableOverrides` are properly round-tripped during serialization.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from marshmallow import fields, post_load

from azure.ai.ml._restclient.arm_ml_service.models import BatchEndpointDefaults
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta

module_logger = logging.getLogger(__name__)
Expand All @@ -25,4 +24,4 @@ class BatchEndpointsDefaultsSchema(metaclass=PatchedSchemaMeta):

@post_load
def make(self, data: Any, **kwargs: Any) -> Any:
return BatchEndpointDefaults(**data)
return dict(data)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import IO, Any, AnyStr, Dict, Optional, Union

from azure.ai.ml._restclient.arm_ml_service.models import BatchEndpoint as BatchEndpointData
from azure.ai.ml._restclient.arm_ml_service.models import BatchEndpointDefaults as RestBatchEndpointDefaults
from azure.ai.ml._restclient.arm_ml_service.models import BatchEndpointProperties as RestBatchEndpoint
from azure.ai.ml._schema._endpoint import BatchEndpointSchema
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
Expand Down Expand Up @@ -80,16 +81,29 @@ def __init__(

def _to_rest_batch_endpoint(self, location: str) -> BatchEndpointData:
validate_endpoint_or_deployment_name(self.name)
defaults: Optional[RestBatchEndpointDefaults] = None
if isinstance(self.defaults, RestBatchEndpointDefaults):
defaults = self.defaults
elif isinstance(self.defaults, dict) and self.defaults:
defaults = RestBatchEndpointDefaults(**self.defaults)
Comment thread
ayushhgarg-work marked this conversation as resolved.
Outdated
batch_endpoint = RestBatchEndpoint(
description=self.description,
auth_mode=snake_to_camel(self.auth_mode),
properties=self.properties,
defaults=self.defaults,
defaults=defaults,
)
return BatchEndpointData(location=location, tags=self.tags, properties=batch_endpoint)

@classmethod
def _from_rest_object(cls, obj: BatchEndpointData) -> "BatchEndpoint":
defaults: Optional[Dict[str, str]] = None
rest_defaults = obj.properties.defaults
if isinstance(rest_defaults, RestBatchEndpointDefaults):
if rest_defaults.deployment_name is not None:
defaults = {"deployment_name": rest_defaults.deployment_name}
elif isinstance(rest_defaults, dict):
defaults = dict(rest_defaults) if rest_defaults else None
Comment thread
ayushhgarg-work marked this conversation as resolved.

return BatchEndpoint(
id=obj.id,
name=obj.name,
Expand All @@ -98,7 +112,7 @@ def _from_rest_object(cls, obj: BatchEndpointData) -> "BatchEndpoint":
auth_mode=camel_to_snake(obj.properties.auth_mode),
description=obj.properties.description,
location=obj.location,
defaults=obj.properties.defaults,
defaults=defaults,
provisioning_state=obj.properties.provisioning_state,
scoring_uri=obj.properties.scoring_uri,
openapi_uri=obj.properties.swagger_uri,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,41 @@ def test_batch_endpoint_with_deployment_no_defaults(self) -> None:

assert endpoint.defaults is None

def test_to_rest_batch_endpoint_serializes_defaults_to_camel_case(self) -> None:
endpoint = BatchEndpoint(
name="my-batch-endpoint",
auth_mode="aad_token",
defaults={"deployment_name": "my-deployment"},
)

rest_batch_endpoint = endpoint._to_rest_batch_endpoint("eastus")
assert endpoint.defaults == {"deployment_name": "my-deployment"}
rest_defaults = rest_batch_endpoint.properties.defaults
assert rest_defaults is not None
assert rest_defaults.deployment_name == "my-deployment"
serialized = rest_defaults.as_dict()
assert serialized == {"deploymentName": "my-deployment"}
assert "deployment_name" not in serialized

def test_to_rest_batch_endpoint_with_no_defaults_passes_none(self) -> None:
endpoint = BatchEndpoint(
name="my-batch-endpoint",
auth_mode="aad_token",
)

rest_batch_endpoint = endpoint._to_rest_batch_endpoint("eastus")

assert rest_batch_endpoint.properties.defaults is None

def test_from_rest_object_defaults_returned_as_snake_case_dict(self) -> None:
with open(TestBatchEndpointYAML.BATCH_ENDPOINT_REST, "r") as f:
batch_endpoint_rest = _deserialize(BatchEndpointData, json.load(f))

batch_endpoint = BatchEndpoint._from_rest_object(batch_endpoint_rest)

assert batch_endpoint.defaults == {"deployment_name": "hello-world-1"}
assert batch_endpoint.defaults["deployment_name"] == "hello-world-1"


class TestKubernetesOnlineEndopint:
K8S_ONLINE_ENDPOINT = "tests/test_configs/endpoints/online/online_endpoint_create_k8s.yml"
Expand Down
Loading