diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 56fe901667..076bc13c2e 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -611,9 +611,12 @@ def delete_volume(self, volume: Volume): raise e logger.debug("Deleted EBS volume %s", volume.configuration.name) - def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData: + def attach_volume( + self, volume: Volume, provisioning_data: JobProvisioningData + ) -> VolumeAttachmentData: ec2_client = self.session.client("ec2", region_name=volume.configuration.region) + instance_id = provisioning_data.instance_id device_names = aws_resources.list_available_device_names( ec2_client=ec2_client, instance_id=instance_id ) @@ -646,9 +649,12 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat logger.debug("Attached EBS volume %s to instance %s", volume.volume_id, instance_id) return VolumeAttachmentData(device_name=device_name) - def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): + def detach_volume( + self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False + ): ec2_client = self.session.client("ec2", region_name=volume.configuration.region) + instance_id = provisioning_data.instance_id logger.debug("Detaching EBS volume %s from instance %s", volume.volume_id, instance_id) attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id)) try: @@ -667,9 +673,10 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): raise e logger.debug("Detached EBS volume %s from instance %s", volume.volume_id, instance_id) - def is_volume_detached(self, volume: Volume, instance_id: str) -> bool: + def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool: ec2_client = self.session.client("ec2", region_name=volume.configuration.region) + instance_id = provisioning_data.instance_id logger.debug("Getting EBS volume %s status", volume.volume_id) response = ec2_client.describe_volumes(VolumeIds=[volume.volume_id]) volumes_infos = response.get("Volumes") diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 1d749793d1..4efcbc851c 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -336,7 +336,9 @@ def delete_volume(self, volume: Volume): """ raise NotImplementedError() - def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData: + def attach_volume( + self, volume: Volume, provisioning_data: JobProvisioningData + ) -> VolumeAttachmentData: """ Attaches a volume to the instance. If the volume is not found, it should raise `ComputeError()`. @@ -345,7 +347,9 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat """ raise NotImplementedError() - def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): + def detach_volume( + self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False + ): """ Detaches a volume from the instance. Implement only if compute may return `VolumeProvisioningData.detachable`. @@ -353,7 +357,7 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): """ raise NotImplementedError() - def is_volume_detached(self, volume: Volume, instance_id: str) -> bool: + def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool: """ Checks if a volume was detached from the instance. If `detach_volume()` may fail to detach volume, diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index e216a6aef7..7e60afff43 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -649,13 +649,24 @@ def delete_volume(self, volume: Volume): pass logger.debug("Deleted persistent disk for volume %s", volume.name) - def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData: + def attach_volume( + self, volume: Volume, provisioning_data: JobProvisioningData + ) -> VolumeAttachmentData: + instance_id = provisioning_data.instance_id logger.debug( "Attaching persistent disk for volume %s to instance %s", volume.volume_id, instance_id, ) + if not gcp_resources.instance_type_supports_persistent_disk( + provisioning_data.instance_type.name + ): + raise ComputeError( + f"Instance type {provisioning_data.instance_type.name} does not support Persistent disk volumes" + ) + zone = get_or_error(volume.provisioning_data).availability_zone + is_tpu = _is_tpu_provisioning_data(provisioning_data) try: disk = self.disk_client.get( project=self.config.project_id, @@ -663,18 +674,16 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat disk=volume.volume_id, ) disk_url = disk.self_link + except google.api_core.exceptions.NotFound: + raise ComputeError("Persistent disk found") - # This method has no information if the instance is a TPU or a VM, - # so we first try to see if there is a TPU with such name - try: + try: + if is_tpu: get_node_request = tpu_v2.GetNodeRequest( name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}", ) tpu_node = self.tpu_client.get_node(get_node_request) - except google.api_core.exceptions.NotFound: - tpu_node = None - if tpu_node is not None: # Python API to attach a disk to a TPU is not documented, # so we follow the code from the gcloud CLI: # https://github.com/twistedpair/google-cloud-sdk/blob/26ab5a281d56b384cc25750f3279a27afe5b499f/google-cloud-sdk/lib/googlecloudsdk/command_lib/compute/tpus/tpu_vm/util.py#L113 @@ -711,7 +720,6 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat attached_disk.auto_delete = False attached_disk.device_name = f"pd-{volume.volume_id}" device_name = attached_disk.device_name - operation = self.instances_client.attach_disk( project=self.config.project_id, zone=zone, @@ -720,13 +728,16 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat ) gcp_resources.wait_for_extended_operation(operation, "persistent disk attachment") except google.api_core.exceptions.NotFound: - raise ComputeError("Persistent disk or instance not found") + raise ComputeError("Disk or instance not found") logger.debug( "Attached persistent disk for volume %s to instance %s", volume.volume_id, instance_id ) return VolumeAttachmentData(device_name=device_name) - def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): + def detach_volume( + self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False + ): + instance_id = provisioning_data.instance_id logger.debug( "Detaching persistent disk for volume %s from instance %s", volume.volume_id, @@ -734,17 +745,16 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): ) zone = get_or_error(volume.provisioning_data).availability_zone attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id)) - # This method has no information if the instance is a TPU or a VM, - # so we first try to see if there is a TPU with such name - try: - get_node_request = tpu_v2.GetNodeRequest( - name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}", - ) - tpu_node = self.tpu_client.get_node(get_node_request) - except google.api_core.exceptions.NotFound: - tpu_node = None + is_tpu = _is_tpu_provisioning_data(provisioning_data) + if is_tpu: + try: + get_node_request = tpu_v2.GetNodeRequest( + name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}", + ) + tpu_node = self.tpu_client.get_node(get_node_request) + except google.api_core.exceptions.NotFound: + raise ComputeError("Instance not found") - if tpu_node is not None: source_disk = ( f"projects/{self.config.project_id}/zones/{zone}/disks/{volume.volume_id}" ) @@ -815,6 +825,11 @@ def _filter(offer: InstanceOffer) -> bool: if _is_tpu(offer.instance.name) and not _is_single_host_tpu(offer.instance.name): return False for family in [ + "m4-", + "c4-", + "n4-", + "h3-", + "n2-", "e2-medium", "e2-standard-", "e2-highmem-", @@ -1001,3 +1016,11 @@ def _get_tpu_data_disk_for_volume(project_id: str, volume: Volume) -> tpu_v2.Att mode=tpu_v2.AttachedDisk.DiskMode.READ_WRITE, ) return attached_disk + + +def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool: + is_tpu = False + if provisioning_data.backend_data: + backend_data_dict = json.loads(provisioning_data.backend_data) + is_tpu = backend_data_dict.get("is_tpu", False) + return is_tpu diff --git a/src/dstack/_internal/core/backends/gcp/resources.py b/src/dstack/_internal/core/backends/gcp/resources.py index 58c42c0386..c56caddf99 100644 --- a/src/dstack/_internal/core/backends/gcp/resources.py +++ b/src/dstack/_internal/core/backends/gcp/resources.py @@ -140,7 +140,10 @@ def create_instance_struct( initialize_params = compute_v1.AttachedDiskInitializeParams() initialize_params.source_image = image_id initialize_params.disk_size_gb = disk_size - initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced" + if instance_type_supports_persistent_disk(machine_type): + initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced" + else: + initialize_params.disk_type = f"zones/{zone}/diskTypes/hyperdisk-balanced" disk.initialize_params = initialize_params instance.disks = [disk] @@ -421,7 +424,7 @@ def wait_for_extended_operation( if operation.error_code: # Write only debug logs here. - # The unexpected errors will be propagated and logged appropriatly by the caller. + # The unexpected errors will be propagated and logged appropriately by the caller. logger.debug( "Error during %s: [Code: %s]: %s", verbose_name, @@ -462,3 +465,16 @@ def get_placement_policy_resource_name( placement_policy: str, ) -> str: return f"projects/{project_id}/regions/{region}/resourcePolicies/{placement_policy}" + + +def instance_type_supports_persistent_disk(instance_type_name: str) -> bool: + return not any( + instance_type_name.startswith(series) + for series in [ + "m4-", + "c4-", + "n4-", + "h3-", + "v6e", + ] + ) diff --git a/src/dstack/_internal/core/backends/local/compute.py b/src/dstack/_internal/core/backends/local/compute.py index 47afd58bb6..7f9e257f35 100644 --- a/src/dstack/_internal/core/backends/local/compute.py +++ b/src/dstack/_internal/core/backends/local/compute.py @@ -110,8 +110,10 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData: def delete_volume(self, volume: Volume): pass - def attach_volume(self, volume: Volume, instance_id: str): + def attach_volume(self, volume: Volume, provisioning_data: JobProvisioningData): pass - def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): + def detach_volume( + self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False + ): pass diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 99613e493b..cce9c89a74 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -659,7 +659,7 @@ async def _attach_volumes( backend=backend, volume_model=volume_model, instance=instance, - instance_id=job_provisioning_data.instance_id, + jpd=job_provisioning_data, ) job_runtime_data.volume_names.append(volume.name) break # attach next mount point @@ -685,7 +685,7 @@ async def _attach_volume( backend: Backend, volume_model: VolumeModel, instance: InstanceModel, - instance_id: str, + jpd: JobProvisioningData, ): compute = backend.compute() assert isinstance(compute, ComputeWithVolumeSupport) @@ -697,7 +697,7 @@ async def _attach_volume( attachment_data = await common_utils.run_async( compute.attach_volume, volume=volume, - instance_id=instance_id, + provisioning_data=jpd, ) volume_attachment_model = VolumeAttachmentModel( volume=volume_model, diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index c4ebbd79af..f25c193f87 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -470,20 +470,20 @@ async def _detach_volume_from_job_instance( await run_async( compute.detach_volume, volume=volume, - instance_id=jpd.instance_id, + provisioning_data=jpd, force=False, ) # For some backends, the volume may be detached immediately detached = await run_async( compute.is_volume_detached, volume=volume, - instance_id=jpd.instance_id, + provisioning_data=jpd, ) else: detached = await run_async( compute.is_volume_detached, volume=volume, - instance_id=jpd.instance_id, + provisioning_data=jpd, ) if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration): logger.info( @@ -494,7 +494,7 @@ async def _detach_volume_from_job_instance( await run_async( compute.detach_volume, volume=volume, - instance_id=jpd.instance_id, + provisioning_data=jpd, force=True, ) # Let the next iteration check if force detach worked diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index 0346de5df6..1d1c143d4f 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -190,7 +190,7 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession): m.assert_awaited_once() backend_mock.compute.return_value.detach_volume.assert_called_once_with( volume=volume_model_to_volume(volume), - instance_id=job_provisioning_data.instance_id, + provisioning_data=job_provisioning_data, force=True, ) backend_mock.compute.return_value.is_volume_detached.assert_called_once()