@@ -56,6 +56,23 @@ def _CreateAgentEngineSandboxConfig_to_vertex(
5656 if getv (from_object , ["ttl" ]) is not None :
5757 setv (parent_object , ["ttl" ], getv (from_object , ["ttl" ]))
5858
59+ if getv (from_object , ["sandbox_environment_template" ]) is not None :
60+ setv (
61+ parent_object ,
62+ ["sandboxEnvironmentTemplate" ],
63+ getv (from_object , ["sandbox_environment_template" ]),
64+ )
65+
66+ if getv (from_object , ["sandbox_environment_snapshot" ]) is not None :
67+ setv (
68+ parent_object ,
69+ ["sandboxEnvironmentSnapshot" ],
70+ getv (from_object , ["sandbox_environment_snapshot" ]),
71+ )
72+
73+ if getv (from_object , ["owner" ]) is not None :
74+ setv (parent_object , ["owner" ], getv (from_object , ["owner" ]))
75+
5976 return to_object
6077
6178
@@ -820,7 +837,7 @@ def delete(
820837 def generate_access_token (
821838 self ,
822839 service_account_email : str ,
823- sandbox_id : str ,
840+ sandbox_hostname : str ,
824841 port : str = "8080" ,
825842 timeout : int = 3600 ,
826843 ) -> str :
@@ -829,8 +846,8 @@ def generate_access_token(
829846 Args:
830847 service_account_email (str):
831848 Required. The email of the service account to use for signing.
832- sandbox_id (str):
833- Required. The resource name of the sandbox to generate a token for.
849+ sandbox_hostname (str):
850+ Required. The hostname of the sandbox to generate a token for.
834851 port (str):
835852 Optional. The port to use for the token. Defaults to "8080".
836853 timeout (int):
@@ -841,13 +858,14 @@ def generate_access_token(
841858 """
842859 client = iam_credentials_v1 .IAMCredentialsClient ()
843860 name = f"projects/-/serviceAccounts/{ service_account_email } "
844- custom_claims = {"port " : port , "sandbox_id " : sandbox_id }
861+ custom_claims = {"hostname " : sandbox_hostname , "port " : port }
845862 payload = {
846863 "iat" : int (time .time ()),
847864 "exp" : int (time .time ()) + timeout ,
848865 "iss" : service_account_email ,
866+ "sub" : service_account_email ,
849867 "nonce" : secrets .randbelow (1000000000 ) + 1 ,
850- "aud" : "vmaas-proxy-api " , # default audience for sandbox proxy
868+ "aud" : "https://aiplatform.googleapis.com/ " , # default audience for sandbox proxy
851869 ** custom_claims ,
852870 }
853871 request = iam_credentials_v1 .SignJwtRequest (
@@ -862,7 +880,9 @@ def send_command(
862880 * ,
863881 http_method : str ,
864882 access_token : str ,
883+ routing_token : str ,
865884 sandbox_environment : types .SandboxEnvironment ,
885+ port : str = "8080" ,
866886 path : Optional [str ] = None ,
867887 query_params : Optional [dict [str , object ]] = None ,
868888 headers : Optional [dict [str , str ]] = None ,
@@ -875,8 +895,12 @@ def send_command(
875895 Required. The HTTP method to use for the command.
876896 access_token (str):
877897 Required. The access token to use for authorization.
898+ routing_token (str):
899+ Required. The routing token to use for authorization. This can be found in the sandbox environment's connection_info.
878900 sandbox_environment (types.SandboxEnvironment):
879901 Required. The sandbox environment to send the command to.
902+ port (str):
903+ Optional. The port to use for the token. Defaults to "8080". This should be one of the ports specified during template creation.
880904 path (str):
881905 Optional. The path to send the command to.
882906 query_params (dict[str, object]):
@@ -905,6 +929,8 @@ def send_command(
905929 if query_params :
906930 path = f"{ path } ?{ urlencode (query_params )} "
907931 headers ["Authorization" ] = f"Bearer { access_token } "
932+ headers ["X-Sandbox-Routing-Token" ] = routing_token
933+ headers ["X-Sandbox-Port" ] = port
908934 endpoint = endpoint + path if path .startswith ("/" ) else endpoint + "/" + path
909935 http_options = genai_types .HttpOptions (headers = headers , base_url = endpoint )
910936 http_client = genai .Client (vertexai = True , http_options = http_options )
@@ -920,6 +946,8 @@ def generate_browser_ws_headers(
920946 self ,
921947 sandbox_environment : types .SandboxEnvironment ,
922948 service_account_email : str ,
949+ routing_token : str ,
950+ port : str = "8080" ,
923951 timeout : int = 3600 ,
924952 ) -> tuple [str , dict [str , str ]]:
925953 """Generates the websocket upgrade headers for the browser.
@@ -929,47 +957,61 @@ def generate_browser_ws_headers(
929957 Required. The sandbox environment to generate websocket headers for.
930958 service_account_email (str):
931959 Required. The email of the service account to use for signing.
960+ routing_token (str):
961+ Required. The routing token to use for authorization. This can be
962+ found in the sandbox environment's connection_info.
963+ port (str):
964+ Optional. The port to use for the token. Defaults to "8080". This
965+ should be one of the ports specified during template creation.
932966 timeout (int):
933967 Optional. The timeout in seconds for the token. Defaults to 3600.
934968
935969 Returns:
936970 tuple[str, dict[str, str]]: A tuple containing the websocket URL and
937971 the headers for websocket upgrade.
938972 """
939- sandbox_id = sandbox_environment .name
973+ if not sandbox_environment .connection_info :
974+ raise ValueError ("Connection info is not available." )
975+
976+ ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
977+ connection_info = sandbox_environment .connection_info
978+ if connection_info .load_balancer_hostname :
979+ ws_base_url = "wss://" + connection_info .load_balancer_hostname
980+ elif connection_info .load_balancer_ip :
981+ ws_base_url = "ws://" + connection_info .load_balancer_ip
982+ else :
983+ raise ValueError ("Load balancer hostname or ip is not available." )
984+
940985 # port 8080 is the default port for http endpoint.
941986 http_access_token = self .generate_access_token (
942- service_account_email , sandbox_id , "8080" , timeout
987+ service_account_email , connection_info . load_balancer_hostname , port , timeout
943988 )
944989 response = self .send_command (
945990 http_method = "GET" ,
946991 access_token = http_access_token ,
992+ routing_token = routing_token ,
947993 sandbox_environment = sandbox_environment ,
994+ port = port ,
948995 path = "/cdp_ws_endpoint" ,
949996 )
950997 if not response :
951998 raise ValueError ("Failed to get the websocket endpoint." )
952999 body_dict = json .loads (response .body )
9531000 ws_path = body_dict ["endpoint" ]
954-
955- ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
956- if sandbox_environment and sandbox_environment .connection_info :
957- connection_info = sandbox_environment .connection_info
958- if connection_info .load_balancer_hostname :
959- ws_url = "wss://" + connection_info .load_balancer_hostname
960- elif connection_info .load_balancer_ip :
961- ws_url = "ws://" + connection_info .load_balancer_ip
962- else :
963- raise ValueError ("Load balancer hostname or ip is not available." )
964- ws_url = ws_url + "/" + ws_path
1001+ ws_url = ws_base_url + "/" + ws_path
9651002
9661003 # port 9222 is the default port for the browser websocket endpoint.
9671004 ws_access_token = self .generate_access_token (
968- service_account_email , sandbox_id , "9222" , timeout
1005+ service_account_email ,
1006+ connection_info .load_balancer_hostname ,
1007+ "9222" ,
1008+ timeout ,
9691009 )
9701010
9711011 headers = {}
972- headers ["Sec-WebSocket-Protocol" ] = f"binary, { ws_access_token } "
1012+ headers ["Sec-WebSocket-Protocol" ] = (
1013+ f"v1.stream, { ws_access_token } , { routing_token } , { port } "
1014+ )
9731015 return ws_url , headers
9741016
9751017
0 commit comments