Skip to content

Commit fb3333c

Browse files
authored
fix file generation in kubernetes (#2)
* fix file generation in kubernetes * .
1 parent a5298e6 commit fb3333c

File tree

2 files changed

+225
-9
lines changed

2 files changed

+225
-9
lines changed

code-interpreter/app/services/executor_kubernetes.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import base64
34
import io
45
import logging
56
import tarfile
@@ -180,36 +181,69 @@ def _create_tar_archive(
180181
return tar_buffer.getvalue()
181182

182183
def _extract_workspace_snapshot(self, pod_name: str) -> tuple[WorkspaceEntry, ...]:
183-
"""Extract files from the pod workspace after execution using tar."""
184-
try:
185-
exec_command = ["tar", "-c", "--exclude=__main__.py", "-C", "/workspace", "."]
184+
"""Extract files from the pod workspace after execution using tar.
186185
186+
Uses base64 encoding to safely transmit binary tar data through the
187+
text-based Kubernetes WebSocket stream.
188+
"""
189+
try:
190+
# Use base64 to encode the tar output so it can safely pass through
191+
# the text-based WebSocket stream without corruption
192+
exec_command = [
193+
"sh",
194+
"-c",
195+
"tar -c --exclude=__main__.py -C /workspace . | base64",
196+
]
197+
198+
logger.info(f"Starting tar extraction from pod {pod_name}")
187199
resp = stream.stream(
188200
self.v1.connect_get_namespaced_pod_exec,
189201
pod_name,
190202
self.namespace,
191203
command=exec_command,
192-
stderr=False,
204+
stderr=True,
193205
stdin=False,
194206
stdout=True,
195207
tty=False,
196208
_preload_content=False,
197209
)
198210

199-
tar_data = b""
211+
base64_data = ""
212+
stderr_data = ""
213+
200214
while resp.is_open():
201215
resp.update(timeout=1)
216+
202217
if resp.peek_stdout():
203-
tar_data += resp.read_stdout().encode("latin-1")
218+
base64_data += resp.read_stdout()
219+
220+
if resp.peek_stderr():
221+
stderr_data += resp.read_stderr()
204222

205223
resp.close()
206224

207-
if not tar_data:
225+
logger.info(f"Tar extraction complete. Received {len(base64_data)} base64 chars")
226+
if stderr_data:
227+
logger.warning(f"Tar extraction stderr: {stderr_data}")
228+
229+
if not base64_data:
230+
logger.warning("No tar data received from workspace snapshot")
208231
return tuple()
209232

233+
# Decode base64 to get the original tar binary data
234+
tar_data = base64.b64decode(base64_data)
235+
logger.info(f"Decoded to {len(tar_data)} bytes of tar data")
236+
210237
entries = []
238+
logger.info("Parsing tar archive")
211239
with tarfile.open(fileobj=io.BytesIO(tar_data), mode="r") as tar:
212-
for member in tar.getmembers():
240+
members = tar.getmembers()
241+
logger.info(f"Tar archive contains {len(members)} members")
242+
for member in members:
243+
logger.debug(
244+
f"Processing tar member: {member.name!r} (type={member.type!r}, "
245+
f"size={member.size})"
246+
)
213247
if member.name == ".":
214248
continue
215249

@@ -223,12 +257,17 @@ def _extract_workspace_snapshot(self, pod_name: str) -> tuple[WorkspaceEntry, ..
223257
file_obj = tar.extractfile(member)
224258
if file_obj:
225259
content = file_obj.read()
260+
logger.debug(f"Extracted file {clean_path}: {len(content)} bytes")
226261
entries.append(
227262
WorkspaceEntry(path=clean_path, kind="file", content=content)
228263
)
264+
else:
265+
logger.warning(f"Failed to extract file content for {clean_path}")
229266

267+
logger.info(f"Extracted {len(entries)} workspace entries")
230268
return tuple(entries)
231-
except Exception:
269+
except Exception as e:
270+
logger.error(f"Failed to extract workspace snapshot: {e}", exc_info=True)
232271
return tuple()
233272

234273
def _cleanup_pod(self, pod_name: str) -> None:

code-interpreter/tests/e2e/test_basic_flow.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,180 @@ def test_execute_edits_passed_file() -> None:
140140
assert returned_content == expected_content, (
141141
f"Content mismatch. Expected: {expected_content!r}, Got: {returned_content!r}"
142142
)
143+
144+
145+
def test_matplotlib_sine_wave_plot() -> None:
146+
timeout = httpx.Timeout(10.0, connect=5.0)
147+
148+
with httpx.Client(base_url=BASE_URL, timeout=timeout) as client:
149+
# First check health
150+
try:
151+
health_response = client.get("/health")
152+
except httpx.TransportError as exc: # pragma: no cover - network failure path
153+
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")
154+
155+
assert health_response.status_code == 200, health_response.text
156+
157+
code = """
158+
import matplotlib
159+
matplotlib.use('Agg') # Use non-interactive backend
160+
import matplotlib.pyplot as plt
161+
import numpy as np
162+
163+
# Generate data
164+
x = np.linspace(0, 10, 100)
165+
y = np.sin(x)
166+
167+
# Create plot
168+
plt.figure(figsize=(10, 6))
169+
plt.plot(x, y)
170+
plt.title('Sine Wave')
171+
plt.xlabel('x')
172+
plt.ylabel('sin(x)')
173+
plt.grid(True)
174+
175+
# Save plot
176+
plt.savefig('sine_wave.png')
177+
plt.close()
178+
print("Plot saved successfully")
179+
""".strip()
180+
181+
execute_payload: dict[str, Any] = {
182+
"code": code,
183+
"stdin": None,
184+
"timeout_ms": 5000,
185+
"files": [],
186+
}
187+
188+
try:
189+
execute_response = client.post("/v1/execute", json=execute_payload)
190+
except httpx.TransportError as exc: # pragma: no cover - network failure path
191+
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")
192+
193+
assert execute_response.status_code == 200, execute_response.text
194+
195+
result = execute_response.json()
196+
197+
# Verify execution succeeded
198+
assert result["stdout"] == "Plot saved successfully\n", f"stdout mismatch: {result}"
199+
assert result["stderr"] == "", f"stderr should be empty: {result}"
200+
assert result["exit_code"] == 0, f"exit_code should be 0: {result}"
201+
assert result["timed_out"] is False, f"should not timeout: {result}"
202+
203+
# Verify the PNG file was created and returned
204+
files = result.get("files")
205+
assert isinstance(files, list), "files should be a list"
206+
207+
# Find the sine_wave.png file
208+
png_file = None
209+
for file_entry in files:
210+
if isinstance(file_entry, dict) and file_entry.get("path") == "sine_wave.png":
211+
png_file = file_entry
212+
break
213+
214+
assert png_file is not None, f"sine_wave.png not found in response files: {files}"
215+
assert png_file["kind"] == "file"
216+
217+
# Verify the file has a file_id
218+
file_id = png_file.get("file_id")
219+
assert isinstance(file_id, str), "file_id should be present"
220+
221+
# Download the file and verify it's a valid PNG
222+
download_response = client.get(f"/v1/files/{file_id}")
223+
assert download_response.status_code == 200, (
224+
f"Failed to download file: {download_response.text}"
225+
)
226+
png_bytes = download_response.content
227+
228+
# PNG files start with these magic bytes
229+
assert png_bytes[:8] == b"\x89PNG\r\n\x1a\n", "File should be a valid PNG"
230+
231+
# Verify the file has reasonable size (should be several KB for a plot)
232+
assert len(png_bytes) > 1000, f"PNG file too small: {len(png_bytes)} bytes"
233+
234+
235+
def test_create_multiple_files() -> None:
236+
timeout = httpx.Timeout(5.0, connect=5.0)
237+
238+
with httpx.Client(base_url=BASE_URL, timeout=timeout) as client:
239+
# First check health
240+
try:
241+
health_response = client.get("/health")
242+
except httpx.TransportError as exc: # pragma: no cover - network failure path
243+
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")
244+
245+
assert health_response.status_code == 200, health_response.text
246+
247+
code = """
248+
# Create multiple files
249+
with open('file1.txt', 'w') as f:
250+
f.write('Content of file 1')
251+
252+
with open('file2.txt', 'w') as f:
253+
f.write('Content of file 2')
254+
255+
with open('file3.txt', 'w') as f:
256+
f.write('Content of file 3')
257+
258+
print("Created 3 files")
259+
""".strip()
260+
261+
execute_payload: dict[str, Any] = {
262+
"code": code,
263+
"stdin": None,
264+
"timeout_ms": 2000,
265+
"files": [],
266+
}
267+
268+
try:
269+
execute_response = client.post("/v1/execute", json=execute_payload)
270+
except httpx.TransportError as exc: # pragma: no cover - network failure path
271+
pytest.fail(f"Failed to reach Code Interpreter service at {BASE_URL}: {exc!s}")
272+
273+
assert execute_response.status_code == 200, execute_response.text
274+
275+
result = execute_response.json()
276+
277+
# Verify execution succeeded
278+
assert result["stdout"] == "Created 3 files\n", f"stdout mismatch: {result}"
279+
assert result["stderr"] == "", f"stderr should be empty: {result}"
280+
assert result["exit_code"] == 0, f"exit_code should be 0: {result}"
281+
assert result["timed_out"] is False, f"should not timeout: {result}"
282+
283+
# Verify all three files were created and returned
284+
files = result.get("files")
285+
assert isinstance(files, list), "files should be a list"
286+
assert len(files) == 3, f"Expected 3 files, got {len(files)}: {files}"
287+
288+
# Check that all expected files are present
289+
file_paths = {file_entry["path"] for file_entry in files}
290+
expected_paths = {"file1.txt", "file2.txt", "file3.txt"}
291+
assert file_paths == expected_paths, (
292+
f"File paths mismatch. Expected: {expected_paths}, Got: {file_paths}"
293+
)
294+
295+
# Verify each file has correct content
296+
expected_contents = {
297+
"file1.txt": "Content of file 1",
298+
"file2.txt": "Content of file 2",
299+
"file3.txt": "Content of file 3",
300+
}
301+
302+
for file_entry in files:
303+
path = file_entry["path"]
304+
assert file_entry["kind"] == "file", f"{path} should be a file"
305+
306+
file_id = file_entry.get("file_id")
307+
assert isinstance(file_id, str), f"{path} should have a file_id"
308+
309+
# Download and verify content
310+
download_response = client.get(f"/v1/files/{file_id}")
311+
assert download_response.status_code == 200, (
312+
f"Failed to download {path}: {download_response.text}"
313+
)
314+
315+
content = download_response.content.decode("utf-8")
316+
expected_content = expected_contents[path]
317+
assert content == expected_content, (
318+
f"Content mismatch for {path}. Expected: {expected_content!r}, Got: {content!r}"
319+
)

0 commit comments

Comments
 (0)