diff --git a/src/machinelearningservices/azext_mlv2/manual/custom/datastore.py b/src/machinelearningservices/azext_mlv2/manual/custom/datastore.py index f3327917c2e..cdbdde28cf3 100644 --- a/src/machinelearningservices/azext_mlv2/manual/custom/datastore.py +++ b/src/machinelearningservices/azext_mlv2/manual/custom/datastore.py @@ -11,12 +11,51 @@ from typing import Dict from azure.ai.ml.entities import Datastore +from azure.ai.ml.entities._datastore.azure_storage import AzureBlobDatastore, AzureDataLakeGen2Datastore, AzureFileDatastore from azure.ai.ml.entities._load_functions import load_datastore from .raise_error import log_and_raise_error from .utils import _dump_entity_with_warnings, get_ml_client, modify_sys_path_for_rslex_mount +_AZURE_STORAGE_DATASTORE_TYPES = (AzureBlobDatastore, AzureDataLakeGen2Datastore, AzureFileDatastore) + + +def _create_or_update_with_arm_scope(ml_client, datastore): + """Create or update a datastore, backfilling subscription and resource group. + + The SDK's ``_to_rest_object`` does not populate ``subscriptionId`` and + ``resourceGroup`` in the request body for Azure storage datastores. When + these fields are missing the created datastore lacks ARM scope, which + breaks downstream operations such as sharing data assets to a registry. + + This helper builds the REST object, injects the workspace's subscription + and resource group when the datastore entity does not carry them, and + then calls the service directly. + """ + ds_request = datastore._to_rest_object() # pylint: disable=protected-access + + if isinstance(datastore, _AZURE_STORAGE_DATASTORE_TYPES): + subscription_id = ml_client._operation_scope.subscription_id # pylint: disable=protected-access + resource_group = ml_client._operation_scope._resource_group_name # pylint: disable=protected-access + + props = ds_request.properties + if props is not None: + if not getattr(props, 'subscription_id', None): + props.subscription_id = subscription_id + if not getattr(props, 'resource_group', None): + props.resource_group = resource_group + + datastore_resource = ml_client.datastores._operation.create_or_update( # pylint: disable=protected-access + name=datastore.name, + resource_group_name=ml_client._operation_scope._resource_group_name, # pylint: disable=protected-access + workspace_name=ml_client.datastores._workspace_name, # pylint: disable=protected-access + body=ds_request, + skip_validation=True, + ) + return Datastore._from_rest_object(datastore_resource) # pylint: disable=protected-access + + def ml_datastore_delete(cmd, resource_group_name, workspace_name, name): ml_client, debug = get_ml_client( cli_ctx=cmd.cli_ctx, resource_group_name=resource_group_name, workspace_name=workspace_name @@ -77,7 +116,7 @@ def ml_datastore_create(cmd, resource_group_name, workspace_name, file, name=Non try: datastore = load_datastore(file, params_override=params_override) - return ml_client.datastores.create_or_update(datastore)._to_dict() # pylint: disable=protected-access + return _create_or_update_with_arm_scope(ml_client, datastore)._to_dict() # pylint: disable=protected-access except Exception as err: # pylint: disable=broad-exception-caught yaml_operation = bool(file) log_and_raise_error(err, debug, yaml_operation=yaml_operation) @@ -90,7 +129,7 @@ def ml_datastore_update(cmd, resource_group_name, workspace_name, parameters: Di try: datastore = Datastore._load(parameters) # pylint: disable=protected-access - return ml_client.datastores.create_or_update(datastore)._to_dict() # pylint: disable=protected-access + return _create_or_update_with_arm_scope(ml_client, datastore)._to_dict() # pylint: disable=protected-access except Exception as err: # pylint: disable=broad-exception-caught log_and_raise_error(err, debug)