Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions code-interpreter/app/services/executor_kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import base64
import io
import logging
import tarfile
Expand Down Expand Up @@ -180,36 +181,69 @@ def _create_tar_archive(
return tar_buffer.getvalue()

def _extract_workspace_snapshot(self, pod_name: str) -> tuple[WorkspaceEntry, ...]:
"""Extract files from the pod workspace after execution using tar."""
try:
exec_command = ["tar", "-c", "--exclude=__main__.py", "-C", "/workspace", "."]
"""Extract files from the pod workspace after execution using tar.

Uses base64 encoding to safely transmit binary tar data through the
text-based Kubernetes WebSocket stream.
"""
try:
# Use base64 to encode the tar output so it can safely pass through
# the text-based WebSocket stream without corruption
exec_command = [
"sh",
"-c",
"tar -c --exclude=__main__.py -C /workspace . | base64",
]

logger.info(f"Starting tar extraction from pod {pod_name}")
resp = stream.stream(
self.v1.connect_get_namespaced_pod_exec,
pod_name,
self.namespace,
command=exec_command,
stderr=False,
stderr=True,
stdin=False,
stdout=True,
tty=False,
_preload_content=False,
)

tar_data = b""
base64_data = ""
stderr_data = ""

while resp.is_open():
resp.update(timeout=1)

if resp.peek_stdout():
tar_data += resp.read_stdout().encode("latin-1")
base64_data += resp.read_stdout()

if resp.peek_stderr():
stderr_data += resp.read_stderr()

resp.close()

if not tar_data:
logger.info(f"Tar extraction complete. Received {len(base64_data)} base64 chars")
if stderr_data:
logger.warning(f"Tar extraction stderr: {stderr_data}")

if not base64_data:
logger.warning("No tar data received from workspace snapshot")
return tuple()

# Decode base64 to get the original tar binary data
tar_data = base64.b64decode(base64_data)
logger.info(f"Decoded to {len(tar_data)} bytes of tar data")

entries = []
logger.info("Parsing tar archive")
with tarfile.open(fileobj=io.BytesIO(tar_data), mode="r") as tar:
for member in tar.getmembers():
members = tar.getmembers()
logger.info(f"Tar archive contains {len(members)} members")
for member in members:
logger.debug(
f"Processing tar member: {member.name!r} (type={member.type!r}, "
f"size={member.size})"
)
if member.name == ".":
continue

Expand All @@ -223,12 +257,17 @@ def _extract_workspace_snapshot(self, pod_name: str) -> tuple[WorkspaceEntry, ..
file_obj = tar.extractfile(member)
if file_obj:
content = file_obj.read()
logger.debug(f"Extracted file {clean_path}: {len(content)} bytes")
entries.append(
WorkspaceEntry(path=clean_path, kind="file", content=content)
)
else:
logger.warning(f"Failed to extract file content for {clean_path}")

logger.info(f"Extracted {len(entries)} workspace entries")
return tuple(entries)
except Exception:
except Exception as e:
logger.error(f"Failed to extract workspace snapshot: {e}", exc_info=True)
return tuple()

def _cleanup_pod(self, pod_name: str) -> None:
Expand Down
177 changes: 177 additions & 0 deletions code-interpreter/tests/e2e/test_basic_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,180 @@ def test_execute_edits_passed_file() -> None:
assert returned_content == expected_content, (
f"Content mismatch. Expected: {expected_content!r}, Got: {returned_content!r}"
)


def test_matplotlib_sine_wave_plot() -> None:
timeout = httpx.Timeout(10.0, connect=5.0)

with httpx.Client(base_url=BASE_URL, timeout=timeout) as client:
# First check health
try:
health_response = client.get("/health")
except httpx.TransportError as exc: # pragma: no cover - network failure path
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")

assert health_response.status_code == 200, health_response.text

code = """
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
import numpy as np

# Generate data
x = np.linspace(0, 10, 100)
y = np.sin(x)

# Create plot
plt.figure(figsize=(10, 6))
plt.plot(x, y)
plt.title('Sine Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True)

# Save plot
plt.savefig('sine_wave.png')
plt.close()
print("Plot saved successfully")
""".strip()

execute_payload: dict[str, Any] = {
"code": code,
"stdin": None,
"timeout_ms": 5000,
"files": [],
}

try:
execute_response = client.post("/v1/execute", json=execute_payload)
except httpx.TransportError as exc: # pragma: no cover - network failure path
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")

assert execute_response.status_code == 200, execute_response.text

result = execute_response.json()

# Verify execution succeeded
assert result["stdout"] == "Plot saved successfully\n", f"stdout mismatch: {result}"
assert result["stderr"] == "", f"stderr should be empty: {result}"
assert result["exit_code"] == 0, f"exit_code should be 0: {result}"
assert result["timed_out"] is False, f"should not timeout: {result}"

# Verify the PNG file was created and returned
files = result.get("files")
assert isinstance(files, list), "files should be a list"

# Find the sine_wave.png file
png_file = None
for file_entry in files:
if isinstance(file_entry, dict) and file_entry.get("path") == "sine_wave.png":
png_file = file_entry
break

assert png_file is not None, f"sine_wave.png not found in response files: {files}"
assert png_file["kind"] == "file"

# Verify the file has a file_id
file_id = png_file.get("file_id")
assert isinstance(file_id, str), "file_id should be present"

# Download the file and verify it's a valid PNG
download_response = client.get(f"/v1/files/{file_id}")
assert download_response.status_code == 200, (
f"Failed to download file: {download_response.text}"
)
png_bytes = download_response.content

# PNG files start with these magic bytes
assert png_bytes[:8] == b"\x89PNG\r\n\x1a\n", "File should be a valid PNG"

# Verify the file has reasonable size (should be several KB for a plot)
assert len(png_bytes) > 1000, f"PNG file too small: {len(png_bytes)} bytes"


def test_create_multiple_files() -> None:
timeout = httpx.Timeout(5.0, connect=5.0)

with httpx.Client(base_url=BASE_URL, timeout=timeout) as client:
# First check health
try:
health_response = client.get("/health")
except httpx.TransportError as exc: # pragma: no cover - network failure path
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")

assert health_response.status_code == 200, health_response.text

code = """
# Create multiple files
with open('file1.txt', 'w') as f:
f.write('Content of file 1')

with open('file2.txt', 'w') as f:
f.write('Content of file 2')

with open('file3.txt', 'w') as f:
f.write('Content of file 3')

print("Created 3 files")
""".strip()

execute_payload: dict[str, Any] = {
"code": code,
"stdin": None,
"timeout_ms": 2000,
"files": [],
}

try:
execute_response = client.post("/v1/execute", json=execute_payload)
except httpx.TransportError as exc: # pragma: no cover - network failure path
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")

assert execute_response.status_code == 200, execute_response.text

result = execute_response.json()

# Verify execution succeeded
assert result["stdout"] == "Created 3 files\n", f"stdout mismatch: {result}"
assert result["stderr"] == "", f"stderr should be empty: {result}"
assert result["exit_code"] == 0, f"exit_code should be 0: {result}"
assert result["timed_out"] is False, f"should not timeout: {result}"

# Verify all three files were created and returned
files = result.get("files")
assert isinstance(files, list), "files should be a list"
assert len(files) == 3, f"Expected 3 files, got {len(files)}: {files}"

# Check that all expected files are present
file_paths = {file_entry["path"] for file_entry in files}
expected_paths = {"file1.txt", "file2.txt", "file3.txt"}
assert file_paths == expected_paths, (
f"File paths mismatch. Expected: {expected_paths}, Got: {file_paths}"
)

# Verify each file has correct content
expected_contents = {
"file1.txt": "Content of file 1",
"file2.txt": "Content of file 2",
"file3.txt": "Content of file 3",
}

for file_entry in files:
path = file_entry["path"]
assert file_entry["kind"] == "file", f"{path} should be a file"

file_id = file_entry.get("file_id")
assert isinstance(file_id, str), f"{path} should have a file_id"

# Download and verify content
download_response = client.get(f"/v1/files/{file_id}")
assert download_response.status_code == 200, (
f"Failed to download {path}: {download_response.text}"
)

content = download_response.content.decode("utf-8")
expected_content = expected_contents[path]
assert content == expected_content, (
f"Content mismatch for {path}. Expected: {expected_content!r}, Got: {content!r}"
)