|
17 | 17 |
|
18 | 18 | import logging |
19 | 19 | import grpc |
20 | | -import time |
21 | | -import uuid |
22 | 20 |
|
23 | | -from typing import Any, AsyncIterator, Optional, Tuple, cast |
| 21 | +from typing import Optional |
24 | 22 | from jetstream.core import orchestrator |
25 | | -from jetstream.core.lora import adapter_tensorstore |
26 | 23 | from jetstream.core.proto import multi_lora_decoding_pb2_grpc |
27 | 24 | from jetstream.core.proto import multi_lora_decoding_pb2 |
28 | | -from jetstream.core.utils import async_multifuture |
29 | | -from jetstream.core.utils.return_sample import ReturnSample |
30 | | -from jetstream.engine import engine_api, tokenizer_api, token_utils |
31 | 25 |
|
32 | 26 |
|
33 | 27 | class MultiLoraManager(multi_lora_decoding_pb2_grpc.v1Servicer): |
@@ -105,155 +99,3 @@ def unload_lora_adapter( |
105 | 99 | logging.info(f"Loading of adapter_id={request.adapter_id} failed with error: {str(e)}") |
106 | 100 | return multi_lora_decoding_pb2.UnloadAdapterResponse(success=False, error_message=str(e)) |
107 | 101 |
|
108 | | - |
109 | | - def _get_prefill_content( |
110 | | - self, request: multi_lora_decoding_pb2.CompletionRequest |
111 | | - ) -> Tuple[str | list[int], bool]: |
112 | | - which_content = request.WhichOneof("content") |
113 | | - content = getattr(request, which_content) |
114 | | - if which_content == "text_content": |
115 | | - return cast(multi_lora_decoding_pb2.CompletionRequest.TextContent, content).text, False |
116 | | - else: |
117 | | - return ( |
118 | | - list( |
119 | | - cast(multi_lora_decoding_pb2.CompletionRequest.TokenContent, content).token_ids |
120 | | - ), |
121 | | - True, |
122 | | - ) |
123 | | - |
124 | | - def process_client_side_tokenization_response(self, response: Any): |
125 | | - samples = [] |
126 | | - for sample in response: |
127 | | - samples.append( |
128 | | - multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample( |
129 | | - token_ids=sample.token_ids, |
130 | | - ) |
131 | | - ) |
132 | | - return multi_lora_decoding_pb2.CompletionResponse( |
133 | | - stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent( |
134 | | - samples=samples |
135 | | - ) |
136 | | - ) |
137 | | - |
138 | | - def should_buffer_response(self, response: Any) -> bool: |
139 | | - for item in response: |
140 | | - if item.text and token_utils.is_byte_token(item.text[-1]): |
141 | | - # If any sample ends in bytes, this means we might still need to |
142 | | - # decode more bytes to compose the string. |
143 | | - return True |
144 | | - |
145 | | - def process_server_side_tokenization_response( |
146 | | - self, response: Any, buffered_response_list |
147 | | - ): |
148 | | - # Flush the buffered responses to each sample of current response. |
149 | | - current_response_with_flushed_buffer = list( |
150 | | - zip(*buffered_response_list, response) |
151 | | - ) |
152 | | - # Empty buffer: [[s0_cur], [s1_cur], ...] |
153 | | - # Has buffer: |
154 | | - # [[s0_b0, s0_b1, ..., s0_cur], [s1_b0, s1_b1, ..., s1_cur], ...] |
155 | | - current_response_with_flushed_buffer = cast( |
156 | | - list[list[ReturnSample]], current_response_with_flushed_buffer |
157 | | - ) |
158 | | - # Form correct sample(s) and return as StreamContent for this iteration. |
159 | | - samples = [] |
160 | | - for sample in current_response_with_flushed_buffer: |
161 | | - text = [] |
162 | | - token_ids = [] |
163 | | - for resp in sample: |
164 | | - text.extend(resp.text) |
165 | | - token_ids.extend(resp.token_ids) |
166 | | - samples.append( |
167 | | - multi_lora_decoding_pb2.CompletionResponse.StreamContent.Sample( |
168 | | - text=token_utils.text_tokens_to_str(text), |
169 | | - token_ids=token_ids, |
170 | | - ) |
171 | | - ) |
172 | | - return multi_lora_decoding_pb2.CompletionResponse( |
173 | | - stream_content=multi_lora_decoding_pb2.CompletionResponse.StreamContent( |
174 | | - samples=samples |
175 | | - ) |
176 | | - ) |
177 | | - |
178 | | - async def completions( # pylint: disable=invalid-overridden-method |
179 | | - self, |
180 | | - request: multi_lora_decoding_pb2.CompletionRequest, |
181 | | - context: Optional[grpc.aio.ServicerContext] = None, |
182 | | - ) -> AsyncIterator[multi_lora_decoding_pb2.CompletionResponse]: |
183 | | - |
184 | | - """Decode.""" |
185 | | - if context is None: |
186 | | - logging.warning( |
187 | | - "LLM orchestrator is being used in offline test mode, and will not" |
188 | | - " respond to gRPC queries - only direct function calls." |
189 | | - ) |
190 | | - is_client_side_tokenization = False |
191 | | - return_channel = async_multifuture.AsyncMultifuture() |
192 | | - if context: |
193 | | - context.add_done_callback(return_channel.cancel) |
194 | | - |
195 | | - prefill_content, is_client_side_tokenization = self._get_prefill_content( |
196 | | - request |
197 | | - ) |
198 | | - |
199 | | - # Wrap request as an ActiveRequest. |
200 | | - active_request = orchestrator.ActiveRequest( |
201 | | - request_id=uuid.uuid4(), |
202 | | - max_tokens=request.max_tokens, |
203 | | - prefill_content=prefill_content, |
204 | | - is_client_side_tokenization=is_client_side_tokenization, |
205 | | - return_channel=return_channel, |
206 | | - adapter_id=request.adapter_id, |
207 | | - metadata=orchestrator.ActiveRequestMetadata( |
208 | | - start_time=request.metadata.start_time, |
209 | | - prefill_enqueue_time=time.perf_counter(), |
210 | | - ), |
211 | | - ) |
212 | | - # The first stage is being prefilled, all other stages are handled |
213 | | - # inside the driver (transfer, generate*N, detokenize). |
214 | | - try: |
215 | | - self._driver.place_request_on_prefill_queue(active_request) |
216 | | - except queue.Full: |
217 | | - # Safely abort the gRPC server thread with a retriable error. |
218 | | - await _abort_or_raise( |
219 | | - context=context, |
220 | | - code=grpc.StatusCode.RESOURCE_EXHAUSTED, |
221 | | - details=( |
222 | | - "The driver prefill queue is full and more requests cannot be" |
223 | | - " handled. You may retry this request." |
224 | | - ), |
225 | | - ) |
226 | | - logging.info( |
227 | | - "Placed request on the prefill queue.", |
228 | | - ) |
229 | | - # When an active request is created a queue is instantiated. New tokens |
230 | | - # are placed there during the decoding loop, we pop from that queue by |
231 | | - # using the .next method on the active request. |
232 | | - # Yielding allows for the response to be a streaming grpc call - which |
233 | | - # can be called via iterating over a for loop on the client side. |
234 | | - # The DecodeResponse stream should consume all generated tokens in |
235 | | - # return_channel when complete signal is received (AsyncMultifuture |
236 | | - # promises this). |
237 | | - buffered_response_list = [] |
238 | | - async for response in active_request.return_channel: |
239 | | - response = cast(list[ReturnSample], response) |
240 | | - if is_client_side_tokenization: |
241 | | - # If is_client_side_tokenization, the client should request with token |
242 | | - # ids, and the JetStream server will return token ids as response. |
243 | | - # The client should take care of tokenization and detokenization. |
244 | | - yield self.process_client_side_tokenization_response(response) |
245 | | - else: |
246 | | - # Buffer response mechanism is used to handle streaming |
247 | | - # detokenization with special character (For some edge cases with |
248 | | - # SentencePiece tokenizer, it requires to decode a complete sequence |
249 | | - # instead of a single token). |
250 | | - if self.should_buffer_response(response): |
251 | | - buffered_response_list.append(response) |
252 | | - continue |
253 | | - yield self.process_server_side_tokenization_response( |
254 | | - response, buffered_response_list |
255 | | - ) |
256 | | - # Reset buffer after flushed. |
257 | | - buffered_response_list = [] |
258 | | - |
259 | | - |
0 commit comments