Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/friendly-fireants-begin.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@e2b/code-interpreter-python': patch
---

Await async callbacks
38 changes: 19 additions & 19 deletions python/e2b_code_interpreter/code_interpreter_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
Context,
Result,
aextract_exception,
parse_output,
OutputHandler,
OutputHandlerWithAsync,
async_parse_output,
OutputMessage,
)
from e2b_code_interpreter.exceptions import (
Expand Down Expand Up @@ -69,10 +69,10 @@ async def run_code(
self,
code: str,
language: Union[Literal["python"], None] = None,
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
on_result: Optional[OutputHandler[Result]] = None,
on_error: Optional[OutputHandler[ExecutionError]] = None,
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
envs: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
Expand Down Expand Up @@ -103,10 +103,10 @@ async def run_code(
self,
code: str,
language: Optional[str] = None,
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
on_result: Optional[OutputHandler[Result]] = None,
on_error: Optional[OutputHandler[ExecutionError]] = None,
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
envs: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
Expand Down Expand Up @@ -138,10 +138,10 @@ async def run_code(
self,
code: str,
context: Optional[Context] = None,
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
on_result: Optional[OutputHandler[Result]] = None,
on_error: Optional[OutputHandler[ExecutionError]] = None,
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
envs: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
Expand Down Expand Up @@ -172,10 +172,10 @@ async def run_code(
code: str,
language: Optional[str] = None,
context: Optional[Context] = None,
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
on_result: Optional[OutputHandler[Result]] = None,
on_error: Optional[OutputHandler[ExecutionError]] = None,
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
envs: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
Expand Down Expand Up @@ -215,7 +215,7 @@ async def run_code(
execution = Execution()

async for line in response.aiter_lines():
parse_output(
await async_parse_output(
execution,
line,
on_stdout=on_stdout,
Expand Down
47 changes: 45 additions & 2 deletions python/e2b_code_interpreter/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json
import logging

Expand All @@ -20,8 +21,10 @@
from .charts import Chart, _deserialize_chart

T = TypeVar("T")
OutputHandler = Union[
Callable[[T], Any],
OutputHandler = Union[Callable[[T], Any],]

OutputHandlerWithAsync = Union[
OutputHandler[T],
Callable[[T], Awaitable[Any]],
]

Expand Down Expand Up @@ -446,6 +449,46 @@ def parse_output(
execution.execution_count = data["execution_count"]


async def async_parse_output(
Comment thread
mishushakov marked this conversation as resolved.
execution: Execution,
output: str,
on_stdout: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_stderr: Optional[OutputHandlerWithAsync[OutputMessage]] = None,
on_result: Optional[OutputHandlerWithAsync[Result]] = None,
on_error: Optional[OutputHandlerWithAsync[ExecutionError]] = None,
):
data = json.loads(output)
data_type = data.pop("type")

if data_type == "result":
result = Result(**data)
execution.results.append(result)
if on_result:
cb = on_result(result)
if inspect.isawaitable(cb):
await cb
elif data_type == "stdout":
execution.logs.stdout.append(data["text"])
if on_stdout:
cb = on_stdout(OutputMessage(data["text"], data["timestamp"], False))
if inspect.isawaitable(cb):
await cb
elif data_type == "stderr":
execution.logs.stderr.append(data["text"])
if on_stderr:
cb = on_stderr(OutputMessage(data["text"], data["timestamp"], True))
if inspect.isawaitable(cb):
await cb
elif data_type == "error":
execution.error = ExecutionError(data["name"], data["value"], data["traceback"])
if on_error:
cb = on_error(execution.error)
if inspect.isawaitable(cb):
await cb
elif data_type == "number_of_executions":
execution.execution_count = data["execution_count"]


@dataclass
class Context:
"""
Expand Down
28 changes: 22 additions & 6 deletions python/tests/async/test_async_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
from e2b_code_interpreter.code_interpreter_async import AsyncSandbox


async def test_resuls(async_sandbox: AsyncSandbox):
def async_append_fn(items):
async def async_append(item):
items.append(item)

return async_append


async def test_results(async_sandbox: AsyncSandbox):
results = []

execution = await async_sandbox.run_code(
"x = 1;x", on_result=async_append_fn(results)
)
assert len(results) == 1
assert execution.results[0].text == "1"


async def test_results_sync_callback(async_sandbox: AsyncSandbox):
results = []

execution = await async_sandbox.run_code(
"x = 1;x", on_result=lambda result: results.append(result)
)
Expand All @@ -12,17 +30,15 @@ async def test_resuls(async_sandbox: AsyncSandbox):

async def test_error(async_sandbox: AsyncSandbox):
errors = []
execution = await async_sandbox.run_code(
"xyz", on_error=lambda error: errors.append(error)
)
execution = await async_sandbox.run_code("xyz", on_error=async_append_fn(errors))
assert len(errors) == 1
assert execution.error.name == "NameError"


async def test_stdout(async_sandbox: AsyncSandbox):
stdout = []
execution = await async_sandbox.run_code(
"print('Hello from e2b')", on_stdout=lambda out: stdout.append(out)
"print('Hello from e2b')", on_stdout=async_append_fn(stdout)
)
assert len(stdout) == 1
assert execution.logs.stdout == ["Hello from e2b\n"]
Expand All @@ -32,7 +48,7 @@ async def test_stderr(async_sandbox: AsyncSandbox):
stderr = []
execution = await async_sandbox.run_code(
'import sys;print("This is an error message", file=sys.stderr)',
on_stderr=lambda err: stderr.append(err),
on_stderr=async_append_fn(stderr),
)
assert len(stderr) == 1
assert execution.logs.stderr == ["This is an error message\n"]