|
1 | 1 | # This file was auto-generated by Fern from our API Definition. |
2 | 2 |
|
| 3 | +import codecs |
3 | 4 | import re |
4 | 5 | from contextlib import asynccontextmanager, contextmanager |
5 | | -from typing import Any, AsyncGenerator, AsyncIterator, Iterator, cast |
| 6 | +from typing import Any, AsyncGenerator, AsyncIterator, Iterator |
6 | 7 |
|
7 | 8 | import httpx |
8 | 9 | from ._decoders import SSEDecoder |
@@ -45,46 +46,81 @@ def _get_charset(self) -> str: |
45 | 46 | def response(self) -> httpx.Response: |
46 | 47 | return self._response |
47 | 48 |
|
| 49 | + @staticmethod |
| 50 | + def _normalize_sse_line_endings(buf: str) -> str: |
| 51 | + """Normalize line endings per the SSE spec (\\r\\n → \\n, bare \\r → \\n). |
| 52 | +
|
| 53 | + A trailing \\r is preserved because it may pair with a leading \\n in |
| 54 | + the next chunk to form a single \\r\\n terminator. |
| 55 | + """ |
| 56 | + buf = buf.replace("\r\n", "\n") |
| 57 | + if buf.endswith("\r"): |
| 58 | + return buf[:-1].replace("\r", "\n") + "\r" |
| 59 | + return buf.replace("\r", "\n") |
| 60 | + |
48 | 61 | def iter_sse(self) -> Iterator[ServerSentEvent]: |
49 | 62 | self._check_content_type() |
50 | 63 | decoder = SSEDecoder() |
51 | 64 | charset = self._get_charset() |
| 65 | + text_decoder = codecs.getincrementaldecoder(charset)(errors="replace") |
52 | 66 |
|
53 | | - buffer = "" |
| 67 | + buf = "" |
54 | 68 | for chunk in self._response.iter_bytes(): |
55 | | - # Decode chunk using detected charset |
56 | | - text_chunk = chunk.decode(charset, errors="replace") |
57 | | - buffer += text_chunk |
58 | | - |
59 | | - # Process complete lines |
60 | | - while "\n" in buffer: |
61 | | - line, buffer = buffer.split("\n", 1) |
62 | | - line = line.rstrip("\r") |
| 69 | + buf += text_decoder.decode(chunk) |
| 70 | + buf = self._normalize_sse_line_endings(buf) |
| 71 | + |
| 72 | + while "\n" in buf: |
| 73 | + line, buf = buf.split("\n", 1) |
63 | 74 | sse = decoder.decode(line) |
64 | | - # when we reach a "\n\n" => line = '' |
65 | | - # => decoder will attempt to return an SSE Event |
66 | 75 | if sse is not None: |
67 | 76 | yield sse |
68 | 77 |
|
69 | | - # Process any remaining data in buffer |
70 | | - if buffer.strip(): |
71 | | - line = buffer.rstrip("\r") |
| 78 | + # Flush any remaining bytes from the incremental decoder |
| 79 | + buf += text_decoder.decode(b"", final=True) |
| 80 | + buf = buf.replace("\r\n", "\n").replace("\r", "\n") |
| 81 | + |
| 82 | + while "\n" in buf: |
| 83 | + line, buf = buf.split("\n", 1) |
72 | 84 | sse = decoder.decode(line) |
73 | 85 | if sse is not None: |
74 | 86 | yield sse |
75 | 87 |
|
| 88 | + if buf.strip(): |
| 89 | + sse = decoder.decode(buf) |
| 90 | + if sse is not None: |
| 91 | + yield sse |
| 92 | + |
76 | 93 | async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]: |
77 | 94 | self._check_content_type() |
78 | 95 | decoder = SSEDecoder() |
79 | | - lines = cast(AsyncGenerator[str, None], self._response.aiter_lines()) |
80 | | - try: |
81 | | - async for line in lines: |
82 | | - line = line.rstrip("\n") |
| 96 | + charset = self._get_charset() |
| 97 | + text_decoder = codecs.getincrementaldecoder(charset)(errors="replace") |
| 98 | + |
| 99 | + buf = "" |
| 100 | + async for chunk in self._response.aiter_bytes(): |
| 101 | + buf += text_decoder.decode(chunk) |
| 102 | + buf = self._normalize_sse_line_endings(buf) |
| 103 | + |
| 104 | + while "\n" in buf: |
| 105 | + line, buf = buf.split("\n", 1) |
83 | 106 | sse = decoder.decode(line) |
84 | 107 | if sse is not None: |
85 | 108 | yield sse |
86 | | - finally: |
87 | | - await lines.aclose() |
| 109 | + |
| 110 | + # Flush any remaining bytes from the incremental decoder |
| 111 | + buf += text_decoder.decode(b"", final=True) |
| 112 | + buf = buf.replace("\r\n", "\n").replace("\r", "\n") |
| 113 | + |
| 114 | + while "\n" in buf: |
| 115 | + line, buf = buf.split("\n", 1) |
| 116 | + sse = decoder.decode(line) |
| 117 | + if sse is not None: |
| 118 | + yield sse |
| 119 | + |
| 120 | + if buf.strip(): |
| 121 | + sse = decoder.decode(buf) |
| 122 | + if sse is not None: |
| 123 | + yield sse |
88 | 124 |
|
89 | 125 |
|
90 | 126 | @contextmanager |
|
0 commit comments