Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/packages/jumpstarter-cli/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import pytest


@pytest.fixture
def anyio_backend():
return "asyncio"
182 changes: 169 additions & 13 deletions python/packages/jumpstarter-cli/jumpstarter_cli/shell.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import sys
from datetime import timedelta

Expand All @@ -8,6 +9,8 @@
from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication
from jumpstarter_cli_common.oidc import (
TOKEN_EXPIRY_WARNING_SECONDS,
Config,
decode_jwt_issuer,
format_duration,
get_token_remaining_seconds,
)
Expand All @@ -19,6 +22,11 @@
from jumpstarter.config.client import ClientConfigV1Alpha1
from jumpstarter.config.exporter import ExporterConfigV1Alpha1

logger = logging.getLogger(__name__)

# Refresh token when less than this many seconds remain
_TOKEN_REFRESH_THRESHOLD_SECONDS = 120


def _warn_about_expired_token(lease_name: str, selector: str) -> None:
"""Warn user that lease won't be cleaned up due to expired token."""
Expand All @@ -27,36 +35,184 @@ def _warn_about_expired_token(lease_name: str, selector: str) -> None:
click.echo(click.style(f"To reconnect: JMP_LEASE={lease_name} jmp shell", fg="cyan"))


async def _monitor_token_expiry(config, cancel_scope) -> None:
"""Monitor token expiry and warn user."""
async def _update_lease_channel(config, lease) -> None:
"""Update the lease's gRPC channel with the current config credentials."""
if lease is not None:
new_channel = await config.channel()
lease.refresh_channel(new_channel)


async def _try_refresh_token(config, lease) -> bool:
"""Attempt to refresh the token and update the lease channel.

Returns True if refresh succeeded, False otherwise.
"""
refresh_token = getattr(config, "refresh_token", None)
if not refresh_token:
return False

old_token = config.token
old_refresh_token = config.refresh_token
try:
issuer = decode_jwt_issuer(config.token)
oidc = Config(
issuer=issuer,
client_id="jumpstarter-cli",
offline_access=True,
)

tokens = await oidc.refresh_token_grant(refresh_token)
config.token = tokens["access_token"]
new_refresh_token = tokens.get("refresh_token")
if new_refresh_token is not None:
config.refresh_token = new_refresh_token

# Update the lease channel first (critical for the running session)
await _update_lease_channel(config, lease)

# Persist to disk (best-effort, uses original config path)
try:
ClientConfigV1Alpha1.save(config, path=config.path)
except Exception as e:
logger.warning("Failed to save refreshed token to disk: %s", e)

return True
except Exception as e:
# Restore old token so the monitor doesn't think we succeeded
config.token = old_token
config.refresh_token = old_refresh_token
logger.debug("Token refresh failed: %s", e)
return False


async def _try_reload_token_from_disk(config, lease) -> bool:
"""Check if the config on disk has a newer/valid token (e.g. from 'jmp login').

If a valid token is found on disk, updates the in-memory config and lease channel.
Returns True if a valid token was loaded, False otherwise.
"""
config_path = getattr(config, "path", None)
if not config_path:
return False

old_token = config.token
old_refresh_token = config.refresh_token
try:
disk_config = ClientConfigV1Alpha1.from_file(config_path)
disk_token = getattr(disk_config, "token", None)
if not disk_token or disk_token == config.token:
return False

# Check if the token on disk is actually valid
disk_remaining = get_token_remaining_seconds(disk_token)
if disk_remaining is None or disk_remaining <= 0:
return False

# Token on disk is valid and different - use it
config.token = disk_token
disk_refresh = getattr(disk_config, "refresh_token", None)
if disk_refresh is not None:
config.refresh_token = disk_refresh

# Update the lease channel (critical for the running session)
await _update_lease_channel(config, lease)

return True
except Exception as e:
config.token = old_token
config.refresh_token = old_refresh_token
logger.debug("Failed to reload token from disk: %s", e)
return False


async def _attempt_token_recovery(config, lease, remaining) -> str | None:
"""Try all available methods to recover a valid token.

Attempts OIDC refresh first, then falls back to reloading from disk
(e.g. if user ran 'jmp login' from the shell).

Returns a message describing the recovery method, or None if all failed.
"""
if await _try_refresh_token(config, lease):
return "Token refreshed automatically."
if await _try_reload_token_from_disk(config, lease):
return "Token reloaded from login."
return None


def _warn_refresh_failed(remaining: float) -> None:
"""Warn the user that token refresh failed."""
if remaining > 0:
duration = format_duration(remaining)
click.echo(
click.style(
f"\nToken expires in {duration} and auto-refresh failed. "
"Run 'jmp login' from this shell to refresh manually.",
fg="yellow",
bold=True,
)
)
else:
click.echo(
click.style(
"\nToken expired and auto-refresh failed. "
"New commands will fail until you run 'jmp login' from this shell.",
fg="red",
bold=True,
)
)


async def _monitor_token_expiry(config, lease, cancel_scope) -> None:
"""Monitor token expiry, auto-refresh when possible, warn user otherwise.

this monitor:
1. Proactively refreshes the token before it expires using the refresh token
2. Updates the lease's gRPC channel with new credentials
3. If refresh fails, periodically checks the config on disk for a token
refreshed externally (e.g. via 'jmp login' from within the shell)
4. Never cancels the scope - the shell stays alive regardless
"""
token = getattr(config, "token", None)
if not token:
return

warned = False
warned_expiry = False
warned_refresh_failed = False
while not cancel_scope.cancel_called:
try:
remaining = get_token_remaining_seconds(token)
# Re-read config.token each iteration since it may have been refreshed
remaining = get_token_remaining_seconds(config.token)
if remaining is None:
return

if remaining <= 0:
click.echo(click.style("\nToken expired! Exiting shell.", fg="red", bold=True))
cancel_scope.cancel()
return
# Try to refresh proactively before the token expires
if remaining <= _TOKEN_REFRESH_THRESHOLD_SECONDS:
recovery_msg = await _attempt_token_recovery(config, lease, remaining)
if recovery_msg:
click.echo(click.style(f"\n{recovery_msg}", fg="green"))
warned_expiry = False
warned_refresh_failed = False
elif not warned_refresh_failed:
_warn_refresh_failed(remaining)
warned_refresh_failed = True

if remaining <= TOKEN_EXPIRY_WARNING_SECONDS and not warned:
elif remaining <= TOKEN_EXPIRY_WARNING_SECONDS and not warned_expiry:
duration = format_duration(remaining)
click.echo(
click.style(
f"\nToken expires in {duration}. Session will continue but cleanup may fail on exit.",
f"\nToken expires in {duration}. Will attempt auto-refresh.",
fg="yellow",
bold=True,
)
)
warned = True
warned_expiry = True

await anyio.sleep(30)
# Check more frequently as we approach expiry
if remaining <= _TOKEN_REFRESH_THRESHOLD_SECONDS:
await anyio.sleep(5)
else:
await anyio.sleep(30)
except Exception:
return

Expand Down Expand Up @@ -111,7 +267,7 @@ async def _shell_with_signal_handling(
lease_used = lease

# Start token monitoring only once we're in the shell
tg.start_soon(_monitor_token_expiry, config, tg.cancel_scope)
tg.start_soon(_monitor_token_expiry, config, lease, tg.cancel_scope)

exit_code = await anyio.to_thread.run_sync(
_run_shell_with_lease, lease, exporter_logs, config, command
Expand Down
Loading
Loading