88import google .cloud .compute_v1 as compute_v1
99from cachetools import TTLCache , cachedmethod
1010from google .cloud import tpu_v2
11+ from google .cloud .compute_v1 .types .compute import Instance
1112from gpuhunt import KNOWN_TPUS
1213
1314import dstack ._internal .core .backends .gcp .auth as auth
1920 ComputeWithGatewaySupport ,
2021 ComputeWithMultinodeSupport ,
2122 ComputeWithPlacementGroupSupport ,
23+ ComputeWithPrivateGatewaySupport ,
2224 ComputeWithVolumeSupport ,
2325 generate_unique_gateway_instance_name ,
2426 generate_unique_instance_name ,
@@ -83,6 +85,7 @@ class GCPCompute(
8385 ComputeWithMultinodeSupport ,
8486 ComputeWithPlacementGroupSupport ,
8587 ComputeWithGatewaySupport ,
88+ ComputeWithPrivateGatewaySupport ,
8689 ComputeWithVolumeSupport ,
8790 Compute ,
8891):
@@ -395,11 +398,7 @@ def update_provisioning_data(
395398 if instance .status in ["PROVISIONING" , "STAGING" ]:
396399 return
397400 if instance .status == "RUNNING" :
398- if allocate_public_ip :
399- hostname = instance .network_interfaces [0 ].access_configs [0 ].nat_i_p
400- else :
401- hostname = instance .network_interfaces [0 ].network_i_p
402- provisioning_data .hostname = hostname
401+ provisioning_data .hostname = _get_instance_ip (instance , allocate_public_ip )
403402 provisioning_data .internal_ip = instance .network_interfaces [0 ].network_i_p
404403 return
405404 raise ProvisioningError (
@@ -512,6 +511,7 @@ def create_gateway(
512511 service_account = self .config .vm_service_account ,
513512 network = self .config .vpc_resource_name ,
514513 subnetwork = subnetwork ,
514+ allocate_public_ip = configuration .public_ip ,
515515 )
516516 operation = self .instances_client .insert (request = request )
517517 gcp_resources .wait_for_extended_operation (operation , "instance creation" )
@@ -522,7 +522,7 @@ def create_gateway(
522522 instance_id = instance_name ,
523523 region = configuration .region , # used for instance termination
524524 availability_zone = zone ,
525- ip_address = instance . network_interfaces [ 0 ]. access_configs [ 0 ]. nat_i_p ,
525+ ip_address = _get_instance_ip ( instance , configuration . public_ip ) ,
526526 backend_data = json .dumps ({"zone" : zone }),
527527 )
528528
@@ -1024,3 +1024,9 @@ def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool:
10241024 backend_data_dict = json .loads (provisioning_data .backend_data )
10251025 is_tpu = backend_data_dict .get ("is_tpu" , False )
10261026 return is_tpu
1027+
1028+
1029+ def _get_instance_ip (instance : Instance , public_ip : bool ) -> str :
1030+ if public_ip :
1031+ return instance .network_interfaces [0 ].access_configs [0 ].nat_i_p
1032+ return instance .network_interfaces [0 ].network_i_p
0 commit comments