|
3 | 3 |
|
4 | 4 | import asyncio |
5 | 5 | import base64 |
| 6 | +import json |
6 | 7 | import os |
7 | 8 | import subprocess |
8 | 9 | import uuid |
@@ -142,11 +143,49 @@ def read_text_from_file(file, save_file_name): |
142 | 143 | return file_content |
143 | 144 |
|
144 | 145 |
|
| 146 | +def align_generator(self, gen, **kwargs): |
| 147 | + # OpenAI response format |
| 148 | + # b'data:{"id":"","object":"text_completion","created":1725530204,"model":"meta-llama/Meta-Llama-3-8B-Instruct","system_fingerprint":"2.0.1-native","choices":[{"index":0,"delta":{"role":"assistant","content":"?"},"logprobs":null,"finish_reason":null}]}\n\n' |
| 149 | + for line in gen: |
| 150 | + line = line.decode("utf-8") |
| 151 | + start = -1 |
| 152 | + end = -1 |
| 153 | + try: |
| 154 | + start = line.find("{") |
| 155 | + end = line.rfind("}") + 1 |
| 156 | + if start == -1 or end <= start: |
| 157 | + # Handle cases where '{' or '}' are not found or are in the wrong order |
| 158 | + json_str = "" |
| 159 | + else: |
| 160 | + json_str = line[start:end] |
| 161 | + except Exception as e: |
| 162 | + print(f"Error finding JSON boundaries: {e}") |
| 163 | + json_str = "" |
| 164 | + |
| 165 | + try: |
| 166 | + # sometimes yield empty chunk, do a fallback here |
| 167 | + json_data = json.loads(json_str) |
| 168 | + if "ops" in json_data and "op" in json_data["ops"][0]: |
| 169 | + if "value" in json_data["ops"][0] and isinstance(json_data["ops"][0]["value"], str): |
| 170 | + yield f"data: {repr(json_data['ops'][0]['value'].encode('utf-8'))}\n\n" |
| 171 | + else: |
| 172 | + pass |
| 173 | + elif ( |
| 174 | + json_data["choices"][0]["finish_reason"] != "eos_token" |
| 175 | + and "content" in json_data["choices"][0]["delta"] |
| 176 | + ): |
| 177 | + yield f"data: {repr(json_data['choices'][0]['delta']['content'].encode('utf-8'))}\n\n" |
| 178 | + except Exception as e: |
| 179 | + yield f"data: {repr(json_str.encode('utf-8'))}\n\n" |
| 180 | + yield "data: [DONE]\n\n" |
| 181 | + |
| 182 | + |
145 | 183 | class DocSumService: |
146 | 184 | def __init__(self, host="0.0.0.0", port=8000): |
147 | 185 | self.host = host |
148 | 186 | self.port = port |
149 | 187 | ServiceOrchestrator.align_inputs = align_inputs |
| 188 | + ServiceOrchestrator.align_generator = align_generator |
150 | 189 | self.megaservice = ServiceOrchestrator() |
151 | 190 | self.megaservice_text_only = ServiceOrchestrator() |
152 | 191 | self.endpoint = str(MegaServiceEndpoint.DOC_SUMMARY) |
|
0 commit comments