|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import asyncio |
| 16 | +import gzip |
16 | 17 | from typing import Any, cast |
17 | 18 |
|
18 | 19 | import httpx |
|
29 | 30 |
|
30 | 31 | class _FakeStreamingResponse: |
31 | 32 | def __init__( |
32 | | - self, status_code: int = 200, headers: dict | None = None, chunks: list[bytes] | None = None |
| 33 | + self, |
| 34 | + status_code: int = 200, |
| 35 | + headers: dict | None = None, |
| 36 | + chunks: list[bytes] | None = None, |
| 37 | + raw_chunks: list[bytes] | None = None, |
33 | 38 | ): |
34 | 39 | self.status_code = status_code |
35 | 40 | self.headers = httpx.Headers(headers or {}) |
36 | 41 | self._chunks = chunks or [] |
| 42 | + self._raw_chunks = raw_chunks if raw_chunks is not None else self._chunks |
37 | 43 | self.aclose_called = False |
| 44 | + self.aiter_bytes_called = False |
| 45 | + self.aiter_raw_called = False |
38 | 46 |
|
39 | 47 | async def aiter_bytes(self): |
| 48 | + self.aiter_bytes_called = True |
40 | 49 | for chunk in self._chunks: |
41 | 50 | yield chunk |
42 | 51 |
|
| 52 | + async def aiter_raw(self): |
| 53 | + self.aiter_raw_called = True |
| 54 | + for chunk in self._raw_chunks: |
| 55 | + yield chunk |
| 56 | + |
43 | 57 | async def aclose(self): |
44 | 58 | self.aclose_called = True |
45 | 59 |
|
@@ -425,6 +439,46 @@ def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> |
425 | 439 | assert response.headers.get("x-hop-temp") is None |
426 | 440 |
|
427 | 441 |
|
| 442 | +def test_proxy_preserves_compressed_response_body( |
| 443 | + client: TestClient, |
| 444 | + auth_headers: dict, |
| 445 | + monkeypatch, |
| 446 | +) -> None: |
| 447 | + class StubService: |
| 448 | + @staticmethod |
| 449 | + def get_endpoint(sandbox_id: str, port: int, resolve_internal: bool = False) -> Endpoint: |
| 450 | + assert resolve_internal is True |
| 451 | + return Endpoint(endpoint="10.57.1.91:40109") |
| 452 | + |
| 453 | + monkeypatch.setattr(lifecycle, "sandbox_service", StubService()) |
| 454 | + |
| 455 | + decoded_body = b"<html>vnc</html>" |
| 456 | + encoded_body = gzip.compress(decoded_body) |
| 457 | + fake_client = _FakeAsyncClient() |
| 458 | + fake_client.response = _FakeStreamingResponse( |
| 459 | + status_code=200, |
| 460 | + headers={ |
| 461 | + "content-type": "text/html", |
| 462 | + "content-encoding": "gzip", |
| 463 | + }, |
| 464 | + chunks=[decoded_body], |
| 465 | + raw_chunks=[encoded_body], |
| 466 | + ) |
| 467 | + _set_http_client(client, fake_client) |
| 468 | + |
| 469 | + response = client.get( |
| 470 | + "/v1/sandboxes/sbx-123/proxy/8080/vnc/index.html", |
| 471 | + headers={**auth_headers, "Accept-Encoding": "gzip"}, |
| 472 | + ) |
| 473 | + |
| 474 | + assert response.status_code == 200 |
| 475 | + assert response.headers.get("content-encoding") == "gzip" |
| 476 | + assert response.content == decoded_body |
| 477 | + assert fake_client.response.aiter_raw_called is True |
| 478 | + assert fake_client.response.aiter_bytes_called is False |
| 479 | + assert fake_client.response.aclose_called is True |
| 480 | + |
| 481 | + |
428 | 482 | def test_proxy_rejects_websocket_upgrade( |
429 | 483 | client: TestClient, |
430 | 484 | auth_headers: dict, |
|
0 commit comments