|
| 1 | +import argparse |
| 2 | + |
| 3 | +from requests import HTTPError |
| 4 | +from rich.table import Table |
| 5 | + |
| 6 | +import dstack.api.server |
| 7 | +from dstack._internal.cli.commands import BaseCommand |
| 8 | +from dstack._internal.cli.utils.common import confirm_ask, console |
| 9 | +from dstack._internal.core.errors import ClientError, CLIError |
| 10 | +from dstack._internal.core.services.configs import ConfigManager |
| 11 | +from dstack._internal.utils.logging import get_logger |
| 12 | + |
| 13 | +logger = get_logger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +class ProjectCommand(BaseCommand): |
| 17 | + NAME = "project" |
| 18 | + DESCRIPTION = "Manage projects" |
| 19 | + |
| 20 | + def _register(self): |
| 21 | + super()._register() |
| 22 | + subparsers = self._parser.add_subparsers(dest="subcommand", help="Command to execute") |
| 23 | + |
| 24 | + # Add subcommand |
| 25 | + add_parser = subparsers.add_parser("add", help="Add or update a project") |
| 26 | + add_parser.add_argument( |
| 27 | + "--name", type=str, help="The name of the project to configure", required=True |
| 28 | + ) |
| 29 | + add_parser.add_argument("--url", type=str, help="Server url", required=True) |
| 30 | + add_parser.add_argument("--token", type=str, help="User token", required=True) |
| 31 | + add_parser.add_argument( |
| 32 | + "-y", |
| 33 | + "--yes", |
| 34 | + help="Don't ask for confirmation (e.g. update the config)", |
| 35 | + action="store_true", |
| 36 | + ) |
| 37 | + add_parser.add_argument( |
| 38 | + "-n", |
| 39 | + "--no", |
| 40 | + help="Don't ask for confirmation (e.g. do not update the config)", |
| 41 | + action="store_true", |
| 42 | + ) |
| 43 | + add_parser.set_defaults(subfunc=self._add) |
| 44 | + |
| 45 | + # Delete subcommand |
| 46 | + delete_parser = subparsers.add_parser("delete", help="Delete a project") |
| 47 | + delete_parser.add_argument( |
| 48 | + "--name", type=str, help="The name of the project to delete", required=True |
| 49 | + ) |
| 50 | + delete_parser.add_argument( |
| 51 | + "-y", |
| 52 | + "--yes", |
| 53 | + help="Don't ask for confirmation", |
| 54 | + action="store_true", |
| 55 | + ) |
| 56 | + delete_parser.set_defaults(subfunc=self._delete) |
| 57 | + |
| 58 | + # List subcommand |
| 59 | + list_parser = subparsers.add_parser("list", help="List configured projects") |
| 60 | + list_parser.set_defaults(subfunc=self._list) |
| 61 | + |
| 62 | + # Set default subcommand |
| 63 | + set_default_parser = subparsers.add_parser("set-default", help="Set default project") |
| 64 | + set_default_parser.add_argument( |
| 65 | + "name", type=str, help="The name of the project to set as default" |
| 66 | + ) |
| 67 | + set_default_parser.set_defaults(subfunc=self._set_default) |
| 68 | + |
| 69 | + def _command(self, args: argparse.Namespace): |
| 70 | + if not hasattr(args, "subfunc"): |
| 71 | + args.subfunc = self._list |
| 72 | + args.subfunc(args) |
| 73 | + |
| 74 | + def _add(self, args: argparse.Namespace): |
| 75 | + config_manager = ConfigManager() |
| 76 | + api_client = dstack.api.server.APIClient(base_url=args.url, token=args.token) |
| 77 | + try: |
| 78 | + api_client.projects.get(args.name) |
| 79 | + except HTTPError as e: |
| 80 | + if e.response.status_code == 403: |
| 81 | + raise CLIError("Forbidden. Ensure the token is valid.") |
| 82 | + elif e.response.status_code == 404: |
| 83 | + raise CLIError(f"Project '{args.name}' not found.") |
| 84 | + else: |
| 85 | + raise e |
| 86 | + default_project = config_manager.get_project_config() |
| 87 | + if ( |
| 88 | + default_project is None |
| 89 | + or default_project.name != args.name |
| 90 | + or default_project.url != args.url |
| 91 | + or default_project.token != args.token |
| 92 | + ): |
| 93 | + set_it_as_default = ( |
| 94 | + ( |
| 95 | + args.yes |
| 96 | + or not default_project |
| 97 | + or confirm_ask(f"Set '{args.name}' as your default project?") |
| 98 | + ) |
| 99 | + if not args.no |
| 100 | + else False |
| 101 | + ) |
| 102 | + config_manager.configure_project( |
| 103 | + name=args.name, url=args.url, token=args.token, default=set_it_as_default |
| 104 | + ) |
| 105 | + config_manager.save() |
| 106 | + logger.info( |
| 107 | + f"Configuration updated at {config_manager.config_filepath}", {"show_path": False} |
| 108 | + ) |
| 109 | + |
| 110 | + def _delete(self, args: argparse.Namespace): |
| 111 | + config_manager = ConfigManager() |
| 112 | + if args.yes or confirm_ask(f"Are you sure you want to delete project '{args.name}'?"): |
| 113 | + config_manager.delete_project(args.name) |
| 114 | + config_manager.save() |
| 115 | + console.print("[grey58]OK[/]") |
| 116 | + |
| 117 | + def _list(self, args: argparse.Namespace): |
| 118 | + config_manager = ConfigManager() |
| 119 | + default_project = config_manager.get_project_config() |
| 120 | + |
| 121 | + table = Table(box=None) |
| 122 | + table.add_column("PROJECT", style="bold", no_wrap=True) |
| 123 | + table.add_column("URL", style="grey58") |
| 124 | + table.add_column("USER", style="grey58") |
| 125 | + table.add_column("DEFAULT", justify="center") |
| 126 | + |
| 127 | + for project_name in config_manager.list_projects(): |
| 128 | + project_config = config_manager.get_project_config(project_name) |
| 129 | + is_default = project_name == default_project.name if default_project else False |
| 130 | + |
| 131 | + # Get username from API |
| 132 | + try: |
| 133 | + api_client = dstack.api.server.APIClient( |
| 134 | + base_url=project_config.url, token=project_config.token |
| 135 | + ) |
| 136 | + user_info = api_client.users.get_my_user() |
| 137 | + username = user_info.username |
| 138 | + except ClientError: |
| 139 | + username = "(invalid token)" |
| 140 | + |
| 141 | + table.add_row( |
| 142 | + project_name, |
| 143 | + project_config.url, |
| 144 | + username, |
| 145 | + "✓" if is_default else "", |
| 146 | + style="bold" if is_default else None, |
| 147 | + ) |
| 148 | + |
| 149 | + console.print(table) |
| 150 | + |
| 151 | + def _set_default(self, args: argparse.Namespace): |
| 152 | + config_manager = ConfigManager() |
| 153 | + project_config = config_manager.get_project_config(args.name) |
| 154 | + if project_config is None: |
| 155 | + raise CLIError(f"Project '{args.name}' not found") |
| 156 | + |
| 157 | + config_manager.configure_project( |
| 158 | + name=args.name, url=project_config.url, token=project_config.token, default=True |
| 159 | + ) |
| 160 | + config_manager.save() |
| 161 | + console.print("[grey58]OK[/]") |
0 commit comments