Skip to content

Commit 927017f

Browse files
Gayathri Srividya RajavarapuGayathri Srividya Rajavarapu
authored andcommitted
fix: support REST auth configuration from environment variables
1 parent 5da8186 commit 927017f

2 files changed

Lines changed: 137 additions & 1 deletion

File tree

pyiceberg/catalog/rest/__init__.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import json
1920
from collections import deque
2021
from enum import Enum
2122
from typing import (
@@ -435,7 +436,32 @@ def _create_session(self) -> Session:
435436
elif ssl_client_cert := ssl_client.get(CERT):
436437
session.cert = ssl_client_cert
437438

438-
if auth_config := self.properties.get(AUTH):
439+
raw_auth = self.properties.get(AUTH)
440+
if isinstance(raw_auth, str):
441+
try:
442+
auth_config: dict[str, Any] | None = json.loads(raw_auth)
443+
except json.JSONDecodeError as e:
444+
raise ValueError("Failed to parse auth configuration as JSON") from e
445+
elif raw_auth is not None:
446+
auth_config = raw_auth
447+
elif auth_type := self.properties.get(f"{AUTH}.type"):
448+
type_prefix = f"{AUTH}.{auth_type}."
449+
auth_config = {
450+
"type": auth_type,
451+
"impl": self.properties.get(f"{AUTH}.impl"),
452+
auth_type: {
453+
key[len(type_prefix) :].replace("-", "_"): value
454+
for key, value in self.properties.items()
455+
if key.startswith(type_prefix)
456+
},
457+
}
458+
else:
459+
auth_config = None
460+
461+
if auth_config is not None and not isinstance(auth_config, dict):
462+
raise ValueError("auth configuration must be a dictionary")
463+
464+
if auth_config:
439465
auth_type = auth_config.get("type")
440466
if auth_type is None:
441467
raise ValueError("auth.type must be defined")

tests/catalog/test_rest.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import base64
21+
import json
2122
import os
2223
from collections.abc import Callable
2324
from typing import Any, cast
@@ -2470,6 +2471,115 @@ def test_rest_catalog_oauth2_non_200_token_response(requests_mock: Mocker) -> No
24702471
RestCatalog("rest", **catalog_properties) # type: ignore
24712472

24722473

2474+
def _rest_catalog_properties_from_environment() -> RecursiveDict:
2475+
env_config = Config._from_environment_variables({})
2476+
catalogs = cast(RecursiveDict, env_config["catalog"])
2477+
return cast(RecursiveDict, catalogs["rest"])
2478+
2479+
2480+
@mock.patch.dict(
2481+
os.environ,
2482+
{
2483+
"PYICEBERG_CATALOG__REST__URI": TEST_URI,
2484+
"PYICEBERG_CATALOG__REST__AUTH": json.dumps({"type": "basic", "basic": {"username": "one", "password": "two"}}),
2485+
},
2486+
clear=True,
2487+
)
2488+
def test_rest_catalog_with_basic_auth_json_environment_variable(rest_mock: Mocker) -> None:
2489+
rest_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200)
2490+
2491+
RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore
2492+
2493+
encoded_user_pass = base64.b64encode(b"one:two").decode()
2494+
assert rest_mock.last_request.headers["Authorization"] == f"Basic {encoded_user_pass}"
2495+
2496+
2497+
@mock.patch.dict(
2498+
os.environ,
2499+
{
2500+
"PYICEBERG_CATALOG__REST__URI": TEST_URI,
2501+
"PYICEBERG_CATALOG__REST__AUTH": json.dumps(
2502+
{
2503+
"type": "oauth2",
2504+
"oauth2": {
2505+
"client_id": "some_client_id",
2506+
"client_secret": "some_client_secret",
2507+
"token_url": f"{TEST_URI}oauth2/token",
2508+
},
2509+
}
2510+
),
2511+
},
2512+
clear=True,
2513+
)
2514+
def test_rest_catalog_with_oauth2_auth_json_environment_variable(requests_mock: Mocker) -> None:
2515+
requests_mock.post(
2516+
f"{TEST_URI}oauth2/token",
2517+
json={"access_token": TEST_TOKEN, "token_type": "Bearer", "expires_in": 3600},
2518+
status_code=200,
2519+
)
2520+
requests_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200)
2521+
2522+
catalog = RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore
2523+
2524+
assert catalog.uri == TEST_URI
2525+
2526+
2527+
@mock.patch.dict(
2528+
os.environ,
2529+
{
2530+
"PYICEBERG_CATALOG__REST__URI": TEST_URI,
2531+
"PYICEBERG_CATALOG__REST__AUTH": "not-valid-json",
2532+
},
2533+
clear=True,
2534+
)
2535+
def test_rest_catalog_with_invalid_json_auth_environment_variable() -> None:
2536+
with pytest.raises(ValueError, match="Failed to parse auth configuration as JSON"):
2537+
RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore
2538+
2539+
2540+
@mock.patch.dict(
2541+
os.environ,
2542+
{
2543+
"PYICEBERG_CATALOG__REST__URI": TEST_URI,
2544+
"PYICEBERG_CATALOG__REST__AUTH__TYPE": "basic",
2545+
"PYICEBERG_CATALOG__REST__AUTH__BASIC__USERNAME": "one",
2546+
"PYICEBERG_CATALOG__REST__AUTH__BASIC__PASSWORD": "two",
2547+
},
2548+
clear=True,
2549+
)
2550+
def test_rest_catalog_with_basic_auth_flat_environment_variables(rest_mock: Mocker) -> None:
2551+
rest_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200)
2552+
2553+
RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore
2554+
2555+
encoded_user_pass = base64.b64encode(b"one:two").decode()
2556+
assert rest_mock.last_request.headers["Authorization"] == f"Basic {encoded_user_pass}"
2557+
2558+
2559+
@mock.patch.dict(
2560+
os.environ,
2561+
{
2562+
"PYICEBERG_CATALOG__REST__URI": TEST_URI,
2563+
"PYICEBERG_CATALOG__REST__AUTH__TYPE": "oauth2",
2564+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_ID": "some_client_id",
2565+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__CLIENT_SECRET": "some_client_secret",
2566+
"PYICEBERG_CATALOG__REST__AUTH__OAUTH2__TOKEN_URL": f"{TEST_URI}oauth2/token",
2567+
},
2568+
clear=True,
2569+
)
2570+
def test_rest_catalog_with_oauth2_auth_flat_environment_variables(requests_mock: Mocker) -> None:
2571+
requests_mock.post(
2572+
f"{TEST_URI}oauth2/token",
2573+
json={"access_token": TEST_TOKEN, "token_type": "Bearer", "expires_in": 3600},
2574+
status_code=200,
2575+
)
2576+
requests_mock.get(f"{TEST_URI}v1/config", json={"defaults": {}, "overrides": {}}, status_code=200)
2577+
2578+
catalog = RestCatalog("rest", **_rest_catalog_properties_from_environment()) # type: ignore
2579+
2580+
assert catalog.uri == TEST_URI
2581+
2582+
24732583
EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}
24742584

24752585

0 commit comments

Comments
 (0)