|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +import json |
| 19 | + |
| 20 | +import pytest |
| 21 | +from requests import HTTPError, Response |
| 22 | + |
| 23 | +from pyiceberg.catalog.rest.response import _handle_non_200_response |
| 24 | +from pyiceberg.exceptions import ( |
| 25 | + AuthorizationExpiredError, |
| 26 | + BadRequestError, |
| 27 | + ForbiddenError, |
| 28 | + NoSuchTableError, |
| 29 | + OAuthError, |
| 30 | + RESTError, |
| 31 | + ServerError, |
| 32 | + ServiceUnavailableError, |
| 33 | + TooManyRequestsError, |
| 34 | + UnauthorizedError, |
| 35 | +) |
| 36 | + |
| 37 | + |
| 38 | +def _make_http_error(status_code: int, body: str = "", reason: str | None = None) -> HTTPError: |
| 39 | + response = Response() |
| 40 | + response.status_code = status_code |
| 41 | + response._content = body.encode("utf-8") if body else b"" |
| 42 | + if reason is not None: |
| 43 | + response.reason = reason |
| 44 | + return HTTPError(response=response) |
| 45 | + |
| 46 | + |
| 47 | +def _error_body(message: str, error_type: str, code: int) -> str: |
| 48 | + return json.dumps({"error": {"message": message, "type": error_type, "code": code}}) |
| 49 | + |
| 50 | + |
| 51 | +@pytest.mark.parametrize( |
| 52 | + "status_code, expected_exception", |
| 53 | + [ |
| 54 | + (400, BadRequestError), |
| 55 | + (401, UnauthorizedError), |
| 56 | + (403, ForbiddenError), |
| 57 | + (419, AuthorizationExpiredError), |
| 58 | + (422, RESTError), |
| 59 | + (429, TooManyRequestsError), |
| 60 | + (501, NotImplementedError), |
| 61 | + (503, ServiceUnavailableError), |
| 62 | + (500, ServerError), |
| 63 | + (502, ServerError), |
| 64 | + (504, ServerError), |
| 65 | + (999, RESTError), |
| 66 | + ], |
| 67 | +) |
| 68 | +def test_status_code_maps_to_exception(status_code: int, expected_exception: type[Exception]) -> None: |
| 69 | + body = _error_body("something went wrong", "SomeError", status_code) |
| 70 | + exc = _make_http_error(status_code, body=body) |
| 71 | + |
| 72 | + with pytest.raises(expected_exception, match="SomeError: something went wrong"): |
| 73 | + _handle_non_200_response(exc, {}) |
| 74 | + |
| 75 | + |
| 76 | +def test_error_handler_overrides_default_mapping() -> None: |
| 77 | + body = _error_body("Table does not exist: ns.tbl", "NoSuchTableException", 404) |
| 78 | + exc = _make_http_error(404, body=body) |
| 79 | + |
| 80 | + with pytest.raises(NoSuchTableError, match="NoSuchTableException: Table does not exist: ns.tbl"): |
| 81 | + _handle_non_200_response(exc, {404: NoSuchTableError}) |
| 82 | + |
| 83 | + |
| 84 | +@pytest.mark.parametrize( |
| 85 | + "status_code, body, expected_exception", |
| 86 | + [ |
| 87 | + (500, "not json at all", ServerError), |
| 88 | + (400, '{"unexpected": "structure"}', BadRequestError), |
| 89 | + ], |
| 90 | +) |
| 91 | +def test_unparseable_body_falls_back_to_validation_error( |
| 92 | + status_code: int, body: str, expected_exception: type[Exception] |
| 93 | +) -> None: |
| 94 | + exc = _make_http_error(status_code, body=body) |
| 95 | + |
| 96 | + with pytest.raises(expected_exception, match="Received unexpected JSON Payload"): |
| 97 | + _handle_non_200_response(exc, {}) |
| 98 | + |
| 99 | + |
| 100 | +def test_empty_body_bypasses_pydantic() -> None: |
| 101 | + exc = _make_http_error(403, body="", reason="Forbidden") |
| 102 | + |
| 103 | + with pytest.raises(ForbiddenError, match="ForbiddenError: RestError: Forbidden"): |
| 104 | + _handle_non_200_response(exc, {}) |
| 105 | + |
| 106 | + |
| 107 | +def test_empty_body_falls_back_to_http_status_phrase() -> None: |
| 108 | + exc = _make_http_error(503, body="") |
| 109 | + exc.response.reason = None |
| 110 | + |
| 111 | + with pytest.raises(ServiceUnavailableError, match="ServiceUnavailableError: RestError: Service Unavailable"): |
| 112 | + _handle_non_200_response(exc, {}) |
| 113 | + |
| 114 | + |
| 115 | +def test_oauth_error_with_description() -> None: |
| 116 | + body = json.dumps( |
| 117 | + { |
| 118 | + "error": "invalid_client", |
| 119 | + "error_description": "Client authentication failed", |
| 120 | + } |
| 121 | + ) |
| 122 | + exc = _make_http_error(401, body=body) |
| 123 | + |
| 124 | + with pytest.raises(OAuthError, match="invalid_client: Client authentication failed"): |
| 125 | + _handle_non_200_response(exc, {401: OAuthError}) |
| 126 | + |
| 127 | + |
| 128 | +def test_oauth_error_with_uri() -> None: |
| 129 | + body = json.dumps( |
| 130 | + { |
| 131 | + "error": "invalid_scope", |
| 132 | + "error_description": "scope not allowed", |
| 133 | + "error_uri": "https://example.com/help", |
| 134 | + } |
| 135 | + ) |
| 136 | + exc = _make_http_error(400, body=body) |
| 137 | + |
| 138 | + with pytest.raises(OAuthError, match=r"invalid_scope: scope not allowed \(https://example.com/help\)"): |
| 139 | + _handle_non_200_response(exc, {400: OAuthError}) |
| 140 | + |
| 141 | + |
| 142 | +def test_oauth_error_without_description() -> None: |
| 143 | + body = json.dumps({"error": "invalid_grant"}) |
| 144 | + exc = _make_http_error(401, body=body) |
| 145 | + |
| 146 | + with pytest.raises(OAuthError, match="^invalid_grant$"): |
| 147 | + _handle_non_200_response(exc, {401: OAuthError}) |
| 148 | + |
| 149 | + |
| 150 | +def test_none_response_raises_value_error() -> None: |
| 151 | + exc = HTTPError() |
| 152 | + exc.response = None |
| 153 | + |
| 154 | + with pytest.raises(ValueError, match="Did not receive a response"): |
| 155 | + _handle_non_200_response(exc, {}) |
0 commit comments