forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathiterators.py
More file actions
197 lines (162 loc) · 7.63 KB
/
iterators.py
File metadata and controls
197 lines (162 loc) · 7.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Implements iterators for deserializing data returned from an inference streaming endpoint."""
from __future__ import absolute_import
from abc import ABC, abstractmethod
import io
from sagemaker.core.exceptions import ModelStreamError, InternalStreamFailure
from sagemaker.core.common_utils import _MAX_BUFFER_SIZE
def handle_stream_errors(chunk):
"""Handle API Response errors within `invoke_endpoint_with_response_stream` API if any.
Args:
chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream`
response object.
Raises:
ModelStreamError: If `ModelStreamError` error is detected in a chunk of
`botocore.eventstream.EventStream` response object.
InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of
`botocore.eventstream.EventStream` response object.
"""
if "ModelStreamError" in chunk:
raise ModelStreamError(
chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"]
)
if "InternalStreamFailure" in chunk:
raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"])
class BaseIterator(ABC):
"""Abstract base class for Inference Streaming iterators.
Provides a skeleton for customization requiring the overriding of iterator methods
__iter__ and __next__.
Tenets of iterator class for Streaming Inference API Response
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/
sagemaker-runtime/client/invoke_endpoint_with_response_stream.html):
1. Needs to accept an botocore.eventstream.EventStream response.
2. Needs to implement logic in __next__ to:
2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream.
While doing so parse the response_chunk["PayloadPart"]["Bytes"].
2.2. If PayloadPart not in EventStream response, handle Errors
[Recommended to use `iterators.handle_stream_errors` method].
"""
def __init__(self, event_stream):
"""Initialises a Iterator object to help parse the byte event stream input.
Args:
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
self.event_stream = event_stream
@abstractmethod
def __iter__(self):
"""Abstract method, returns an iterator object itself"""
return self
@abstractmethod
def __next__(self):
"""Abstract method, is responsible for returning the next element in the iteration"""
class ByteIterator(BaseIterator):
"""A helper class for parsing the byte Event Stream input to provide Byte iteration."""
def __init__(self, event_stream):
"""Initialises a BytesIterator Iterator object
Args:
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
super().__init__(event_stream)
self.byte_iterator = iter(event_stream)
def __iter__(self):
"""Returns an iterator object itself, which allows the object to be iterated.
Returns:
iter : object
An iterator object representing the iterable.
"""
return self
def __next__(self):
"""Returns the next chunk of Byte directly."""
# Even with "while True" loop the function still behaves like a generator
# and sends the next new byte chunk.
while True:
chunk = next(self.byte_iterator)
if "PayloadPart" not in chunk:
# handle API response errors and force terminate.
handle_stream_errors(chunk)
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue
return chunk["PayloadPart"]["Bytes"]
class LineIterator(BaseIterator):
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""
def __init__(self, event_stream):
"""Initialises a LineIterator Iterator object
Args:
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
"""
super().__init__(event_stream)
self.byte_iterator = iter(self.event_stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
"""Returns an iterator object itself, which allows the object to be iterated.
Returns:
iter : object
An iterator object representing the iterable.
"""
return self
def __next__(self):
r"""Returns the next Line for an Line iterable.
The output of the event stream will be in the following format:
```
b'{"outputs": [" a"]}\n'
b'{"outputs": [" challenging"]}\n'
b'{"outputs": [" problem"]}\n'
...
```
While usually each PayloadPart event from the event stream will contain a byte array
with a full json, this is not guaranteed and some of the json objects may be split across
PayloadPart events. For example:
```
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
```
This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character) within
the buffer via the 'scan_lines' function. It maintains the position of the last read
position to ensure that previous bytes are not exposed again.
Returns:
str: Read and return one line from the event stream.
"""
# Even with "while True" loop the function still behaves like a generator
# and sends the next new concatenated line
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
# handle API response errors and force terminate.
handle_stream_errors(chunk)
# print and move on to next response byte
print("Unknown event type:" + chunk)
continue
# Check buffer size before writing to prevent unbounded memory consumption
chunk_size = len(chunk["PayloadPart"]["Bytes"])
current_size = self.buffer.getbuffer().nbytes
if current_size + chunk_size > _MAX_BUFFER_SIZE:
raise RuntimeError(
f"Line buffer exceeded maximum size of {_MAX_BUFFER_SIZE} bytes. "
f"No newline found in stream."
)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])