Skip to content

Commit 43744d0

Browse files
Fix DeploymentTemplate: cross-tenant endpoint resolution and update round-trip for immutable fields (#46719)
* Fix DT update: cross-tenant endpoint resolution and round-trip for immutable fields * Add display_name support to DeploymentTemplate entity and schema * Revert "Add display_name support to DeploymentTemplate entity and schema" This reverts commit 6ce4d07. * Add changelog entries for DT bug fixes * Fix pylint line-too-long and black formatting issues
1 parent f426769 commit 43744d0

3 files changed

Lines changed: 48 additions & 27 deletions

File tree

sdk/ml/azure-ai-ml/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
### Bugs Fixed
88

9+
- Fixed cross-tenant registry endpoint resolution for deployment template operations by using the registry discovery API instead of ARM calls.
10+
- Fixed deployment template update failing with immutable field errors by ensuring `allowedInstanceType` and `allowedEnvironmentVariableOverrides` are properly round-tripped during serialization.
11+
912
### Other Changes
1013

1114
## 1.32.0 (unreleased)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__( # pylint: disable=too-many-locals
104104
)
105105
self.allowed_instance_types = allowed_instance_types
106106
self.default_instance_type = default_instance_type
107+
self._allowed_environment_variable_overrides = None
107108
self.scoring_port = scoring_port
108109
self.scoring_path = scoring_path
109110
self.model_mount_path = model_mount_path
@@ -368,6 +369,16 @@ def get_value(source, key, default=None):
368369
allowed_instance_types = get_value(properties, "allowedInstanceTypes") or get_value(
369370
obj, "allowed_instance_types"
370371
)
372+
# Also check additional_properties for service fields with mismatched names
373+
if not allowed_instance_types:
374+
additional_props = get_value(obj, "additional_properties", {})
375+
if isinstance(additional_props, dict):
376+
allowed_instance_types = additional_props.get("allowedInstanceType") or additional_props.get(
377+
"allowedInstanceTypes"
378+
)
379+
allowed_environment_variable_overrides = get_value(
380+
properties, "allowedEnvironmentVariableOverrides"
381+
) or get_value(obj, "allowed_environment_variable_overrides")
371382
scoring_port = get_value(properties, "scoringPort") or get_value(obj, "scoring_port")
372383
scoring_path = get_value(properties, "scoringPath") or get_value(obj, "scoring_path")
373384
model_mount_path = get_value(properties, "modelMountPath") or get_value(obj, "model_mount_path")
@@ -399,6 +410,13 @@ def get_value(source, key, default=None):
399410
except (ValueError, SyntaxError):
400411
allowed_instance_types = None
401412

413+
# Parse allowed_environment_variable_overrides if it's a string
414+
if isinstance(allowed_environment_variable_overrides, str):
415+
try:
416+
allowed_environment_variable_overrides = ast.literal_eval(allowed_environment_variable_overrides)
417+
except (ValueError, SyntaxError):
418+
allowed_environment_variable_overrides = None
419+
402420
# Convert request_settings to OnlineRequestSettings object using the built-in conversion method
403421
request_settings_obj = OnlineRequestSettings._from_rest_object(request_settings) if request_settings else None
404422

@@ -451,6 +469,9 @@ def get_value(source, key, default=None):
451469
# updates
452470
template._from_service = True
453471

472+
# Store allowed_environment_variable_overrides as private field for round-trip
473+
template._allowed_environment_variable_overrides = allowed_environment_variable_overrides
474+
454475
# Store additional fields from the REST response that may be needed
455476
template.environment_id = environment_id # type: ignore[attr-defined]
456477
# Alternative name for deployment_template_type
@@ -472,6 +493,7 @@ def get_value(source, key, default=None):
472493
"app_insights_enabled": get_value(obj, "app_insights_enabled"),
473494
"deployment_template_type": deployment_template_type,
474495
"allowed_instance_types": allowed_instance_types,
496+
"allowed_environment_variable_overrides": allowed_environment_variable_overrides,
475497
"scoring_port": scoring_port,
476498
"scoring_path": scoring_path,
477499
"model_mount_path": model_mount_path,
@@ -571,6 +593,10 @@ def _to_rest_object(self) -> dict:
571593
# Handle allowed instance types
572594
if hasattr(self, "allowed_instance_types") and self.allowed_instance_types:
573595
result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment]
596+
result["allowedInstanceType"] = self.allowed_instance_types # type: ignore[assignment]
597+
598+
if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides:
599+
result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides
574600

