diff --git a/e2e/tests.bats b/e2e/tests.bats index 703e15fc4..b3f0f4631 100644 --- a/e2e/tests.bats +++ b/e2e/tests.bats @@ -69,18 +69,18 @@ wait_for_exporter() { } @test "can create clients with admin cli" { - jmp admin create client -n "${JS_NAMESPACE}" test-client-oidc --unsafe --out /dev/null \ + jmp admin create client -n "${JS_NAMESPACE}" test-client-oidc --unsafe --nointeractive \ --oidc-username dex:test-client-oidc - jmp admin create client -n "${JS_NAMESPACE}" test-client-sa --unsafe --out /dev/null \ + jmp admin create client -n "${JS_NAMESPACE}" test-client-sa --unsafe --nointeractive \ --oidc-username dex:system:serviceaccount:"${JS_NAMESPACE}":test-client-sa jmp admin create client -n "${JS_NAMESPACE}" test-client-legacy --unsafe --save } @test "can create exporters with admin cli" { - jmp admin create exporter -n "${JS_NAMESPACE}" test-exporter-oidc --out /dev/null \ + jmp admin create exporter -n "${JS_NAMESPACE}" test-exporter-oidc --nointeractive \ --oidc-username dex:test-exporter-oidc \ --label example.com/board=oidc - jmp admin create exporter -n "${JS_NAMESPACE}" test-exporter-sa --out /dev/null \ + jmp admin create exporter -n "${JS_NAMESPACE}" test-exporter-sa --nointeractive \ --oidc-username dex:system:serviceaccount:"${JS_NAMESPACE}":test-exporter-sa \ --label example.com/board=sa jmp admin create exporter -n "${JS_NAMESPACE}" test-exporter-legacy --save \ diff --git a/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py index 5f4115fbd..376c38fef 100644 --- a/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py +++ b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py @@ -31,6 +31,11 @@ def opt_oidc(f): default=None, help="Port for OIDC callback server (0=random port)", ) + @click.option( + "--offline-access/--no-offline-access", + default=True, + help="Request offline_access scope (refresh token)", + ) @wraps(f) def wrapper(*args, **kwds): return f(*args, **kwds) @@ -42,6 +47,7 @@ def wrapper(*args, **kwds): class Config: issuer: str client_id: str + offline_access: bool = False scope: ClassVar[list[str]] = ["openid", "profile"] async def configuration(self): @@ -52,8 +58,13 @@ async def configuration(self): ) as response: return await response.json() + def _scopes(self) -> list[str]: + if self.offline_access: + return [*self.scope, "offline_access"] + return list(self.scope) + def client(self, **kwargs): - return OAuth2Session(client_id=self.client_id, scope=self.scope, **kwargs) + return OAuth2Session(client_id=self.client_id, scope=self._scopes(), **kwargs) async def token_exchange_grant(self, token: str, **kwargs): config = await self.configuration() @@ -71,6 +82,19 @@ async def token_exchange_grant(self, token: str, **kwargs): ) ) + async def refresh_token_grant(self, refresh_token: str): + config = await self.configuration() + + client = self.client() + + return await run_sync( + lambda: client.fetch_token( + config["token_endpoint"], + grant_type="refresh_token", + refresh_token=refresh_token, + ) + ) + async def password_grant(self, username: str, password: str): config = await self.configuration() diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/auth.py b/python/packages/jumpstarter-cli/jumpstarter_cli/auth.py index 630b7d11f..006210930 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/auth.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/auth.py @@ -1,14 +1,19 @@ from datetime import datetime, timezone import click +from jumpstarter_cli_common.blocking import blocking from jumpstarter_cli_common.config import opt_config from jumpstarter_cli_common.oidc import ( TOKEN_EXPIRY_WARNING_SECONDS, + Config, decode_jwt, + decode_jwt_issuer, format_duration, get_token_remaining_seconds, ) +from jumpstarter.config.client import ClientConfigV1Alpha1 + @click.group() def auth(): @@ -19,21 +24,52 @@ def _print_token_status(remaining: float) -> None: """Print token status message based on remaining time.""" duration = format_duration(remaining) + hint = "Run 'jmp login' to refresh your credentials." + if remaining < 0: click.echo(click.style(f"Status: EXPIRED ({duration} ago)", fg="red", bold=True)) - click.echo(click.style("Run 'jmp login --force' to refresh your credentials.", fg="yellow")) + click.echo(click.style(hint, fg="yellow")) elif remaining < TOKEN_EXPIRY_WARNING_SECONDS: click.echo(click.style(f"Status: EXPIRING SOON ({duration} remaining)", fg="red", bold=True)) - click.echo(click.style("Run 'jmp login --force' to refresh your credentials.", fg="yellow")) + click.echo(click.style(hint, fg="yellow")) elif remaining < 3600: click.echo(click.style(f"Status: Valid ({duration} remaining)", fg="yellow")) else: click.echo(click.style(f"Status: Valid ({duration} remaining)", fg="green")) +def _print_subject_issuer(payload: dict) -> None: + sub = payload.get("sub") + iss = payload.get("iss") + if sub: + click.echo(f"Subject: {sub}") + if iss: + click.echo(f"Issuer: {iss}") + + +def _print_timestamp(label: str, value: int | None) -> None: + if value is None: + return + dt = datetime.fromtimestamp(value, tz=timezone.utc) + click.echo(f"{label}: {dt.strftime('%Y-%m-%d %H:%M:%S %Z')}") + + +def _print_verbose_details(payload: dict, config) -> None: + iat = payload.get("iat") + auth_time = payload.get("auth_time") + if isinstance(iat, int): + _print_timestamp("Issued at", iat) + if isinstance(auth_time, int): + _print_timestamp("Auth time", auth_time) + + refresh_token = getattr(config, "refresh_token", None) + click.echo(f"Refresh token stored: {'yes' if refresh_token else 'no'}") + + @auth.command(name="status") +@click.option("--verbose", is_flag=True, help="Show additional token details") @opt_config(exporter=False) -def token_status(config): +def token_status(config, verbose: bool): """Display token status and expiry information.""" token_str = getattr(config, "token", None) @@ -58,10 +94,38 @@ def token_status(config): _print_token_status(remaining) - # Show additional token info - sub = payload.get("sub") - iss = payload.get("iss") - if sub: - click.echo(f"Subject: {sub}") - if iss: - click.echo(f"Issuer: {iss}") + _print_subject_issuer(payload) + + if verbose: + _print_verbose_details(payload, config) + + +@auth.command(name="refresh") +@opt_config(exporter=False) +@blocking +async def refresh_token(config): + """Refresh the access token using a stored refresh token.""" + refresh_token = getattr(config, "refresh_token", None) + if not refresh_token: + raise click.ClickException("No refresh token found. Run 'jmp login --offline-access'.") + + access_token = getattr(config, "token", None) + if not access_token: + raise click.ClickException("No access token found. Run 'jmp login --offline-access'.") + + try: + issuer = decode_jwt_issuer(access_token) + except Exception as e: + raise click.ClickException(f"Failed to decode JWT issuer: {e}") from e + + if issuer is None: + raise click.ClickException("Failed to determine issuer from access token.") + + oidc = Config(issuer=issuer, client_id="jumpstarter-cli") + tokens = await oidc.refresh_token_grant(refresh_token) + config.token = tokens["access_token"] + new_refresh_token = tokens.get("refresh_token") + if new_refresh_token is not None: + config.refresh_token = new_refresh_token + ClientConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] + click.echo("Access token refreshed.") diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/login.py b/python/packages/jumpstarter-cli/jumpstarter_cli/login.py index 60dcfc3ff..312fb8a62 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/login.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/login.py @@ -19,12 +19,6 @@ @click.option("-e", "--endpoint", type=str, help="Enter the Jumpstarter service endpoint.", default=None) @click.option("--namespace", type=str, help="Enter the Jumpstarter exporter namespace.", default=None) @click.option("--name", type=str, help="Enter the Jumpstarter exporter name.", default=None) -@click.option( - "--force", - is_flag=True, - help="Force fresh login", - default=False, -) @opt_oidc # client specific # TODO: warn if used with exporter @@ -54,11 +48,11 @@ async def login( # noqa: C901 client_id: str, connector_id: str, callback_port: int | None, + offline_access: bool, unsafe, insecure_tls_config: bool, nointeractive: bool, allow, - force: bool, ): """Login into a jumpstarter instance""" @@ -123,7 +117,38 @@ async def login( # noqa: C901 raise click.UsageError("Issuer is required in non-interactive mode.") issuer = click.prompt("Enter the OIDC issuer") - oidc = Config(issuer=issuer, client_id=client_id) + stored_refresh_token = getattr(config, "refresh_token", None) + oidc = Config( + issuer=issuer, + client_id=client_id, + offline_access=offline_access or stored_refresh_token is not None, + ) + + def save_config() -> None: + match config_kind: + case "client": + ClientConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] + case "client_config": + ClientConfigV1Alpha1.save(config, value) # ty: ignore[invalid-argument-type] + case "exporter": + ExporterConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] + case "exporter_config": + ExporterConfigV1Alpha1.save(config, value) # ty: ignore[invalid-argument-type] + + if stored_refresh_token and token is None and username is None and password is None: + try: + tokens = await oidc.refresh_token_grant(stored_refresh_token) + config.token = tokens["access_token"] + refresh_token = tokens.get("refresh_token") + if refresh_token is not None and isinstance(config, ClientConfigV1Alpha1): + config.refresh_token = refresh_token + save_config() + click.echo("Refreshed access token using stored refresh token.") + return + except Exception as e: + if nointeractive: + raise click.ClickException(f"Failed to refresh access token: {e}") from e + pass if token is not None: kwargs = {"connector_id": connector_id} if connector_id is not None else {} @@ -131,20 +156,16 @@ async def login( # noqa: C901 elif username is not None and password is not None: tokens = await oidc.password_grant(username, password) else: - prompt = "login" if force else None - tokens = await oidc.authorization_code_grant(callback_port=callback_port, prompt=prompt) + tokens = await oidc.authorization_code_grant(callback_port=callback_port) config.token = tokens["access_token"] + refresh_token = tokens.get("refresh_token") - match config_kind: - case "client": - ClientConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] - case "client_config": - ClientConfigV1Alpha1.save(config, value) # ty: ignore[invalid-argument-type] - case "exporter": - ExporterConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] - case "exporter_config": - ExporterConfigV1Alpha1.save(config, value) # ty: ignore[invalid-argument-type] + # only client configs support refresh_token + if refresh_token is not None and isinstance(config, ClientConfigV1Alpha1): + config.refresh_token = refresh_token + + save_config() @blocking @@ -157,9 +178,24 @@ async def relogin_client(config: ClientConfigV1Alpha1): raise ReauthenticationFailed(f"Failed to decode JWT issuer: {e}") from e try: - oidc = Config(issuer=issuer, client_id=client_id) + oidc = Config(issuer=issuer, client_id=client_id, offline_access=config.refresh_token is not None) + if config.refresh_token: + try: + tokens = await oidc.refresh_token_grant(config.refresh_token) + config.token = tokens["access_token"] + refresh_token = tokens.get("refresh_token") + if refresh_token is not None: + config.refresh_token = refresh_token + ClientConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] + return + except Exception: + pass + tokens = await oidc.authorization_code_grant() config.token = tokens["access_token"] + refresh_token = tokens.get("refresh_token") + if refresh_token is not None: + config.refresh_token = refresh_token ClientConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] except Exception as e: raise ReauthenticationFailed(f"Failed to re-authenticate: {e}") from e diff --git a/python/packages/jumpstarter/jumpstarter/config/client.py b/python/packages/jumpstarter/jumpstarter/config/client.py index 97f92c1ec..448f7d032 100644 --- a/python/packages/jumpstarter/jumpstarter/config/client.py +++ b/python/packages/jumpstarter/jumpstarter/config/client.py @@ -1,7 +1,9 @@ from __future__ import annotations import asyncio +import errno import os +import tempfile from contextlib import asynccontextmanager, contextmanager from datetime import datetime, timedelta from functools import wraps @@ -112,6 +114,7 @@ class ClientConfigV1Alpha1(BaseSettings): endpoint: str | None = Field(default=None) tls: TLSConfigV1Alpha1 = Field(default_factory=TLSConfigV1Alpha1) token: str | None = Field(default=None) + refresh_token: str | None = Field(default=None) grpcOptions: dict[str, str | int] | None = Field(default_factory=dict) drivers: ClientConfigV1Alpha1Drivers = Field(default_factory=ClientConfigV1Alpha1Drivers) @@ -344,13 +347,40 @@ def save(cls, config: Self, path: Optional[os.PathLike] = None) -> Path: config.path = cls._get_path(config.alias) else: config.path = Path(path) - with config.path.open(mode="w") as f: - yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), f, sort_keys=False) + config.path.parent.mkdir(parents=True, exist_ok=True) + payload = config.model_dump(mode="json", exclude={"path", "alias"}, exclude_none=True) + temp_fd, temp_path = tempfile.mkstemp(prefix=f".{config.path.name}.", dir=config.path.parent) + try: + os.fchmod(temp_fd, 0o600) + with os.fdopen(temp_fd, "w") as f: + yaml.safe_dump( + payload, + f, + sort_keys=False, + ) + f.flush() + os.fsync(f.fileno()) + os.replace(temp_path, config.path) + os.chmod(config.path, 0o600) + finally: + try: + os.unlink(temp_path) + except OSError as e: + if e.errno != errno.ENOENT: + raise return config.path @classmethod def dump_yaml(cls, config: Self) -> str: - return yaml.safe_dump(config.model_dump(mode="json", exclude={"path", "alias"}), sort_keys=False) + """Return YAML suitable for display""" + return yaml.safe_dump( + config.model_dump( + mode="json", + exclude={"path", "alias", "refresh_token"}, + exclude_none=True, + ), + sort_keys=False, + ) @classmethod def exists(cls, alias: str) -> bool: @@ -404,10 +434,21 @@ class ClientConfigListV1Alpha1(BaseModel): kind: Literal["ClientConfigList"] = Field(default="ClientConfigList") def dump_json(self): - return self.model_dump_json(indent=4, by_alias=True) + return self.model_dump_json( + indent=4, + by_alias=True, + exclude={"items": {"__all__": {"refresh_token"}}}, + ) def dump_yaml(self): - return yaml.safe_dump(self.model_dump(mode="json", by_alias=True), indent=2) + return yaml.safe_dump( + self.model_dump( + mode="json", + by_alias=True, + exclude={"items": {"__all__": {"refresh_token"}}}, + ), + indent=2, + ) model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)