Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
import textwrap
import types
import typing

# types.UnionType was added in Python 3.10 to represent PEP 604 `X | Y`
# syntax. Burr supports Python >= 3.9, so on 3.9 we fall back to a sentinel
# type that keeps Union annotations well-formed. No PEP 604 union can ever
# exist on 3.9, so isinstance() checks against it simply never match.
if sys.version_info >= (3, 10):
_UnionType = types.UnionType
else: # pragma: no cover - exercised on Python 3.9 CI only
class _UnionType: # type: ignore[no-redef]
"""Placeholder for ``types.UnionType`` on Python < 3.10."""
from collections.abc import AsyncIterator
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -1511,7 +1521,7 @@ def pydantic(
writes: List[str],
state_input_type: Type["BaseModel"],
state_output_type: Type["BaseModel"],
stream_type: Union[Type["BaseModel"], Type[dict]],
stream_type: Union[Type["BaseModel"], Type[dict], _UnionType],
tags: Optional[List[str]] = None,
) -> Callable:
"""Creates a streaming action that uses pydantic models.
Expand Down
9 changes: 6 additions & 3 deletions burr/integrations/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from burr.core.action import (
FunctionBasedAction,
FunctionBasedStreamingAction,
_UnionType,
bind,
derive_inputs_from_fn,
)
Expand Down Expand Up @@ -269,7 +270,7 @@ async def async_action_function(state: State, **kwargs) -> State:
return decorator


PartialType = Union[Type[pydantic.BaseModel], Type[dict]]
PartialType = Union[Type[pydantic.BaseModel], Type[dict], _UnionType]

PydanticStreamingActionFunctionSync = Callable[
..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None]
Expand All @@ -290,11 +291,13 @@ async def async_action_function(state: State, **kwargs) -> State:

def _validate_and_extract_signature_types_streaming(
fn: PydanticStreamingActionFunction,
stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]],
stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict], _UnionType]],
state_input_type: Optional[Type[pydantic.BaseModel]] = None,
state_output_type: Optional[Type[pydantic.BaseModel]] = None,
) -> Tuple[
Type[pydantic.BaseModel], Type[pydantic.BaseModel], Union[Type[dict], Type[pydantic.BaseModel]]
Type[pydantic.BaseModel],
Type[pydantic.BaseModel],
Union[Type[dict], Type[pydantic.BaseModel], _UnionType],
]:
if stream_type is None:
# TODO -- derive from the signature
Expand Down
47 changes: 47 additions & 0 deletions tests/integrations/test_burr_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import asyncio
import sys
import warnings
from typing import AsyncGenerator, Generator, List, Optional, Tuple

Expand Down Expand Up @@ -449,6 +450,52 @@ def act(
assert final_state.data.times_called == 1


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="PEP 604 union syntax requires Python 3.10+"
)
def test_streaming_pydantic_action_union_stream_type():
class IntermediateUnionModelOne(BaseModel):
result: int

class IntermediateUnionModelTwo(BaseModel):
message: str

@pydantic_streaming_action(
reads=["count", "times_called"],
writes=["count", "times_called"],
stream_type=IntermediateUnionModelOne | IntermediateUnionModelTwo,
state_input_type=AppStateModel,
state_output_type=AppStateModel,
)
def act(
state: AppStateModel, total_count: int
) -> Generator[Tuple[IntermediateUnionModelOne, Optional[AppStateModel]], None, None]:
initial_value = state.count
for i in range(initial_value, initial_value + total_count):
yield IntermediateUnionModelOne(result=i), None
state.count = i
state.times_called += 1
yield IntermediateUnionModelOne(result=state.count), state

assert hasattr(act, "bind") # has to have bind
assert (action_function := getattr(act, FunctionBasedAction.ACTION_FUNCTION, None)) is not None
assert action_function.inputs == (["total_count"], [])
gen = action_function.fn(
State(dict(count=1, times_called=0), typing_system=PydanticTypingSystem(AppStateModel)),
total_count=5,
)
result = list(gen)
assert len(result) == 6
assert [item[0].result for item in result] == [1, 2, 3, 4, 5, 5]
assert all([isinstance(item[0], IntermediateUnionModelOne) for item in result])
assert all([item[1] is None for item in result[:-1]])
assert isinstance(final_state := result[-1][1], State)
assert final_state["count"] == 5
assert final_state["times_called"] == 1
assert final_state.data.count == 5
assert final_state.data.times_called == 1


async def test_streaming_pydantic_action_same_io_async():
@pydantic_streaming_action(
reads=["count", "times_called"],
Expand Down
Loading