@@ -56,6 +56,23 @@ def _CreateAgentEngineSandboxConfig_to_vertex(
5656 if getv (from_object , ["ttl" ]) is not None :
5757 setv (parent_object , ["ttl" ], getv (from_object , ["ttl" ]))
5858
59+ if getv (from_object , ["sandbox_environment_template" ]) is not None :
60+ setv (
61+ parent_object ,
62+ ["sandboxEnvironmentTemplate" ],
63+ getv (from_object , ["sandbox_environment_template" ]),
64+ )
65+
66+ if getv (from_object , ["sandbox_environment_snapshot" ]) is not None :
67+ setv (
68+ parent_object ,
69+ ["sandboxEnvironmentSnapshot" ],
70+ getv (from_object , ["sandbox_environment_snapshot" ]),
71+ )
72+
73+ if getv (from_object , ["owner" ]) is not None :
74+ setv (parent_object , ["owner" ], getv (from_object , ["owner" ]))
75+
5976 return to_object
6077
6178
@@ -820,7 +837,7 @@ def delete(
820837 def generate_access_token (
821838 self ,
822839 service_account_email : str ,
823- sandbox_id : str ,
840+ sandbox_hostname : str ,
824841 port : str = "8080" ,
825842 timeout : int = 3600 ,
826843 ) -> str :
@@ -829,8 +846,8 @@ def generate_access_token(
829846 Args:
830847 service_account_email (str):
831848 Required. The email of the service account to use for signing.
832- sandbox_id (str):
833- Required. The resource name of the sandbox to generate a token for.
849+ sandbox_hostname (str):
850+ Required. The hostname of the sandbox to generate a token for.
834851 port (str):
835852 Optional. The port to use for the token. Defaults to "8080".
836853 timeout (int):
@@ -841,13 +858,18 @@ def generate_access_token(
841858 """
842859 client = iam_credentials_v1 .IAMCredentialsClient ()
843860 name = f"projects/-/serviceAccounts/{ service_account_email } "
844- custom_claims = {"port" : port , "sandbox_id" : sandbox_id }
861+ custom_claims = {
862+ "hostname" : sandbox_hostname ,
863+ "port" : port ,
864+ "sandbox_id" : sandbox_id ,
865+ }
845866 payload = {
846867 "iat" : int (time .time ()),
847868 "exp" : int (time .time ()) + timeout ,
848869 "iss" : service_account_email ,
870+ "sub" : service_account_email ,
849871 "nonce" : secrets .randbelow (1000000000 ) + 1 ,
850- "aud" : "vmaas-proxy-api " , # default audience for sandbox proxy
872+ "aud" : "https://aiplatform.googleapis.com/ " , # default audience for sandbox proxy
851873 ** custom_claims ,
852874 }
853875 request = iam_credentials_v1 .SignJwtRequest (
@@ -862,7 +884,9 @@ def send_command(
862884 * ,
863885 http_method : str ,
864886 access_token : str ,
887+ routing_token : str ,
865888 sandbox_environment : types .SandboxEnvironment ,
889+ port : str = "8080" ,
866890 path : Optional [str ] = None ,
867891 query_params : Optional [dict [str , object ]] = None ,
868892 headers : Optional [dict [str , str ]] = None ,
@@ -875,8 +899,12 @@ def send_command(
875899 Required. The HTTP method to use for the command.
876900 access_token (str):
877901 Required. The access token to use for authorization.
902+ routing_token (str):
903+ Required. The routing token to use for authorization. This can be found in the sandbox environment's connection_info.
878904 sandbox_environment (types.SandboxEnvironment):
879905 Required. The sandbox environment to send the command to.
906+ port (str):
907+ Optional. The port to use for the token. Defaults to "8080". This should be one of the ports specified during template creation.
880908 path (str):
881909 Optional. The path to send the command to.
882910 query_params (dict[str, object]):
@@ -905,6 +933,8 @@ def send_command(
905933 if query_params :
906934 path = f"{ path } ?{ urlencode (query_params )} "
907935 headers ["Authorization" ] = f"Bearer { access_token } "
936+ headers ["X-Sandbox-Routing-Token" ] = routing_token
937+ headers ["X-Sandbox-Port" ] = port
908938 endpoint = endpoint + path if path .startswith ("/" ) else endpoint + "/" + path
909939 http_options = genai_types .HttpOptions (headers = headers , base_url = endpoint )
910940 http_client = genai .Client (vertexai = True , http_options = http_options )
@@ -920,6 +950,8 @@ def generate_browser_ws_headers(
920950 self ,
921951 sandbox_environment : types .SandboxEnvironment ,
922952 service_account_email : str ,
953+ routing_token : str ,
954+ port : str = "8080" ,
923955 timeout : int = 3600 ,
924956 ) -> tuple [str , dict [str , str ]]:
925957 """Generates the websocket upgrade headers for the browser.
@@ -929,47 +961,61 @@ def generate_browser_ws_headers(
929961 Required. The sandbox environment to generate websocket headers for.
930962 service_account_email (str):
931963 Required. The email of the service account to use for signing.
964+ routing_token (str):
965+ Required. The routing token to use for authorization. This can be
966+ found in the sandbox environment's connection_info.
967+ port (str):
968+ Optional. The port to use for the token. Defaults to "8080". This
969+ should be one of the ports specified during template creation.
932970 timeout (int):
933971 Optional. The timeout in seconds for the token. Defaults to 3600.
934972
935973 Returns:
936974 tuple[str, dict[str, str]]: A tuple containing the websocket URL and
937975 the headers for websocket upgrade.
938976 """
939- sandbox_id = sandbox_environment .name
977+ if not sandbox_environment .connection_info :
978+ raise ValueError ("Connection info is not available." )
979+
980+ ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
981+ connection_info = sandbox_environment .connection_info
982+ if connection_info .load_balancer_hostname :
983+ ws_base_url = "wss://" + connection_info .load_balancer_hostname
984+ elif connection_info .load_balancer_ip :
985+ ws_base_url = "ws://" + connection_info .load_balancer_ip
986+ else :
987+ raise ValueError ("Load balancer hostname or ip is not available." )
988+
940989 # port 8080 is the default port for http endpoint.
941990 http_access_token = self .generate_access_token (
942- service_account_email , sandbox_id , "8080" , timeout
991+ service_account_email , connection_info . load_balancer_hostname , port , timeout
943992 )
944993 response = self .send_command (
945994 http_method = "GET" ,
946995 access_token = http_access_token ,
996+ routing_token = routing_token ,
947997 sandbox_environment = sandbox_environment ,
998+ port = port ,
948999 path = "/cdp_ws_endpoint" ,
9491000 )
9501001 if not response :
9511002 raise ValueError ("Failed to get the websocket endpoint." )
9521003 body_dict = json .loads (response .body )
9531004 ws_path = body_dict ["endpoint" ]
954-
955- ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
956- if sandbox_environment and sandbox_environment .connection_info :
957- connection_info = sandbox_environment .connection_info
958- if connection_info .load_balancer_hostname :
959- ws_url = "wss://" + connection_info .load_balancer_hostname
960- elif connection_info .load_balancer_ip :
961- ws_url = "ws://" + connection_info .load_balancer_ip
962- else :
963- raise ValueError ("Load balancer hostname or ip is not available." )
964- ws_url = ws_url + "/" + ws_path
1005+ ws_url = ws_base_url + "/" + ws_path
9651006
9661007 # port 9222 is the default port for the browser websocket endpoint.
9671008 ws_access_token = self .generate_access_token (
968- service_account_email , sandbox_id , "9222" , timeout
1009+ service_account_email ,
1010+ connection_info .load_balancer_hostname ,
1011+ "9222" ,
1012+ timeout ,
9691013 )
9701014
9711015 headers = {}
972- headers ["Sec-WebSocket-Protocol" ] = f"binary, { ws_access_token } "
1016+ headers ["Sec-WebSocket-Protocol" ] = (
1017+ f"v1.stream, { ws_access_token } , { routing_token } , { port } "
1018+ )
9731019 return ws_url , headers
9741020
9751021
0 commit comments