Skip to content

Commit b26291d

Browse files
dkolaseob
andauthored
[SHIP-834] [SHIP-818] Streaming Generators (#548)
This PR includes: - Streaming Generator interface - Additional methods and status on `Block` for streaming Depends on: https://github.com/nludb/nludb/pull/631 (tests will fail until deploy of above) --------- Co-authored-by: Ted Benson <edward.benson@gmail.com>
1 parent d0e1e96 commit b26291d

17 files changed

Lines changed: 242 additions & 6 deletions

src/steamship/base/client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,15 @@ def _headers( # noqa: C901
315315
return headers
316316

317317
@staticmethod
318-
def _prepare_data(payload: Union[Request, dict]):
318+
def _prepare_data(payload: Union[Request, dict, bytes]):
319319
if payload is None:
320320
data = {}
321321
elif isinstance(payload, dict):
322322
data = payload
323323
elif isinstance(payload, BaseModel):
324324
data = payload.dict(by_alias=True)
325+
elif isinstance(payload, bytes):
326+
data = payload
325327
else:
326328
raise RuntimeError(f"Unable to parse payload of type {type(payload)}")
327329

@@ -407,7 +409,7 @@ def call( # noqa: C901
407409
self,
408410
verb: Verb,
409411
operation: str,
410-
payload: Union[Request, dict] = None,
412+
payload: Union[Request, dict, bytes] = None,
411413
file: Any = None,
412414
expect: Type[T] = None,
413415
debug: bool = False,
@@ -464,7 +466,10 @@ def call( # noqa: C901
464466
files = self._prepare_multipart_data(data, file)
465467
resp = self._session.post(url, files=files, headers=headers, timeout=timeout_s)
466468
else:
467-
resp = self._session.post(url, json=data, headers=headers, timeout=timeout_s)
469+
if isinstance(data, bytes):
470+
resp = self._session.post(url, data=data, headers=headers, timeout=timeout_s)
471+
else:
472+
resp = self._session.post(url, json=data, headers=headers, timeout=timeout_s)
468473
elif verb == Verb.GET:
469474
resp = self._session.get(url, params=data, headers=headers, timeout=timeout_s)
470475
else:
@@ -555,7 +560,7 @@ def call( # noqa: C901
555560
def post(
556561
self,
557562
operation: str,
558-
payload: Union[Request, dict, BaseModel] = None,
563+
payload: Union[Request, dict, BaseModel, bytes] = None,
559564
file: Any = None,
560565
expect: Any = None,
561566
debug: bool = False,

src/steamship/cli/deploy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def _create_version(self, client: Steamship, manifest: Manifest, thing_id: str):
241241
handle=manifest.version,
242242
filename=f"build/archives/{manifest.handle}_v{manifest.version}.zip",
243243
plugin_id=thing_id,
244+
streaming=manifest.plugin.streaming,
244245
)
245246

246247
def create_object(self, client: Steamship, manifest: Manifest):

src/steamship/cli/manifest_init_wizard.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ def manifest_init_wizard(client: Steamship):
5858
trainable = click.confirm("Is the plugin trainable?", default=False)
5959
else:
6060
trainable = False
61-
plugin_detail = PluginConfig(isTrainable=trainable, type=plugin_type)
61+
if plugin_type == "generator":
62+
streaming = click.confirm("Will the plugin stream its results?", default=False)
63+
else:
64+
streaming = False
65+
plugin_detail = PluginConfig(isTrainable=trainable, type=plugin_type, streaming=streaming)
6266

6367
public = click.confirm(f"Do you want this {deployable_type} to be public?", default=True)
6468

@@ -71,7 +75,7 @@ def manifest_init_wizard(client: Steamship):
7175
if public:
7276
tagline = click.prompt(f"Want to give the {deployable_type} a tagline?", default="")
7377
author_github = click.prompt(
74-
"If you'd like this associated with your github account, please your github username",
78+
"If you'd like this associated with your github account, please input your github username",
7579
default="",
7680
)
7781

src/steamship/data/block.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ class BlockUploadType(str, Enum):
2727
NONE = "none" # No upload; plain text only.
2828

2929

30+
class StreamState(str, Enum):
31+
STARTED = "started" # A producer has begun streaming to this block.
32+
COMPLETE = "complete" # The producer has finished streaming to this block successfully.
33+
ABORTED = "aborted" # The producer finished streaming to the block, but there was an error.
34+
35+
3036
class Block(CamelModel):
3137
"""A Block is a chunk of content within a File. It can be plain text content, image content,
3238
video content, etc. If the content is not text, the text value may be the empty string
@@ -41,6 +47,7 @@ class Block(CamelModel):
4147
index_in_file: Optional[int] = Field(alias="index")
4248
mime_type: Optional[MimeTypes]
4349
public_data: bool = False
50+
stream_state: Optional[StreamState] = None
4451

4552
url: Optional[
4653
str
@@ -88,6 +95,7 @@ def create(
8895
url: Optional[str] = None,
8996
mime_type: Optional[MimeTypes] = None,
9097
public_data: bool = False,
98+
streaming: Optional[bool] = None,
9199
) -> Block:
92100
"""
93101
Create a new Block within a File specified by file_id.
@@ -117,6 +125,7 @@ def create(
117125
"mimeType": mime_type,
118126
"uploadType": upload_type,
119127
"publicData": public_data,
128+
"streaming": streaming,
120129
}
121130

122131
file_data = (
@@ -340,6 +349,24 @@ def as_llm_input(self, exclude_block_wrapper: Optional[bool] = False) -> str:
340349
return f"{identifier}"
341350
return f"Block({identifier})"
342351

352+
def finish_stream(self):
353+
self.client.post(
354+
f"block/{self.id}/finishStream",
355+
payload={},
356+
)
357+
358+
def append_stream(self, bytes: bytes):
359+
self.client.post(
360+
f"block/{self.id}/appendStream",
361+
payload=bytes,
362+
)
363+
364+
def abort_stream(self):
365+
self.client.post(
366+
f"block/{self.id}/abortStream",
367+
payload={},
368+
)
369+
343370

344371
class BlockQueryResponse(Response):
345372
blocks: List[Block]

src/steamship/data/file.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def generate(
282282
options: Optional[dict] = None,
283283
wait_on_tasks: List[Task] = None,
284284
make_output_public: bool = False,
285+
streaming: Optional[bool] = False,
285286
) -> Task[GenerateResponse]:
286287
"""Generate new content from this file. Assumes this file as context for input and output. May specify start and end blocks."""
287288
from steamship.data.operations.generator import GenerateRequest, GenerateResponse
@@ -301,6 +302,7 @@ def generate(
301302
output_file_id=output_file_id,
302303
options=options,
303304
make_output_public=make_output_public,
305+
streaming=streaming,
304306
)
305307
return self.client.post(
306308
"plugin/instance/generate", req, expect=GenerateResponse, wait_on_tasks=wait_on_tasks

src/steamship/data/manifest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class PluginConfig(BaseModel):
5959
isTrainable: Optional[bool] = False # noqa: N815
6060
transport: str = "jsonOverHttp"
6161
type: str # Does not use PluginType due to circular import
62+
streaming: bool = False
6263

6364

6465
class Manifest(BaseModel):

src/steamship/data/operations/generator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ class GenerateRequest(Request):
6060
# Arbitrary runtime options which may be passed to a generator
6161
options: Optional[dict]
6262

63+
# Whether we wish to have the output blocks streamed back to us.
64+
# If the blocks are streamed, they will be returned with a streamState=started,
65+
# and the content can be fetched in a streaming manner by
66+
# fetching Block.raw()
67+
# Default behavior if not provided is streaming=false
68+
streaming: Optional[bool] = None
69+
6370

6471
class GenerateResponse(Response):
6572
blocks: List[Block]

src/steamship/data/plugin/plugin_instance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def generate(
130130
output_file_id: Optional[str] = None,
131131
make_output_public: Optional[bool] = None,
132132
options: Optional[dict] = None,
133+
streaming: Optional[bool] = None,
133134
) -> Task[GenerateResponse]:
134135
"""See GenerateRequest for description of parameter options"""
135136
req = GenerateRequest(
@@ -146,6 +147,7 @@ def generate(
146147
output_file_id=output_file_id,
147148
make_output_public=make_output_public,
148149
options=options,
150+
streaming=streaming,
149151
)
150152
return self.client.post("plugin/instance/generate", req, expect=GenerateResponse)
151153

src/steamship/data/plugin/plugin_version.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class CreatePluginVersionRequest(Request):
2323
type: str = "file"
2424
# Note: this is a Dict[str, Any] but should be transmitted to the Engine as a JSON string
2525
config_template: str = None
26+
streaming: Optional[bool] = None
2627

2728

2829
class ListPluginVersionsRequest(Request):
@@ -65,6 +66,7 @@ def create(
6566
is_public: Optional[bool] = None,
6667
is_default: Optional[bool] = None,
6768
config_template: Optional[Dict[str, Any]] = None,
69+
streaming: Optional[bool] = None,
6870
) -> PluginVersion:
6971

7072
if filename is None and filebytes is None:
@@ -85,6 +87,7 @@ def create(
8587
is_public=is_public,
8688
is_default=is_default,
8789
config_template=json.dumps(config_template or {}),
90+
streaming=streaming,
8891
)
8992

9093
task = client.post(
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from typing import List
2+
3+
from steamship import Block
4+
from steamship.plugin.inputs.raw_block_and_tag_plugin_input import RawBlockAndTagPluginInput
5+
6+
7+
class RawBlockAndTagPluginInputWithPreallocatedBlocks(RawBlockAndTagPluginInput):
8+
output_blocks: List[Block]

0 commit comments

Comments
 (0)