|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import os |
18 | | -import subprocess |
19 | 18 | from typing import Optional |
20 | | -from typing import Tuple |
21 | 19 |
|
22 | 20 | import click |
23 | 21 |
|
24 | 22 | from ..apps.app import validate_app_name |
25 | | -from .utils import gcp_utils |
| 23 | +from .utils import _onboarding |
26 | 24 |
|
27 | 25 | _INIT_PY_TEMPLATE = """\ |
28 | 26 | from . import agent |
|
48 | 46 | """ |
49 | 47 |
|
50 | 48 |
|
51 | | -_GOOGLE_API_MSG = """ |
52 | | -Don't have API Key? Create one in AI Studio: https://aistudio.google.com/apikey |
53 | | -""" |
54 | | - |
55 | | -_GOOGLE_CLOUD_SETUP_MSG = """ |
56 | | -You need an existing Google Cloud account and project, check out this link for details: |
57 | | -https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai |
58 | | -""" |
59 | | - |
60 | 49 | _OTHER_MODEL_MSG = """ |
61 | 50 | Please see below guide to configure other models: |
62 | 51 | https://google.github.io/adk-docs/agents/models |
63 | 52 | """ |
64 | 53 |
|
65 | | -_EXPRESS_TOS_MSG = """ |
66 | | -Google Cloud Express Mode Terms of Service: https://cloud.google.com/terms/google-cloud-express |
67 | | -By using this application, you agree to the Google Cloud Express Mode terms of service and any |
68 | | -applicable services and APIs: https://console.cloud.google.com/terms. You also agree to only use |
69 | | -this application for your trade, business, craft, or profession. |
70 | | -""" |
71 | | - |
72 | | -_NOT_ELIGIBLE_MSG = """ |
73 | | -You are not eligible for Express Mode. |
74 | | -Please follow these instructions to set up a full Google Cloud project: |
75 | | -https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai |
76 | | -""" |
77 | | - |
78 | 54 | _SUCCESS_MSG_CODE = """ |
79 | 55 | Agent created in {agent_folder}: |
80 | 56 | - .env |
|
96 | 72 | """ |
97 | 73 |
|
98 | 74 |
|
99 | | -def _get_gcp_project_from_gcloud() -> str: |
100 | | - """Uses gcloud to get default project.""" |
101 | | - try: |
102 | | - result = subprocess.run( |
103 | | - ["gcloud", "config", "get-value", "project"], |
104 | | - capture_output=True, |
105 | | - text=True, |
106 | | - check=True, |
107 | | - ) |
108 | | - return result.stdout.strip() |
109 | | - except (subprocess.CalledProcessError, FileNotFoundError): |
110 | | - return "" |
111 | | - |
112 | | - |
113 | | -def _get_gcp_region_from_gcloud() -> str: |
114 | | - """Uses gcloud to get default region.""" |
115 | | - try: |
116 | | - result = subprocess.run( |
117 | | - ["gcloud", "config", "get-value", "compute/region"], |
118 | | - capture_output=True, |
119 | | - text=True, |
120 | | - check=True, |
121 | | - ) |
122 | | - return result.stdout.strip() |
123 | | - except (subprocess.CalledProcessError, FileNotFoundError): |
124 | | - return "" |
125 | | - |
126 | | - |
127 | | -def _prompt_str( |
128 | | - prompt_prefix: str, |
129 | | - *, |
130 | | - prior_msg: Optional[str] = None, |
131 | | - default_value: Optional[str] = None, |
132 | | -) -> str: |
133 | | - if prior_msg: |
134 | | - click.secho(prior_msg, fg="green") |
135 | | - while True: |
136 | | - value: str = click.prompt( |
137 | | - prompt_prefix, default=default_value or None, type=str |
138 | | - ) |
139 | | - if value and value.strip(): |
140 | | - return value.strip() |
141 | | - |
142 | | - |
143 | | -def _prompt_for_google_cloud( |
144 | | - google_cloud_project: Optional[str], |
145 | | -) -> str: |
146 | | - """Prompts user for Google Cloud project ID.""" |
147 | | - google_cloud_project = ( |
148 | | - google_cloud_project |
149 | | - or os.environ.get("GOOGLE_CLOUD_PROJECT", None) |
150 | | - or _get_gcp_project_from_gcloud() |
151 | | - ) |
152 | | - |
153 | | - google_cloud_project = _prompt_str( |
154 | | - "Enter Google Cloud project ID", default_value=google_cloud_project |
155 | | - ) |
156 | | - |
157 | | - return google_cloud_project |
158 | | - |
159 | | - |
160 | | -def _prompt_for_google_cloud_region( |
161 | | - google_cloud_region: Optional[str], |
162 | | -) -> str: |
163 | | - """Prompts user for Google Cloud region.""" |
164 | | - google_cloud_region = ( |
165 | | - google_cloud_region |
166 | | - or os.environ.get("GOOGLE_CLOUD_LOCATION", None) |
167 | | - or _get_gcp_region_from_gcloud() |
168 | | - ) |
169 | | - |
170 | | - google_cloud_region = _prompt_str( |
171 | | - "Enter Google Cloud region", |
172 | | - default_value=google_cloud_region or "us-central1", |
173 | | - ) |
174 | | - return google_cloud_region |
175 | | - |
176 | | - |
177 | | -def _prompt_for_google_api_key( |
178 | | - google_api_key: Optional[str], |
179 | | -) -> str: |
180 | | - """Prompts user for Google API key.""" |
181 | | - google_api_key = google_api_key or os.environ.get("GOOGLE_API_KEY", None) |
182 | | - |
183 | | - google_api_key = _prompt_str( |
184 | | - "Enter Google API key", |
185 | | - prior_msg=_GOOGLE_API_MSG, |
186 | | - default_value=google_api_key, |
187 | | - ) |
188 | | - return google_api_key |
189 | | - |
190 | | - |
191 | 75 | def _generate_files( |
192 | 76 | agent_folder: str, |
193 | 77 | *, |
@@ -256,155 +140,6 @@ def _prompt_for_model() -> str: |
256 | 140 | return "<FILL_IN_MODEL>" |
257 | 141 |
|
258 | 142 |
|
259 | | -def _prompt_to_choose_backend( |
260 | | - google_api_key: Optional[str], |
261 | | - google_cloud_project: Optional[str], |
262 | | - google_cloud_region: Optional[str], |
263 | | -) -> Tuple[Optional[str], Optional[str], Optional[str]]: |
264 | | - """Prompts user to choose backend. |
265 | | -
|
266 | | - Returns: |
267 | | - A tuple of (google_api_key, google_cloud_project, google_cloud_region). |
268 | | - """ |
269 | | - backend_choice = click.prompt( |
270 | | - "1. Google AI\n2. Vertex AI\n3. Login with Google\nChoose a backend", |
271 | | - type=click.Choice(["1", "2", "3"]), |
272 | | - ) |
273 | | - if backend_choice == "1": |
274 | | - google_api_key = _prompt_for_google_api_key(google_api_key) |
275 | | - elif backend_choice == "2": |
276 | | - click.secho(_GOOGLE_CLOUD_SETUP_MSG, fg="green") |
277 | | - google_cloud_project = _prompt_for_google_cloud(google_cloud_project) |
278 | | - google_cloud_region = _prompt_for_google_cloud_region(google_cloud_region) |
279 | | - elif backend_choice == "3": |
280 | | - google_api_key, google_cloud_project, google_cloud_region = ( |
281 | | - _handle_login_with_google() |
282 | | - ) |
283 | | - return google_api_key, google_cloud_project, google_cloud_region |
284 | | - |
285 | | - |
286 | | -def _handle_login_with_google() -> ( |
287 | | - Tuple[Optional[str], Optional[str], Optional[str]] |
288 | | -): |
289 | | - """Handles the "Login with Google" flow.""" |
290 | | - if not gcp_utils.check_adc(): |
291 | | - click.secho( |
292 | | - "No Application Default Credentials found. " |
293 | | - "Opening browser for login...", |
294 | | - fg="yellow", |
295 | | - ) |
296 | | - try: |
297 | | - gcp_utils.login_adc() |
298 | | - except RuntimeError as e: |
299 | | - click.secho(str(e), fg="red") |
300 | | - raise click.Abort() |
301 | | - |
302 | | - # Check for existing Express project |
303 | | - express_project = gcp_utils.retrieve_express_project() |
304 | | - if express_project: |
305 | | - api_key = express_project.get("api_key") |
306 | | - project_id = express_project.get("project_id") |
307 | | - region = express_project.get("region", "us-central1") |
308 | | - if project_id: |
309 | | - click.secho(f"Using existing Express project: {project_id}", fg="green") |
310 | | - return api_key, project_id, region |
311 | | - |
312 | | - # Check for existing full GCP projects |
313 | | - projects = gcp_utils.list_gcp_projects(limit=20) |
314 | | - if projects: |
315 | | - click.secho("Recently created Google Cloud projects found:", fg="green") |
316 | | - click.echo("0. Enter project ID manually") |
317 | | - for i, (p_id, p_name) in enumerate(projects, 1): |
318 | | - click.echo(f"{i}. {p_name} ({p_id})") |
319 | | - |
320 | | - project_index = click.prompt( |
321 | | - "Select a project", |
322 | | - type=click.IntRange(0, len(projects)), |
323 | | - ) |
324 | | - if project_index == 0: |
325 | | - selected_project_id = _prompt_for_google_cloud(None) |
326 | | - else: |
327 | | - selected_project_id = projects[project_index - 1][0] |
328 | | - region = _prompt_for_google_cloud_region(None) |
329 | | - return None, selected_project_id, region |
330 | | - |
331 | | - click.secho( |
332 | | - "A Google Cloud project is required to continue. You can enter an" |
333 | | - " existing project ID or create an Express Mode project. Learn more:" |
334 | | - " https://cloud.google.com/resources/cloud-express-faqs", |
335 | | - fg="green", |
336 | | - ) |
337 | | - action = click.prompt( |
338 | | - "1. Enter an existing Google Cloud project ID\n" |
339 | | - "2. Create a new project (Express Mode)\n" |
340 | | - "3. Abandon\n" |
341 | | - "Choose an action", |
342 | | - type=click.Choice(["1", "2", "3"]), |
343 | | - ) |
344 | | - |
345 | | - if action == "3": |
346 | | - raise click.Abort() |
347 | | - |
348 | | - if action == "1": |
349 | | - google_cloud_project = _prompt_for_google_cloud(None) |
350 | | - google_cloud_region = _prompt_for_google_cloud_region(None) |
351 | | - return None, google_cloud_project, google_cloud_region |
352 | | - |
353 | | - elif action == "2": |
354 | | - if gcp_utils.check_express_eligibility(): |
355 | | - click.secho(_EXPRESS_TOS_MSG, fg="yellow") |
356 | | - if click.confirm("Do you accept the Terms of Service?", default=False): |
357 | | - selected_region = click.prompt( |
358 | | - """\ |
359 | | -Choose a region for Express Mode: |
360 | | -1. us-central1 |
361 | | -2. europe-west1 |
362 | | -3. asia-southeast1 |
363 | | -Choose region""", |
364 | | - type=click.Choice(["1", "2", "3"]), |
365 | | - default="1", |
366 | | - ) |
367 | | - region_map = { |
368 | | - "1": "us-central1", |
369 | | - "2": "europe-west1", |
370 | | - "3": "asia-southeast1", |
371 | | - } |
372 | | - region = region_map[selected_region] |
373 | | - express_info = gcp_utils.sign_up_express(location=region) |
374 | | - api_key = express_info.get("api_key") |
375 | | - project_id = express_info.get("project_id") |
376 | | - region = express_info.get("region", region) |
377 | | - click.secho( |
378 | | - f"Express Mode project created: {project_id}", |
379 | | - fg="green", |
380 | | - ) |
381 | | - current_proj = _get_gcp_project_from_gcloud() |
382 | | - if current_proj and current_proj != project_id: |
383 | | - click.secho( |
384 | | - "Warning: Your default gcloud project is set to" |
385 | | - f" '{current_proj}'. This might conflict with or override your" |
386 | | - f" Express Mode project '{project_id}'. We recommend" |
387 | | - " unsetting it.", |
388 | | - fg="yellow", |
389 | | - ) |
390 | | - if click.confirm("Run 'gcloud config unset project'?", default=True): |
391 | | - try: |
392 | | - subprocess.run( |
393 | | - ["gcloud", "config", "unset", "project"], |
394 | | - check=True, |
395 | | - capture_output=True, |
396 | | - ) |
397 | | - click.secho("Unset default gcloud project.", fg="green") |
398 | | - except Exception: |
399 | | - click.secho( |
400 | | - "Failed to unset project. Please do it manually.", fg="red" |
401 | | - ) |
402 | | - return api_key, project_id, region |
403 | | - |
404 | | - click.secho(_NOT_ELIGIBLE_MSG, fg="red") |
405 | | - raise click.Abort() |
406 | | - |
407 | | - |
408 | 143 | def _prompt_to_choose_type() -> str: |
409 | 144 | """Prompts user to choose type of agent to create.""" |
410 | 145 | type_choice = click.prompt( |
@@ -464,11 +199,18 @@ def run_cmd( |
464 | 199 |
|
465 | 200 | if not google_api_key and not (google_cloud_project and google_cloud_region): |
466 | 201 | if model.startswith("gemini"): |
467 | | - google_api_key, google_cloud_project, google_cloud_region = ( |
468 | | - _prompt_to_choose_backend( |
469 | | - google_api_key, google_cloud_project, google_cloud_region |
470 | | - ) |
| 202 | + auth_info = _onboarding.prompt_to_choose_backend( |
| 203 | + google_api_key, google_cloud_project, google_cloud_region |
471 | 204 | ) |
| 205 | + if isinstance(auth_info, _onboarding.GoogleAIAuth): |
| 206 | + google_api_key = auth_info.api_key |
| 207 | + elif isinstance(auth_info, _onboarding.VertexAIAuth): |
| 208 | + google_cloud_project = auth_info.project_id |
| 209 | + google_cloud_region = auth_info.region |
| 210 | + elif isinstance(auth_info, _onboarding.ExpressModeAuth): |
| 211 | + google_api_key = auth_info.api_key |
| 212 | + google_cloud_project = auth_info.project_id |
| 213 | + google_cloud_region = auth_info.region |
472 | 214 |
|
473 | 215 | if not type: |
474 | 216 | type = _prompt_to_choose_type() |
|
0 commit comments