Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/packages/jumpstarter-driver-flashers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ Options:
--insecure-tls Skip TLS certificate verification
--header TEXT Custom HTTP header in 'Key: Value' format
--bearer TEXT Bearer token for HTTP authentication
--oci-username TEXT OCI registry username (or OCI_USERNAME environment variable)
--oci-password TEXT OCI registry password (or OCI_PASSWORD environment variable)
--retries INTEGER Number of retry attempts for flash operation
(default: 3)
--method [fls|shell] Method to use for flash operation (default:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import hashlib
import json
import os
Expand Down Expand Up @@ -129,14 +130,13 @@ def flash( # noqa: C901
oci_username: str | None = None,
oci_password: str | None = None,
):
"""Flash image to DUT"""
if bearer_token:
bearer_token = self._validate_bearer_token(bearer_token)

if headers:
headers = self._validate_header_dict(headers)
oci_username, oci_password = self._validate_oci_credentials(oci_username, oci_password)

"""Flash image to DUT"""
oci_username, oci_password = self._resolve_oci_credentials(path, oci_username, oci_password)
should_download_to_httpd = True
image_url = ""
original_http_url = None
Expand Down Expand Up @@ -653,7 +653,7 @@ def _flash_with_fls(

# Flash the image
creds_file = None
with self._redaction_scope([oci_username, oci_password]):
with self._redaction_scope([oci_password]):
if str(path).startswith("oci://") and oci_username:
creds_file = self._setup_fls_oci_credential_file(console, prompt, oci_username, oci_password or "")

Expand All @@ -664,12 +664,8 @@ def _flash_with_fls(
)
console.sendline(flash_cmd)

try:
# Start monitoring the flash operation
self._monitor_fls_progress(console, prompt)
finally:
if creds_file:
self._cleanup_fls_oci_credential_file(console, prompt, creds_file)
# Start monitoring the flash operation
self._monitor_fls_progress(console, prompt)

self.logger.info("Flushing buffers")
console.sendline("sync")
Expand Down Expand Up @@ -1304,12 +1300,24 @@ def _validate_oci_credentials(

if bool(username) != bool(password):
raise click.ClickException(
"OCI authentication requires both --oci-username and --oci-password "
"(or OCI_USERNAME and OCI_PASSWORD environment variables)"
"OCI authentication requires both OCI_USERNAME and OCI_PASSWORD "
"environment variables (or both oci_username and oci_password arguments)"
)

return username, password

def _resolve_oci_credentials(
self, path: PathBuf, username: str | None, password: str | None
) -> tuple[str | None, str | None]:
if username is None and password is None and path.startswith("oci://"):
username = os.environ.get("OCI_USERNAME")
password = os.environ.get("OCI_PASSWORD")

if username or password:
self.logger.info("Using OCI registry credentials from environment variables")

return self._validate_oci_credentials(username, password)

def _fls_oci_auth_env(self, path: PathBuf, creds_file: str | None) -> str:
if not str(path).startswith("oci://") or not creds_file:
return ""
Expand Down Expand Up @@ -1345,20 +1353,35 @@ def _temporarily_disable_console_debug_stream(self, console):
def _setup_fls_oci_credential_file(
self, console, prompt: str, oci_username: str, oci_password: str, creds_file: str = "/tmp/fls_creds"
) -> str:
# Write credential file using base64-encoded chunks to avoid serial
# console line buffer overflow with long tokens (e.g. 1400+ char JWTs).
creds_content = (
f"FLS_REGISTRY_USERNAME={shlex.quote(oci_username)}\n"
f"FLS_REGISTRY_PASSWORD={shlex.quote(oci_password)}\n"
)
encoded = base64.b64encode(creds_content.encode()).decode()

chunk_size = 512
with self._temporarily_disable_console_debug_stream(console):
console.sendline(f"umask 077 && cat > {shlex.quote(creds_file)} <<'EOF_FLS_CREDS'")
console.sendline(f"FLS_REGISTRY_USERNAME={shlex.quote(oci_username)}")
console.sendline(f"FLS_REGISTRY_PASSWORD={shlex.quote(oci_password)}")
console.sendline("EOF_FLS_CREDS")
console.sendline(f"true > {shlex.quote(creds_file)}")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
console.sendline(f"chmod 600 {shlex.quote(creds_file)}")

# Write base64 data in chunks to a temp file
b64_file = f"{creds_file}.b64"
console.sendline(f"true > {shlex.quote(b64_file)}")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
return creds_file

def _cleanup_fls_oci_credential_file(self, console, prompt: str, creds_file: str):
with self._temporarily_disable_console_debug_stream(console):
console.sendline(f"rm -f {shlex.quote(creds_file)}")
for i in range(0, len(encoded), chunk_size):
chunk = encoded[i : i + chunk_size]
console.sendline(f"printf '%s' {shlex.quote(chunk)} >> {shlex.quote(b64_file)}")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)

# Decode into the actual creds file
console.sendline(f"base64 -d {shlex.quote(b64_file)} > {shlex.quote(creds_file)}")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
console.sendline(f"chmod 600 {shlex.quote(creds_file)}")
console.expect(prompt, timeout=EXPECT_TIMEOUT_DEFAULT)
return creds_file
Comment thread
mangelajo marked this conversation as resolved.

def _resolve_flash_parameters(
self, file: str | None, partitions: tuple[str, ...] | None, block_device: str | None
Expand Down Expand Up @@ -1425,6 +1448,7 @@ def _resolve_flash_parameters(

return flash_ops


def cli(self):
@driver_click_group(self)
def base():
Expand Down Expand Up @@ -1465,18 +1489,6 @@ def base():
type=str,
help="Bearer token for HTTP authentication",
)
@click.option(
"--oci-username",
type=str,
envvar="OCI_USERNAME",
help="OCI registry username (or OCI_USERNAME environment variable)",
)
@click.option(
"--oci-password",
type=str,
envvar="OCI_PASSWORD",
help="OCI registry password (or OCI_PASSWORD environment variable)",
)
@click.option(
"--retries",
type=int,
Expand Down Expand Up @@ -1514,8 +1526,6 @@ def flash(
insecure_tls,
header,
bearer,
oci_username,
oci_password,
retries,
method,
fls_version,
Expand Down Expand Up @@ -1576,8 +1586,6 @@ def flash(
insecure_tls=insecure_tls,
headers=headers_dict,
bearer_token=bearer,
oci_username=oci_username,
oci_password=oci_password,
retries=retries,
method=method,
fls_version=fls_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,38 @@ def test_validate_oci_credentials_accepts_pair_and_strips_whitespace():
assert password == "mypassword"


def test_resolve_oci_credentials_reads_env_for_oci_path(monkeypatch):
"""Test OCI credentials are read from environment for OCI paths."""
client = MockFlasherClient()
monkeypatch.setenv("OCI_USERNAME", "env-user")
monkeypatch.setenv("OCI_PASSWORD", "env-pass")

username, password = client._resolve_oci_credentials("oci://quay.io/org/image:tag", None, None)
assert username == "env-user"
assert password == "env-pass"


def test_resolve_oci_credentials_ignores_env_for_non_oci_path(monkeypatch):
"""Test OCI credential env vars are ignored for non-OCI image paths."""
client = MockFlasherClient()
monkeypatch.setenv("OCI_USERNAME", "env-user")
monkeypatch.setenv("OCI_PASSWORD", "env-pass")

username, password = client._resolve_oci_credentials("https://example.com/image.raw.xz", None, None)
assert username is None
assert password is None


def test_resolve_oci_credentials_rejects_partial_env_for_oci_path(monkeypatch):
"""Test partial OCI env credentials are rejected for OCI paths."""
client = MockFlasherClient()
monkeypatch.setenv("OCI_USERNAME", "env-user")
monkeypatch.delenv("OCI_PASSWORD", raising=False)

with pytest.raises(click.ClickException, match="OCI authentication requires both"):
client._resolve_oci_credentials("oci://quay.io/org/image:tag", None, None)


def test_fls_oci_auth_env_sources_credentials_file():
"""Test OCI auth shell snippet sources the on-target credentials file"""
client = MockFlasherClient()
Expand Down Expand Up @@ -101,8 +133,8 @@ def test_redact_sensitive_values_masks_username_and_password():
assert result == "user=*** pass=***"


def test_setup_and_cleanup_fls_oci_credential_file():
"""Test secure credentials file setup and cleanup commands."""
def test_setup_fls_oci_credential_file():
"""Test secure credentials file setup commands."""
client = MockFlasherClient()

class MockConsole:
Expand All @@ -120,14 +152,75 @@ def expect(self, prompt, timeout=None):
console = MockConsole()
creds_path = client._setup_fls_oci_credential_file(console, "#", "myuser", "my'password")
assert creds_path == "/tmp/fls_creds"
assert "cat > /tmp/fls_creds <<'EOF_FLS_CREDS'" in console.sent_lines[0]
assert "FLS_REGISTRY_USERNAME=myuser" in console.sent_lines[1]
assert "FLS_REGISTRY_PASSWORD='my'\"'\"'password'" in console.sent_lines[2]
assert "chmod 600 /tmp/fls_creds" in console.sent_lines[4]

# Verify chunked base64 approach: creates file, writes b64 chunks, decodes, cleans up
assert "true > /tmp/fls_creds" in console.sent_lines[0]
assert "true > /tmp/fls_creds.b64" in console.sent_lines[1]

# Find the base64 chunk lines (printf commands)
b64_lines = [line for line in console.sent_lines if "printf" in line and ".b64" in line]
assert len(b64_lines) >= 1

# Verify decode step
assert any("base64 -d" in line for line in console.sent_lines)
assert any("chmod 600 /tmp/fls_creds" in line for line in console.sent_lines)

# Verify the decoded content is correct
import base64

b64_data = ""
for line in b64_lines:
# Extract the base64 chunk from: printf '%s' <chunk> >> /tmp/fls_creds.b64
parts = shlex.split(line)
b64_data += parts[2] # the chunk argument
decoded = base64.b64decode(b64_data).decode()
assert "FLS_REGISTRY_USERNAME=myuser" in decoded
assert "FLS_REGISTRY_PASSWORD='my'\"'\"'password'" in decoded

assert console.logfile_read is not None

client._cleanup_fls_oci_credential_file(console, "#", "/tmp/fls_creds")
assert "rm -f /tmp/fls_creds" in console.sent_lines[-1]

def test_setup_fls_oci_credential_file_chunks_long_tokens():
"""Test that long JWT tokens are split into multiple base64 chunks."""
client = MockFlasherClient()

class MockConsole:
def __init__(self):
self.logfile_read = object()
self.sent_lines = []
self.expect_calls = []

def sendline(self, line):
self.sent_lines.append(line)

def expect(self, prompt, timeout=None):
self.expect_calls.append((prompt, timeout))

console = MockConsole()
# Simulate a 1400-char JWT token (similar to real Kubernetes service account tokens)
long_token = "eyJ" + "a" * 1397

creds_path = client._setup_fls_oci_credential_file(console, "#", "serviceaccount", long_token)
assert creds_path == "/tmp/fls_creds"

# With a 1400+ char token, the base64 encoding should produce multiple chunks
b64_lines = [line for line in console.sent_lines if "printf" in line and ".b64" in line]
assert len(b64_lines) > 1, f"Expected multiple chunks for long token, got {len(b64_lines)}"

# Each printf line should be well under serial buffer limits
for line in b64_lines:
assert len(line) < 600, f"Chunk line too long ({len(line)} chars): {line[:80]}..."

# Verify roundtrip: reassemble and decode
import base64

b64_data = ""
for line in b64_lines:
parts = shlex.split(line)
b64_data += parts[2]
decoded = base64.b64decode(b64_data).decode()
assert f"FLS_REGISTRY_PASSWORD={long_token}" in decoded
assert "FLS_REGISTRY_USERNAME=serviceaccount" in decoded


def test_flash_http_url_with_oci_credentials_still_uses_direct_http_path():
Expand Down
Loading