22import json
33import threading
44from collections import defaultdict
5+ from dataclasses import dataclass
56from typing import Callable , Dict , List , Literal , Optional , Tuple
67
78import google .api_core .exceptions
@@ -285,16 +286,18 @@ def create_instance(
285286 )
286287 raise NoCapacityError ()
287288
289+ image = _get_image (
290+ instance_type_name = instance_offer .instance .name ,
291+ cuda = len (instance_offer .instance .resources .gpus ) > 0 ,
292+ )
293+
288294 for zone in zones :
289295 request = compute_v1 .InsertInstanceRequest ()
290296 request .zone = zone
291297 request .project = self .config .project_id
292298 request .instance_resource = gcp_resources .create_instance_struct (
293299 disk_size = disk_size ,
294- image_id = _get_image_id (
295- instance_type_name = instance_offer .instance .name ,
296- cuda = len (instance_offer .instance .resources .gpus ) > 0 ,
297- ),
300+ image_id = image .id ,
298301 machine_type = instance_offer .instance .name ,
299302 accelerators = gcp_resources .get_accelerators (
300303 project_id = self .config .project_id ,
@@ -305,6 +308,7 @@ def create_instance(
305308 user_data = _get_user_data (
306309 authorized_keys = authorized_keys ,
307310 instance_type_name = instance_offer .instance .name ,
311+ is_ufw_installed = image .is_ufw_installed ,
308312 ),
309313 authorized_keys = authorized_keys ,
310314 labels = labels ,
@@ -889,24 +893,41 @@ def _get_vpc_subnet(
889893 )
890894
891895
892- def _get_image_id (instance_type_name : str , cuda : bool ) -> str :
896+ @dataclass
897+ class GCPImage :
898+ id : str
899+ is_ufw_installed : bool
900+
901+
902+ def _get_image (instance_type_name : str , cuda : bool ) -> GCPImage :
893903 if instance_type_name == "a3-megagpu-8g" :
894904 image_name = "dstack-a3mega-5"
905+ is_ufw_installed = False
895906 elif instance_type_name in ["a3-edgegpu-8g" , "a3-highgpu-8g" ]:
896- return "projects/cos-cloud/global/images/cos-105-17412-535-78"
907+ return GCPImage (
908+ id = "projects/cos-cloud/global/images/cos-105-17412-535-78" ,
909+ is_ufw_installed = False ,
910+ )
897911 elif cuda :
898912 image_name = f"dstack-cuda-{ version .base_image } "
913+ is_ufw_installed = True
899914 else :
900915 image_name = f"dstack-{ version .base_image } "
916+ is_ufw_installed = True
901917 image_name = image_name .replace ("." , "-" )
902- return f"projects/dstack/global/images/{ image_name } "
918+ return GCPImage (
919+ id = f"projects/dstack/global/images/{ image_name } " ,
920+ is_ufw_installed = is_ufw_installed ,
921+ )
903922
904923
905924def _get_gateway_image_id () -> str :
906925 return "projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714"
907926
908927
909- def _get_user_data (authorized_keys : List [str ], instance_type_name : str ) -> str :
928+ def _get_user_data (
929+ authorized_keys : List [str ], instance_type_name : str , is_ufw_installed : bool
930+ ) -> str :
910931 base_path = None
911932 bin_path = None
912933 backend_shim_env = None
@@ -929,6 +950,9 @@ def _get_user_data(authorized_keys: List[str], instance_type_name: str) -> str:
929950 base_path = base_path ,
930951 bin_path = bin_path ,
931952 backend_shim_env = backend_shim_env ,
953+ # Instance-level firewall is optional on GCP. The main protection comes from GCP firewalls.
954+ # So only set up instance-level firewall if ufw is available.
955+ skip_firewall_setup = not is_ufw_installed ,
932956 )
933957
934958
0 commit comments