Skip to content

Commit 2e850a0

Browse files
committed
[Refractor] - Migrated remove_vm_identity function
1 parent b342b3b commit 2e850a0

3 files changed

Lines changed: 152 additions & 7 deletions

File tree

src/azure-cli/azure/cli/command_modules/vm/_vm_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import re
99
import importlib
10+
from enum import Enum
1011

1112
from urllib.parse import urlparse
1213

@@ -831,4 +832,11 @@ def resolve_role_id(cli_ctx, role, scope):
831832
err = "More than one role matches the given name '{}'. Please pick an id from '{}'"
832833
raise CLIError(err.format(role, ids))
833834
role_id = role_defs[0].id
834-
return role_id
835+
return role_id
836+
837+
838+
class IdentityType(Enum):
839+
SYSTEM_ASSIGNED = 'SystemAssigned'
840+
USER_ASSIGNED = "UserAssigned"
841+
SYSTEM_ASSIGNED_USER_ASSIGNED = "SystemAssigned, UserAssigned"
842+
NONE = 'None'

src/azure-cli/azure/cli/command_modules/vm/custom.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,18 +2545,81 @@ def _remove_identities(cmd, resource_group_name, name, identities, getter, sette
25452545
return result.identity
25462546

25472547

2548+
def _remove_identities_by_aaz(cmd, resource_group_name, name, identities, getter, setter):
2549+
from ._vm_utils import MSI_LOCAL_ID, IdentityType
2550+
2551+
remove_system_assigned_identity = False
2552+
2553+
if MSI_LOCAL_ID in identities:
2554+
remove_system_assigned_identity = True
2555+
identities.remove(MSI_LOCAL_ID)
2556+
2557+
resource = getter(cmd, resource_group_name, name)
2558+
existing_identity = resource.get('identity')
2559+
2560+
if existing_identity is None:
2561+
return None
2562+
2563+
existing_emsis = [x.lower() for x in list((existing_identity.get('userAssignedIdentities') or {}).keys())]
2564+
2565+
if identities:
2566+
emsis_to_remove = [x.lower() for x in identities]
2567+
2568+
non_existing = [emsis for emsis in emsis_to_remove if not emsis in existing_emsis]
2569+
if non_existing:
2570+
raise CLIError("'{}' are not associated with '{}'".format(','.join(non_existing), name))
2571+
2572+
emsis_to_be_retain = [emsis for emsis in existing_emsis if not emsis in emsis_to_remove]
2573+
2574+
if not emsis_to_be_retain: # if all emsis are gone, we need to update the type
2575+
if existing_identity['type'] == IdentityType.USER_ASSIGNED.value:
2576+
existing_identity['type'] = IdentityType.NONE.value
2577+
elif existing_identity['type'] == IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value:
2578+
existing_identity['type'] = IdentityType.SYSTEM_ASSIGNED.value
2579+
2580+
existing_identity['userAssignedIdentities'] = {}
2581+
for emsis in identities:
2582+
existing_identity['userAssignedIdentities'][emsis] = {}
2583+
else:
2584+
existing_identity['userAssignedIdentities'] = None
2585+
2586+
if remove_system_assigned_identity:
2587+
if existing_identity['type'] == IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value or existing_identity['type'] == IdentityType.USER_ASSIGNED.value:
2588+
existing_identity['type'] = IdentityType.USER_ASSIGNED.value
2589+
else:
2590+
existing_identity['type'] = IdentityType.NONE.value
2591+
2592+
result = LongRunningOperation(cmd.cli_ctx)(setter(resource_group_name, name, resource))
2593+
return result.get('identity') or None
2594+
2595+
25482596
def remove_vm_identity(cmd, resource_group_name, vm_name, identities=None):
25492597
def setter(resource_group_name, vm_name, vm):
2550-
client = _compute_client_factory(cmd.cli_ctx)
2551-
VirtualMachineUpdate = cmd.get_models('VirtualMachineUpdate', operation_group='virtual_machines')
2552-
vm_update = VirtualMachineUpdate(identity=vm.identity)
2553-
return client.virtual_machines.begin_update(resource_group_name, vm_name, vm_update)
2598+
command_args = {
2599+
'resource_group': resource_group_name,
2600+
'vm_name': vm_name
2601+
}
2602+
2603+
from ._vm_utils import IdentityType
2604+
if vm.get('identity') and vm.get('identity').get('type') == IdentityType.USER_ASSIGNED.value:
2605+
command_args['mi_user_assigned'] = [key for key in list((vm.get('identity').get('userAssignedIdentities') or {}).keys())] + ['UserAssigned']
2606+
elif vm.get('identity') and vm.get('identity').get('type') == IdentityType.SYSTEM_ASSIGNED.value:
2607+
command_args['mi_user_assigned'] = []
2608+
command_args['mi_system_assigned'] = 'True'
2609+
elif vm.get('identity') and vm.get('identity').get('type') == IdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED.value:
2610+
command_args['mi_user_assigned'] = [key for key in list((vm.get('identity').get('userAssignedIdentities') or {}).keys())]
2611+
command_args['mi_system_assigned'] = 'True'
2612+
else:
2613+
command_args['mi_user_assigned'] = []
2614+
2615+
from .operations.vm import VMIdentityRemove
2616+
return VMIdentityRemove(cli_ctx=cmd.cli_ctx)(command_args=command_args)
25542617

25552618
if identities is None:
25562619
from ._vm_utils import MSI_LOCAL_ID
25572620
identities = [MSI_LOCAL_ID]
25582621

2559-
return _remove_identities(cmd, resource_group_name, vm_name, identities, get_vm, setter)
2622+
return _remove_identities_by_aaz(cmd, resource_group_name, vm_name, identities, get_vm_migrated, setter)
25602623

25612624

25622625
# region VirtualMachines Images

src/azure-cli/azure/cli/command_modules/vm/operations/vm.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
55
# pylint: disable=no-self-use, line-too-long, protected-access, too-few-public-methods, unused-argument, too-many-statements, too-many-branches, too-many-locals
6+
from typing import override
7+
import json
8+
69
from knack.log import get_logger
710

811
from azure.cli.core.aaz import AAZStrType
9-
from ..aaz.latest.vm import (Show as _VMShow, ListSizes as _VMListSizes,
12+
from ..aaz.latest.vm import (Show as _VMShow, ListSizes as _VMListSizes, Patch as _VMPatch,
1013
Update as _VMUpdate, Capture as _VMCapture, Create as _VMCreate)
14+
from .._vm_utils import IdentityType
1115

1216
logger = get_logger(__name__)
1317

@@ -155,6 +159,76 @@ def _output(self, *args, **kwargs):
155159
return result
156160

157161

162+
class VMIdentityRemove(_VMPatch):
163+
def _output(self, *args, **kwargs):
164+
result = self.deserialize_output(self.ctx.vars.instance, client_flatten=True)
165+
166+
identity = result.get('identity')
167+
if not identity:
168+
return result
169+
170+
if not identity.get('principalId'):
171+
identity['principalId'] = None
172+
173+
if not identity.get('tenantId'):
174+
identity['tenantId'] = None
175+
176+
if not identity.get('userAssignedIdentities'):
177+
identity['userAssignedIdentities'] = None
178+
179+
return result
180+
181+
class VirtualMachinesUpdate(_VMPatch.VirtualMachinesUpdate):
182+
def _format_content(self, content):
183+
if type(content) == str:
184+
content = json.loads(content)
185+
186+
if not content.get('identity'):
187+
content['identity'] = {
188+
'userAssignedIdentities': None,
189+
'type': IdentityType.NONE.value
190+
}
191+
return json.dumps(content)
192+
193+
identities = content.get('identity', {}).get('userAssignedIdentities')
194+
if identities:
195+
if 'UserAssigned' in identities.keys():
196+
identities.pop('UserAssigned')
197+
198+
for key in identities.keys():
199+
identities[key] = None
200+
201+
if len(content.get('identity', {}).get('userAssignedIdentities', {}).keys()) < 1:
202+
content['identity']['userAssignedIdentities'] = None
203+
204+
return json.dumps(content)
205+
206+
def __call__(self, *args, **kwargs):
207+
request = self.make_request()
208+
request.data = self._format_content(request.data)
209+
session = self.client.send_request(request=request, stream=False, **kwargs)
210+
if session.http_response.status_code in [202]:
211+
return self.client.build_lro_polling(
212+
self.ctx.args.no_wait,
213+
session,
214+
self.on_200,
215+
self.on_error,
216+
lro_options={"final-state-via": "azure-async-operation"},
217+
path_format_arguments=self.url_parameters,
218+
)
219+
if session.http_response.status_code in [200]:
220+
return self.client.build_lro_polling(
221+
self.ctx.args.no_wait,
222+
session,
223+
self.on_200,
224+
self.on_error,
225+
lro_options={"final-state-via": "azure-async-operation"},
226+
path_format_arguments=self.url_parameters,
227+
)
228+
229+
return self.on_error(session.http_response)
230+
231+
158232
def convert_show_result_to_snake_case(result):
159233
new_result = {}
160234
if "extendedLocation" in result:

0 commit comments

Comments
 (0)