575601
return result
576602

@@ -637,6 +663,8 @@ def _to_dict(self) -> Dict:
637663
# Add instance configuration
638664
if hasattr(self, "allowed_instance_types") and self.allowed_instance_types:
639665
result["allowedInstanceTypes"] = self.allowed_instance_types # type: ignore[assignment]
666+
if hasattr(self, "_allowed_environment_variable_overrides") and self._allowed_environment_variable_overrides:
667+
result["allowedEnvironmentVariableOverrides"] = self._allowed_environment_variable_overrides
640668
if self.default_instance_type:
641669
result["defaultInstanceType"] = self.default_instance_type
642670
elif self.instance_type:

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_deployment_template_operations.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,21 @@ def __init__(
4545
def _get_registry_endpoint(self) -> str:
4646
"""Dynamically determine the registry endpoint based on registry region.
4747
48+
Uses the registry discovery API (which does not require ARM access to the
49+
registry's subscription) to resolve the primary region, then constructs the
50+
appropriate dataplane endpoint.
51+
4852
:return: The API endpoint URL for the registry
4953
:rtype: str
5054
"""
5155
try:
52-
# Import here to avoid circular dependencies
53-
from azure.ai.ml._restclient.v2022_10_01_preview import (
54-
AzureMachineLearningWorkspaces as ServiceClient102022,
56+
from azure.ai.ml._azure_environments import (
57+
_get_default_cloud_name,
58+
_get_registry_discovery_endpoint_from_metadata,
59+
)
60+
from azure.ai.ml._restclient.registry_discovery import (
61+
RegistryDiscoveryClient as ServiceClientRegistryDiscovery,
5562
)
56-
from azure.ai.ml.operations import RegistryOperations
5763

5864
# Try to get credential from service client or operation config
5965
credential = None
@@ -63,31 +69,15 @@ def _get_registry_endpoint(self) -> str:
6369
credential = self._operation_config.credential
6470

6571
if credential and self._operation_scope.registry_name:
66-
# Get registry information to determine the region
67-
registry_operations = RegistryOperations(
68-
operation_scope=self._operation_scope,
69-
service_client=ServiceClient102022(
70-
credential=credential,
71-
subscription_id=self._operation_scope.subscription_id,
72-
resource_group_name=self._operation_scope.resource_group_name,
73-
),
74-
all_operations=None, # type: ignore[arg-type]
75-
credentials=credential,
72+
# Use registry discovery API to get the primary region
73+
discovery_base_url = _get_registry_discovery_endpoint_from_metadata(_get_default_cloud_name())
74+
discovery_client = ServiceClientRegistryDiscovery(credential=credential, base_url=discovery_base_url)
75+
response = discovery_client.registry_management_non_workspace.get_registry_management_non_workspace(
76+
self._operation_scope.registry_name
7677
)
7778

78-
registry = registry_operations.get(self._operation_scope.registry_name)
79-
80-
# Extract region from registry location or replication locations
81-
region = None
82-
if registry.location:
83-
region = registry.location
84-
elif registry.replication_locations and len(registry.replication_locations) > 0:
85-
region = registry.replication_locations[0].location
86-
87-
if region:
88-
# Format the endpoint using the detected region
89-
# return f"https://int.experiments.azureml-test.net"
90-
return f"https://{region}.api.azureml.ms"
79+
if response.primary_region:
80+
return f"https://{response.primary_region}.api.azureml.ms"
9181

9282
except Exception as e:
9383
module_logger.debug("Could not determine registry region dynamically: %s. Using default.", e)

0 commit comments

Comments
 (0)