diff --git a/src/aks-preview/HISTORY.rst b/src/aks-preview/HISTORY.rst index 836e4937216..92e45ec3257 100644 --- a/src/aks-preview/HISTORY.rst +++ b/src/aks-preview/HISTORY.rst @@ -12,7 +12,10 @@ To release a new version, please select a new version number (usually plus 1 to Pending +++++++ * Fix `match_condition` kwarg leaking to HTTP transport by overriding `put_mc` and `add_agentpool` to pass `if_match` / `if_none_match` directly to the vendored SDK. This change fixes the compatibility issue as azure-cli/acs module adopts TypeSpec emitted SDKs while azure-cli-extensions/aks-preview still uses the autorest emitted SDK. -+ `az aks list-vm-skus`: New command to list available VM SKUs for AKS clusters in a given region. +* `az aks list-vm-skus`: New command to list available VM SKUs for AKS clusters in a given region. +* Add managed GPU enablement option to node pool property in `az aks nodepool add` and `az aks nodepool update`. + + 19.0.0b27 +++++++ diff --git a/src/aks-preview/azcli_aks_live_test/configs/ext_matrix_default.json b/src/aks-preview/azcli_aks_live_test/configs/ext_matrix_default.json index cbec9591737..e3e54bb1540 100644 --- a/src/aks-preview/azcli_aks_live_test/configs/ext_matrix_default.json +++ b/src/aks-preview/azcli_aks_live_test/configs/ext_matrix_default.json @@ -22,7 +22,9 @@ ], "gpu, no quota": [ "test_aks_nodepool_add_with_gpu_instance_profile", - "test_aks_gpu_driver_type" + "test_aks_gpu_driver_type", + "test_aks_nodepool_add_with_enable_managed_gpu", + "test_aks_nodepool_update_with_enable_managed_gpu" ], "pod ip allocation mode static block, missing feature registration": [ "test_aks_create_with_pod_ip_allocation_mode_static_block" diff --git a/src/aks-preview/azext_aks_preview/_consts.py b/src/aks-preview/azext_aks_preview/_consts.py index 9947552d15a..10c430a9ce6 100644 --- a/src/aks-preview/azext_aks_preview/_consts.py +++ b/src/aks-preview/azext_aks_preview/_consts.py @@ -76,6 +76,10 @@ CONST_GPU_DRIVER_INSTALL = "Install" CONST_GPU_DRIVER_NONE = "None" +# gpu management mode +CONST_GPU_MANAGEMENT_MODE_MANAGED = "Managed" +CONST_GPU_MANAGEMENT_MODE_UNMANAGED = "Unmanaged" + # consts for ManagedCluster # load balancer sku CONST_LOAD_BALANCER_SKU_BASIC = "basic" diff --git a/src/aks-preview/azext_aks_preview/_help.py b/src/aks-preview/azext_aks_preview/_help.py index 17581041e38..dc751bec483 100644 --- a/src/aks-preview/azext_aks_preview/_help.py +++ b/src/aks-preview/azext_aks_preview/_help.py @@ -2203,6 +2203,9 @@ - name: --enable-artifact-streaming type: bool short-summary: Enable artifact streaming for VirtualMachineScaleSets managed by a node pool, to speed up the cold-start of containers on a node through on-demand image loading. To use this feature, container images must also enable artifact streaming on ACR. If not specified, the default is false. + - name: --enable-managed-gpu + type: bool + short-summary: Enable the Managed GPU experience, which installs additional components like DCGM metrics for monitoring on top of the GPU driver. For more details, visit aka.ms/aks/managed-gpu. - name: --skip-gpu-driver-install type: bool short-summary: To skip GPU driver auto installation by AKS on a nodepool using GPU vm size if customers want to manage GPU driver installation by their own. If not specified, the default is false. @@ -2419,6 +2422,9 @@ - name: --enable-artifact-streaming type: bool short-summary: Enable artifact streaming for VirtualMachineScaleSets managed by a node pool, to speed up the cold-start of containers on a node through on-demand image loading. To use this feature, container images must also enable artifact streaming on ACR. If not specified, the default is false. + - name: --enable-managed-gpu + type: bool + short-summary: Enable the Managed GPU experience, which installs additional components like DCGM metrics for monitoring on top of the GPU driver. For more details, visit aka.ms/aks/managed-gpu. - name: --os-sku type: string short-summary: The os-sku of the agent node pool. diff --git a/src/aks-preview/azext_aks_preview/_params.py b/src/aks-preview/azext_aks_preview/_params.py index c51e8bcfdb2..b7a41f70400 100644 --- a/src/aks-preview/azext_aks_preview/_params.py +++ b/src/aks-preview/azext_aks_preview/_params.py @@ -2031,6 +2031,12 @@ def load_arguments(self, _): validator=validate_artifact_streaming, is_preview=True, ) + c.argument( + "enable_managed_gpu", + action="store_true", + is_preview=True, + help="Enable the Managed GPU experience.", + ) c.argument( "node_public_ip_tags", arg_type=tags_type, @@ -2140,6 +2146,12 @@ def load_arguments(self, _): validator=validate_artifact_streaming, is_preview=True, ) + c.argument( + "enable_managed_gpu", + action="store_true", + is_preview=True, + help="Enable the Managed GPU experience.", + ) c.argument( "os_sku", arg_type=get_enum_type(node_os_skus_update), diff --git a/src/aks-preview/azext_aks_preview/agentpool_decorator.py b/src/aks-preview/azext_aks_preview/agentpool_decorator.py index d3d522afa73..6446ac61edc 100644 --- a/src/aks-preview/azext_aks_preview/agentpool_decorator.py +++ b/src/aks-preview/azext_aks_preview/agentpool_decorator.py @@ -44,7 +44,10 @@ CONST_DEFAULT_WINDOWS_VMS_VM_SIZE, CONST_MANAGED_CLUSTER_SKU_NAME_AUTOMATIC, CONST_SSH_ACCESS_LOCALUSER, + CONST_GPU_DRIVER_INSTALL, CONST_GPU_DRIVER_NONE, + CONST_GPU_MANAGEMENT_MODE_MANAGED, + CONST_GPU_MANAGEMENT_MODE_UNMANAGED, CONST_NODEPOOL_MODE_MANAGEDSYSTEM, CONST_NODEPOOL_MODE_MACHINES, ) @@ -587,6 +590,27 @@ def get_enable_artifact_streaming(self) -> bool: enable_artifact_streaming = self.agentpool.artifact_streaming_profile.enabled return enable_artifact_streaming + def get_enable_managed_gpu(self) -> Union[bool, None]: + """Obtain the value of enable_managed_gpu. + :return: bool + """ + + # read the original value passed by the command + enable_managed_gpu = self.raw_param.get("enable_managed_gpu") + + # In create mode, try to read the property value corresponding to the parameter from the `agentpool` object + if self.decorator_mode == DecoratorMode.CREATE: + if ( + self.agentpool and + self.agentpool.gpu_profile is not None and + self.agentpool.gpu_profile.nvidia is not None and + self.agentpool.gpu_profile.nvidia.management_mode is not None + ): + enable_managed_gpu = ( + self.agentpool.gpu_profile.nvidia.management_mode == CONST_GPU_MANAGEMENT_MODE_MANAGED + ) + return enable_managed_gpu + def get_pod_ip_allocation_mode(self: bool = False) -> Union[str, None]: """Get the value of pod_ip_allocation_mode. :return: str or None @@ -1276,6 +1300,21 @@ def set_up_artifact_streaming(self, agentpool: AgentPool) -> AgentPool: agentpool.artifact_streaming_profile.enabled = True return agentpool + def set_up_managed_gpu(self, agentpool: AgentPool) -> AgentPool: + """Set up managed GPU property for the AgentPool object.""" + self._ensure_agentpool(agentpool) + + enable_managed_gpu = self.context.get_enable_managed_gpu() + + if enable_managed_gpu: + if agentpool.gpu_profile is None: + agentpool.gpu_profile = self.models.GPUProfile() # pylint: disable=no-member + if agentpool.gpu_profile.nvidia is None: + agentpool.gpu_profile.nvidia = self.models.NvidiaGPUProfile() # pylint: disable=no-member + agentpool.gpu_profile.nvidia.management_mode = CONST_GPU_MANAGEMENT_MODE_MANAGED + agentpool.gpu_profile.driver = CONST_GPU_DRIVER_INSTALL + return agentpool + def set_up_ssh_access(self, agentpool: AgentPool) -> AgentPool: self._ensure_agentpool(agentpool) @@ -1510,6 +1549,8 @@ def construct_agentpool_profile_preview(self) -> AgentPool: agentpool = self.set_up_init_taints(agentpool) # set up artifact streaming agentpool = self.set_up_artifact_streaming(agentpool) + # set up managed gpu + agentpool = self.set_up_managed_gpu(agentpool) # set up skip_gpu_driver_install agentpool = self.set_up_skip_gpu_driver_install(agentpool) # set up gpu profile @@ -1704,6 +1745,29 @@ def update_artifact_streaming(self, agentpool: AgentPool) -> AgentPool: agentpool.artifact_streaming_profile.enabled = True return agentpool + def update_managed_gpu(self, agentpool: AgentPool) -> AgentPool: + """Update managed GPU property for the AgentPool object. + :return: the AgentPool object + """ + self._ensure_agentpool(agentpool) + + enable_managed_gpu = self.context.get_enable_managed_gpu() + if enable_managed_gpu is None: + return agentpool + + if enable_managed_gpu: + if agentpool.gpu_profile is None: + agentpool.gpu_profile = self.models.GPUProfile() # pylint: disable=no-member + if agentpool.gpu_profile.nvidia is None: + agentpool.gpu_profile.nvidia = self.models.NvidiaGPUProfile() # pylint: disable=no-member + agentpool.gpu_profile.nvidia.management_mode = CONST_GPU_MANAGEMENT_MODE_MANAGED + agentpool.gpu_profile.driver = CONST_GPU_DRIVER_INSTALL + else: + if agentpool.gpu_profile and agentpool.gpu_profile.nvidia: + agentpool.gpu_profile.nvidia.management_mode = CONST_GPU_MANAGEMENT_MODE_UNMANAGED + + return agentpool + def update_os_sku(self, agentpool: AgentPool) -> AgentPool: self._ensure_agentpool(agentpool) @@ -1828,6 +1892,9 @@ def update_agentpool_profile_preview(self, agentpools: List[AgentPool] = None) - # update artifact streaming agentpool = self.update_artifact_streaming(agentpool) + # update managed gpu + agentpool = self.update_managed_gpu(agentpool) + # update secure boot agentpool = self.update_secure_boot(agentpool) diff --git a/src/aks-preview/azext_aks_preview/custom.py b/src/aks-preview/azext_aks_preview/custom.py index a63a6e96202..f8894b88bec 100644 --- a/src/aks-preview/azext_aks_preview/custom.py +++ b/src/aks-preview/azext_aks_preview/custom.py @@ -1919,6 +1919,7 @@ def aks_agentpool_add( asg_ids=None, node_public_ip_tags=None, enable_artifact_streaming=False, + enable_managed_gpu=False, skip_gpu_driver_install=False, gpu_driver=None, driver_type=None, @@ -1993,6 +1994,7 @@ def aks_agentpool_update( allowed_host_ports=None, asg_ids=None, enable_artifact_streaming=False, + enable_managed_gpu=False, os_sku=None, ssh_access=None, yes=False, diff --git a/src/aks-preview/azext_aks_preview/tests/latest/test_agentpool_decorator.py b/src/aks-preview/azext_aks_preview/tests/latest/test_agentpool_decorator.py index f211311be2d..8bbe8fe9028 100644 --- a/src/aks-preview/azext_aks_preview/tests/latest/test_agentpool_decorator.py +++ b/src/aks-preview/azext_aks_preview/tests/latest/test_agentpool_decorator.py @@ -36,6 +36,7 @@ CONST_MANAGED_CLUSTER_SKU_NAME_BASE, CONST_MANAGED_CLUSTER_SKU_NAME_AUTOMATIC, CONST_GPU_DRIVER_NONE, + CONST_GPU_MANAGEMENT_MODE_MANAGED, CONST_NODEPOOL_MODE_MANAGEDSYSTEM, CONST_NODEPOOL_MODE_MACHINES, ) @@ -257,6 +258,45 @@ def common_get_enable_artifact_streaming(self): ctx_2.attach_agentpool(agentpool_2) self.assertEqual(ctx_2.get_enable_artifact_streaming(), None) + def common_get_enable_managed_gpu(self): + # default + ctx_1 = AKSPreviewAgentPoolContext( + self.cmd, + AKSAgentPoolParamDict({"enable_managed_gpu": None}), + self.models, + DecoratorMode.CREATE, + self.agentpool_decorator_mode, + ) + self.assertEqual(ctx_1.get_enable_managed_gpu(), None) + agentpool_1 = self.create_initialized_agentpool_instance( + gpu_profile=self.models.GPUProfile( + nvidia=self.models.NvidiaGPUProfile( + management_mode=CONST_GPU_MANAGEMENT_MODE_MANAGED + ) + ) + ) + ctx_1.attach_agentpool(agentpool_1) + self.assertEqual(ctx_1.get_enable_managed_gpu(), True) + + # default + ctx_2 = AKSPreviewAgentPoolContext( + self.cmd, + AKSAgentPoolParamDict({"enable_managed_gpu": None}), + self.models, + DecoratorMode.UPDATE, + self.agentpool_decorator_mode, + ) + self.assertEqual(ctx_2.get_enable_managed_gpu(), None) + agentpool_2 = self.create_initialized_agentpool_instance( + gpu_profile=self.models.GPUProfile( + nvidia=self.models.NvidiaGPUProfile( + management_mode=CONST_GPU_MANAGEMENT_MODE_MANAGED + ) + ) + ) + ctx_2.attach_agentpool(agentpool_2) + self.assertEqual(ctx_2.get_enable_managed_gpu(), None) + def common_get_pod_ip_allocation_mode(self): # default ctx_1 = AKSPreviewAgentPoolContext( @@ -1037,6 +1077,9 @@ def test_get_workload_runtime(self): def test_get_enable_artifact_streaming(self): self.common_get_enable_artifact_streaming() + def test_get_enable_managed_gpu(self): + self.common_get_enable_managed_gpu() + def test_get_pod_ip_allocation_mode(self): self.common_get_pod_ip_allocation_mode() @@ -1130,6 +1173,9 @@ def test_get_workload_runtime(self): def test_get_enable_artifact_streaming(self): self.common_get_enable_artifact_streaming() + + def test_get_enable_managed_gpu(self): + self.common_get_enable_managed_gpu() def test_get_pod_ip_allocation_mode(self): self.common_get_pod_ip_allocation_mode() @@ -1450,6 +1496,31 @@ def common_set_up_artifact_streaming(self): ) self.assertEqual(dec_agentpool_1, ground_truth_agentpool_1) + def common_set_up_managed_gpu(self): + dec_1 = AKSPreviewAgentPoolAddDecorator( + self.cmd, + self.client, + {"enable_managed_gpu": True}, + self.resource_type, + self.agentpool_decorator_mode, + ) + # fail on passing the wrong agentpool object + with self.assertRaises(CLIInternalError): + dec_1.set_up_managed_gpu(None) + agentpool_1 = self.create_initialized_agentpool_instance(restore_defaults=False) + dec_1.context.attach_agentpool(agentpool_1) + dec_agentpool_1 = dec_1.set_up_managed_gpu(agentpool_1) + dec_agentpool_1 = self._restore_defaults_in_agentpool(dec_agentpool_1) + ground_truth_agentpool_1 = self.create_initialized_agentpool_instance( + gpu_profile=self.models.GPUProfile( + driver=CONST_GPU_DRIVER_INSTALL, + nvidia=self.models.NvidiaGPUProfile( + management_mode=CONST_GPU_MANAGEMENT_MODE_MANAGED + ) + ) + ) + self.assertEqual(dec_agentpool_1, ground_truth_agentpool_1) + def common_set_up_skip_gpu_driver_install(self): dec_1 = AKSPreviewAgentPoolAddDecorator( self.cmd, @@ -1999,6 +2070,9 @@ def test_set_up_gpu_propertes(self): def test_set_up_artifact_streaming(self): self.common_set_up_artifact_streaming() + def test_set_up_managed_gpu(self): + self.common_set_up_managed_gpu() + def test_set_up_skip_gpu_driver_install(self): self.common_set_up_skip_gpu_driver_install() @@ -2144,6 +2218,9 @@ def test_set_up_gpu_propertes(self): def test_set_up_artifact_streaming(self): self.common_set_up_artifact_streaming() + + def test_set_up_managed_gpu(self): + self.common_set_up_managed_gpu() def test_set_up_skip_gpu_driver_install(self): self.common_set_up_skip_gpu_driver_install() @@ -2322,12 +2399,12 @@ def common_update_artifact_streaming(self): ) dec_1.context.attach_agentpool(agentpool_1) dec_agentpool_1 = dec_1.update_artifact_streaming(agentpool_1) - grond_truth_agentpool_1 = self.create_initialized_agentpool_instance( + ground_truth_agentpool_1 = self.create_initialized_agentpool_instance( artifact_streaming_profile=self.models.AgentPoolArtifactStreamingProfile( enabled=True ) ) - self.assertEqual(dec_agentpool_1, grond_truth_agentpool_1) + self.assertEqual(dec_agentpool_1, ground_truth_agentpool_1) dec_2 = AKSPreviewAgentPoolUpdateDecorator( self.cmd, @@ -2342,11 +2419,63 @@ def common_update_artifact_streaming(self): agentpool_2 = self.create_initialized_agentpool_instance() dec_2.context.attach_agentpool(agentpool_2) dec_agentpool_2 = dec_2.update_artifact_streaming(agentpool_2) - grond_truth_agentpool_2 = self.create_initialized_agentpool_instance( + ground_truth_agentpool_2 = self.create_initialized_agentpool_instance( artifact_streaming_profile=self.models.AgentPoolArtifactStreamingProfile( enabled=True ) ) + self.assertEqual(dec_agentpool_2, ground_truth_agentpool_2) + + def common_update_managed_gpu(self): + dec_1 = AKSPreviewAgentPoolUpdateDecorator( + self.cmd, + self.client, + {"enable_managed_gpu": None}, + self.resource_type, + self.agentpool_decorator_mode, + ) + # fail on passing the wrong agentpool object + with self.assertRaises(CLIInternalError): + dec_1.update_managed_gpu(None) + agentpool_1 = self.create_initialized_agentpool_instance( + gpu_profile=self.models.GPUProfile( + nvidia=self.models.NvidiaGPUProfile( + management_mode=CONST_GPU_MANAGEMENT_MODE_MANAGED + ) + ) + ) + dec_1.context.attach_agentpool(agentpool_1) + dec_agentpool_1 = dec_1.update_managed_gpu(agentpool_1) + ground_truth_agentpool_1 = self.create_initialized_agentpool_instance( + gpu_profile=self.models.GPUProfile( + nvidia=self.models.NvidiaGPUProfile( + management_mode=CONST_GPU_MANAGEMENT_MODE_MANAGED + ) + ) + ) + self.assertEqual(dec_agentpool_1, ground_truth_agentpool_1) + + dec_2 = AKSPreviewAgentPoolUpdateDecorator( + self.cmd, + self.client, + {"enable_managed_gpu": True}, + self.resource_type, + self.agentpool_decorator_mode, + ) + # fail on passing the wrong agentpool object + with self.assertRaises(CLIInternalError): + dec_2.update_managed_gpu(None) + agentpool_2 = self.create_initialized_agentpool_instance() + dec_2.context.attach_agentpool(agentpool_2) + dec_agentpool_2 = dec_2.update_managed_gpu(agentpool_2) + grond_truth_agentpool_2 = self.create_initialized_agentpool_instance( + gpu_profile=self.models.GPUProfile( + driver=CONST_GPU_DRIVER_INSTALL, + nvidia=self.models.NvidiaGPUProfile( + management_mode=CONST_GPU_MANAGEMENT_MODE_MANAGED + ) + ) + ) self.assertEqual(dec_agentpool_2, grond_truth_agentpool_2) def common_update_secure_boot(self): @@ -2849,6 +2978,9 @@ def setUp(self): def test_update_artifact_streaming(self): self.common_update_artifact_streaming() + def test_update_managed_gpu(self): + self.common_update_managed_gpu() + def test_update_secure_boot(self): self.common_update_secure_boot() @@ -2941,6 +3073,9 @@ def setUp(self): def test_update_artifact_streaming(self): self.common_update_artifact_streaming() + + def test_update_managed_gpu(self): + self.common_update_managed_gpu() def test_update_secure_boot(self): self.common_update_secure_boot() diff --git a/src/aks-preview/azext_aks_preview/tests/latest/test_aks_commands.py b/src/aks-preview/azext_aks_preview/tests/latest/test_aks_commands.py index 09d8f18bf65..9ffe956d6d6 100644 --- a/src/aks-preview/azext_aks_preview/tests/latest/test_aks_commands.py +++ b/src/aks-preview/azext_aks_preview/tests/latest/test_aks_commands.py @@ -6868,6 +6868,61 @@ def test_aks_nodepool_add_with_artifact_streaming( checks=[self.is_empty()], ) + @live_only() + @AllowLargeResponse() + @AKSCustomResourceGroupPreparer( + random_name_length=17, name_prefix="clitest", location="westus3" + ) + def test_aks_nodepool_add_with_enable_managed_gpu( + self, resource_group, resource_group_location + ): + aks_name = self.create_random_name("cliakstest", 16) + nodepool_name = self.create_random_name("n", 6) + + self.kwargs.update( + { + "resource_group": resource_group, + "name": aks_name, + "location": resource_group_location, + "ssh_key_value": self.generate_ssh_keys(), + "node_pool_name": nodepool_name, + "node_vm_size": "Standard_NC6s_v3", + } + ) + + # create + create_cmd = ( + "aks create --resource-group={resource_group} --name={name} " + "--ssh-key-value={ssh_key_value} " + ) + + self.cmd( + create_cmd, + checks=[ + self.check("provisioningState", "Succeeded"), + ], + ) + + # nodepool add + self.cmd( + "aks nodepool add --resource-group={resource_group} --cluster-name={name} --name={node_pool_name} " + "--node-vm-size={node_vm_size} --node-count 1 " + " --enable-managed-gpu", + checks=[ + self.check("provisioningState", "Succeeded"), + self.check("gpuProfile.driver", "Install"), + self.check( + "gpuProfile.nvidia.managementMode", "Managed" + ), + ], + ) + + # delete + self.cmd( + "aks delete -g {resource_group} -n {name} --yes --no-wait", + checks=[self.is_empty()], + ) + @AllowLargeResponse() @AKSCustomResourceGroupPreparer( random_name_length=17, name_prefix="clitest", location="eastus" @@ -16443,6 +16498,68 @@ def test_aks_nodepool_update_with_artifact_streaming( ], ) + @live_only() + @AllowLargeResponse() + @AKSCustomResourceGroupPreparer( + random_name_length=17, name_prefix="clitest", location="westus3" + ) + def test_aks_nodepool_update_with_enable_managed_gpu( + self, resource_group, resource_group_location + ): + aks_name = self.create_random_name("cliakstest", 16) + nodepool_name = self.create_random_name("n", 6) + + self.kwargs.update( + { + "resource_group": resource_group, + "name": aks_name, + "location": resource_group_location, + "ssh_key_value": self.generate_ssh_keys(), + "node_pool_name": nodepool_name, + "node_vm_size": "Standard_NC6s_v3", + } + ) + + self.cmd( + "aks create " + "--resource-group={resource_group} " + "--name={name} " + "--location={location} " + "--ssh-key-value={ssh_key_value} " + "--nodepool-name={node_pool_name} " + "--node-count=1 " + "--node-vm-size={node_vm_size}", + checks=[ + self.check("provisioningState", "Succeeded"), + ], + ) + + self.cmd( + "aks nodepool update " + "--resource-group={resource_group} " + "--cluster-name={name} " + "--name={node_pool_name} " + "--enable-managed-gpu", + checks=[ + self.check("provisioningState", "Succeeded"), + self.check("gpuProfile.driver", "Install"), + self.check( + "gpuProfile.nvidia.managementMode", "Managed" + ), + ], + ) + + # delete + cmd = ( + "aks delete --resource-group={resource_group} --name={name} --yes --no-wait" + ) + self.cmd( + cmd, + checks=[ + self.is_empty(), + ], + ) + @AllowLargeResponse() @AKSCustomResourceGroupPreparer( random_name_length=17, name_prefix="clitest", location="westus2" diff --git a/src/aks-preview/azext_aks_preview/tests/latest/test_update_agentpool_profile_preview.py b/src/aks-preview/azext_aks_preview/tests/latest/test_update_agentpool_profile_preview.py index 7df5619d3ac..a9d5f9548ab 100644 --- a/src/aks-preview/azext_aks_preview/tests/latest/test_update_agentpool_profile_preview.py +++ b/src/aks-preview/azext_aks_preview/tests/latest/test_update_agentpool_profile_preview.py @@ -123,6 +123,7 @@ def test_update_agentpool_profile_preview_default_behavior(self): # Mock all the update methods to return the agentpool unchanged decorator.update_network_profile = Mock(return_value=agentpool) decorator.update_artifact_streaming = Mock(return_value=agentpool) + decorator.update_managed_gpu = Mock(return_value=agentpool) decorator.update_secure_boot = Mock(return_value=agentpool) decorator.update_vtpm = Mock(return_value=agentpool) decorator.update_os_sku = Mock(return_value=agentpool) @@ -146,6 +147,7 @@ def test_update_agentpool_profile_preview_default_behavior(self): # Verify that all update methods were called decorator.update_network_profile.assert_called_once_with(agentpool) decorator.update_artifact_streaming.assert_called_once_with(agentpool) + decorator.update_managed_gpu.assert_called_once_with(agentpool) decorator.update_secure_boot.assert_called_once_with(agentpool) decorator.update_vtpm.assert_called_once_with(agentpool) decorator.update_os_sku.assert_called_once_with(agentpool) @@ -187,6 +189,7 @@ def test_update_agentpool_profile_preview_with_agentpools_parameter(self): # Mock all the update methods to return the agentpool unchanged decorator.update_network_profile = Mock(return_value=agentpool) decorator.update_artifact_streaming = Mock(return_value=agentpool) + decorator.update_managed_gpu = Mock(return_value=agentpool) decorator.update_secure_boot = Mock(return_value=agentpool) decorator.update_vtpm = Mock(return_value=agentpool) decorator.update_os_sku = Mock(return_value=agentpool) @@ -238,6 +241,7 @@ def test_update_agentpool_profile_preview_managed_system_mode(self): # Mock all the update methods (they should not be called for ManagedSystem mode) decorator.update_network_profile = Mock() decorator.update_artifact_streaming = Mock() + decorator.update_managed_gpu = Mock() decorator.update_secure_boot = Mock() decorator.update_vtpm = Mock() decorator.update_os_sku = Mock() @@ -267,6 +271,7 @@ def test_update_agentpool_profile_preview_managed_system_mode(self): # Verify that none of the update methods were called for ManagedSystem mode decorator.update_network_profile.assert_not_called() decorator.update_artifact_streaming.assert_not_called() + decorator.update_managed_gpu.assert_not_called() decorator.update_secure_boot.assert_not_called() decorator.update_vtpm.assert_not_called() decorator.update_os_sku.assert_not_called() @@ -345,6 +350,7 @@ def test_update_agentpool_profile_preview_system_mode_regular_flow(self): # Mock all the update methods to return the agentpool unchanged decorator.update_network_profile = Mock(return_value=agentpool) decorator.update_artifact_streaming = Mock(return_value=agentpool) + decorator.update_managed_gpu = Mock(return_value=agentpool) decorator.update_secure_boot = Mock(return_value=agentpool) decorator.update_vtpm = Mock(return_value=agentpool) decorator.update_os_sku = Mock(return_value=agentpool) @@ -366,6 +372,7 @@ def test_update_agentpool_profile_preview_system_mode_regular_flow(self): # Verify that all update methods were called for System mode decorator.update_network_profile.assert_called_once_with(agentpool) decorator.update_artifact_streaming.assert_called_once_with(agentpool) + decorator.update_managed_gpu.assert_called_once_with(agentpool) decorator.update_secure_boot.assert_called_once_with(agentpool) decorator.update_vtpm.assert_called_once_with(agentpool) decorator.update_os_sku.assert_called_once_with(agentpool) @@ -412,6 +419,7 @@ def mock_method(pool): decorator.update_network_profile = create_mock_update_method("update_network_profile") decorator.update_artifact_streaming = create_mock_update_method("update_artifact_streaming") + decorator.update_managed_gpu = create_mock_update_method("update_managed_gpu") decorator.update_secure_boot = create_mock_update_method("update_secure_boot") decorator.update_vtpm = create_mock_update_method("update_vtpm") decorator.update_os_sku = create_mock_update_method("update_os_sku") @@ -430,6 +438,7 @@ def mock_method(pool): expected_order = [ "update_network_profile", "update_artifact_streaming", + "update_managed_gpu", "update_secure_boot", "update_vtpm", "update_os_sku", @@ -478,6 +487,7 @@ def track_and_return(pool): decorator.update_network_profile = create_tracking_mock("update_network_profile") decorator.update_artifact_streaming = create_tracking_mock("update_artifact_streaming") + decorator.update_managed_gpu = create_tracking_mock("update_managed_gpu") decorator.update_secure_boot = create_tracking_mock("update_secure_boot") decorator.update_vtpm = create_tracking_mock("update_vtpm") decorator.update_os_sku = create_tracking_mock("update_os_sku") @@ -547,7 +557,7 @@ def test_update_agentpool_profile_preview_mixed_modes_scenario(self): # Mock all update methods update_methods = [ - 'update_network_profile', 'update_artifact_streaming', + 'update_network_profile', 'update_artifact_streaming', 'update_managed_gpu', 'update_secure_boot', 'update_vtpm', 'update_os_sku', 'update_fips_image', 'update_ssh_access', 'update_localdns_profile', 'update_auto_scaler_properties_vms', 'update_upgrade_strategy', 'update_blue_green_upgrade_settings', 'update_gpu_profile' @@ -613,6 +623,7 @@ def test_update_agentpool_profile_preview_managed_cluster_mode(self): # Mock all the update methods decorator.update_network_profile = Mock(return_value=agentpool) decorator.update_artifact_streaming = Mock(return_value=agentpool) + decorator.update_managed_gpu = Mock(return_value=agentpool) decorator.update_secure_boot = Mock(return_value=agentpool) decorator.update_vtpm = Mock(return_value=agentpool) decorator.update_os_sku = Mock(return_value=agentpool)