@@ -649,32 +649,41 @@ def delete_volume(self, volume: Volume):
649649 pass
650650 logger .debug ("Deleted persistent disk for volume %s" , volume .name )
651651
652- def attach_volume (self , volume : Volume , instance_id : str ) -> VolumeAttachmentData :
652+ def attach_volume (
653+ self , volume : Volume , provisioning_data : JobProvisioningData
654+ ) -> VolumeAttachmentData :
655+ instance_id = provisioning_data .instance_id
653656 logger .debug (
654657 "Attaching persistent disk for volume %s to instance %s" ,
655658 volume .volume_id ,
656659 instance_id ,
657660 )
661+ if not gcp_resources .instance_type_supports_persistent_disk (
662+ provisioning_data .instance_type .name
663+ ):
664+ raise ComputeError (
665+ f"Instance type { provisioning_data .instance_type .name } does not support Persistent disk volumes"
666+ )
667+
658668 zone = get_or_error (volume .provisioning_data ).availability_zone
669+ is_tpu = _is_tpu_provisioning_data (provisioning_data )
659670 try :
660671 disk = self .disk_client .get (
661672 project = self .config .project_id ,
662673 zone = zone ,
663674 disk = volume .volume_id ,
664675 )
665676 disk_url = disk .self_link
677+ except google .api_core .exceptions .NotFound :
678+ raise ComputeError ("Persistent disk found" )
666679
667- # This method has no information if the instance is a TPU or a VM,
668- # so we first try to see if there is a TPU with such name
669- try :
680+ try :
681+ if is_tpu :
670682 get_node_request = tpu_v2 .GetNodeRequest (
671683 name = f"projects/{ self .config .project_id } /locations/{ zone } /nodes/{ instance_id } " ,
672684 )
673685 tpu_node = self .tpu_client .get_node (get_node_request )
674- except google .api_core .exceptions .NotFound :
675- tpu_node = None
676686
677- if tpu_node is not None :
678687 # Python API to attach a disk to a TPU is not documented,
679688 # so we follow the code from the gcloud CLI:
680689 # https://github.com/twistedpair/google-cloud-sdk/blob/26ab5a281d56b384cc25750f3279a27afe5b499f/google-cloud-sdk/lib/googlecloudsdk/command_lib/compute/tpus/tpu_vm/util.py#L113
@@ -711,7 +720,6 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
711720 attached_disk .auto_delete = False
712721 attached_disk .device_name = f"pd-{ volume .volume_id } "
713722 device_name = attached_disk .device_name
714-
715723 operation = self .instances_client .attach_disk (
716724 project = self .config .project_id ,
717725 zone = zone ,
@@ -720,31 +728,33 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
720728 )
721729 gcp_resources .wait_for_extended_operation (operation , "persistent disk attachment" )
722730 except google .api_core .exceptions .NotFound :
723- raise ComputeError ("Persistent disk or instance not found" )
731+ raise ComputeError ("Disk or instance not found" )
724732 logger .debug (
725733 "Attached persistent disk for volume %s to instance %s" , volume .volume_id , instance_id
726734 )
727735 return VolumeAttachmentData (device_name = device_name )
728736
729- def detach_volume (self , volume : Volume , instance_id : str , force : bool = False ):
737+ def detach_volume (
738+ self , volume : Volume , provisioning_data : JobProvisioningData , force : bool = False
739+ ):
740+ instance_id = provisioning_data .instance_id
730741 logger .debug (
731742 "Detaching persistent disk for volume %s from instance %s" ,
732743 volume .volume_id ,
733744 instance_id ,
734745 )
735746 zone = get_or_error (volume .provisioning_data ).availability_zone
736747 attachment_data = get_or_error (volume .get_attachment_data_for_instance (instance_id ))
737- # This method has no information if the instance is a TPU or a VM,
738- # so we first try to see if there is a TPU with such name
739- try :
740- get_node_request = tpu_v2 .GetNodeRequest (
741- name = f"projects/{ self .config .project_id } /locations/{ zone } /nodes/{ instance_id } " ,
742- )
743- tpu_node = self .tpu_client .get_node (get_node_request )
744- except google .api_core .exceptions .NotFound :
745- tpu_node = None
748+ is_tpu = _is_tpu_provisioning_data ( provisioning_data )
749+ if is_tpu :
750+ try :
751+ get_node_request = tpu_v2 .GetNodeRequest (
752+ name = f"projects/{ self .config .project_id } /locations/{ zone } /nodes/{ instance_id } " ,
753+ )
754+ tpu_node = self .tpu_client .get_node (get_node_request )
755+ except google .api_core .exceptions .NotFound :
756+ raise ComputeError ( "Instance not found" )
746757
747- if tpu_node is not None :
748758 source_disk = (
749759 f"projects/{ self .config .project_id } /zones/{ zone } /disks/{ volume .volume_id } "
750760 )
@@ -815,6 +825,11 @@ def _filter(offer: InstanceOffer) -> bool:
815825 if _is_tpu (offer .instance .name ) and not _is_single_host_tpu (offer .instance .name ):
816826 return False
817827 for family in [
828+ "m4-" ,
829+ "c4-" ,
830+ "n4-" ,
831+ "h3-" ,
832+ "n2-" ,
818833 "e2-medium" ,
819834 "e2-standard-" ,
820835 "e2-highmem-" ,
@@ -1001,3 +1016,11 @@ def _get_tpu_data_disk_for_volume(project_id: str, volume: Volume) -> tpu_v2.Att
10011016 mode = tpu_v2 .AttachedDisk .DiskMode .READ_WRITE ,
10021017 )
10031018 return attached_disk
1019+
1020+
1021+ def _is_tpu_provisioning_data (provisioning_data : JobProvisioningData ) -> bool :
1022+ is_tpu = False
1023+ if provisioning_data .backend_data :
1024+ backend_data_dict = json .loads (provisioning_data .backend_data )
1025+ is_tpu = backend_data_dict .get ("is_tpu" , False )
1026+ return is_tpu
0 commit comments