Skip to content

Commit 92ad051

Browse files
committed
add tests
1 parent 5954550 commit 92ad051

File tree

3 files changed

+118
-8
lines changed

3 files changed

+118
-8
lines changed

pyiceberg/catalog/rest/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class IdentifierKind(Enum):
135135
OAUTH2_SERVER_URI = "oauth2-server-uri"
136136
SNAPSHOT_LOADING_MODE = "snapshot-loading-mode"
137137
AUTH = "auth"
138+
CUSTOM = "custom"
138139

139140
NAMESPACE_SEPARATOR = b"\x1f".decode(UTF8)
140141

@@ -249,15 +250,19 @@ def _create_session(self) -> Session:
249250
session.cert = ssl_client_cert
250251

251252
if auth_config := self.properties.get(AUTH):
252-
# set up auth_manager based on the properties
253253
auth_type = auth_config.get("type")
254254
if auth_type is None:
255255
raise ValueError("auth.type must be defined")
256256
auth_type_config = auth_config.get(auth_type, {})
257-
if auth_impl := auth_config.get("impl"):
258-
session.auth = AuthManagerAdapter(AuthManagerFactory.create(auth_impl, auth_type_config))
259-
else:
260-
session.auth = AuthManagerAdapter(AuthManagerFactory.create(auth_type, auth_type_config))
257+
auth_impl = auth_config.get("impl")
258+
259+
if auth_type is CUSTOM and not auth_impl:
260+
raise ValueError("auth.impl must be specified when using custom auth.type")
261+
262+
if auth_type is not CUSTOM and auth_impl:
263+
raise ValueError("auth.impl can only be specified when using custom auth.type")
264+
265+
session.auth = AuthManagerAdapter(AuthManagerFactory.create(auth_impl or auth_type, auth_type_config))
261266
else:
262267
session.auth = AuthManagerAdapter(self._create_legacy_oauth2_auth_manager(session))
263268

pyiceberg/catalog/rest/auth.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,23 @@ def get_token(self) -> str:
189189
class OAuth2AuthManager(AuthManager):
190190
"""Auth Manager implementation that supports OAuth2 as defined in IETF RFC6749."""
191191

192-
def __init__(self, token_provider: OAuth2TokenProvider):
193-
self.token_provider = token_provider
192+
def __init__(
193+
self,
194+
client_id: str,
195+
client_secret: str,
196+
token_url: str,
197+
scope: Optional[str] = None,
198+
refresh_margin: int = 60,
199+
expires_in: Optional[int] = None,
200+
):
201+
self.token_provider = OAuth2TokenProvider(
202+
client_id,
203+
client_secret,
204+
token_url,
205+
scope,
206+
refresh_margin,
207+
expires_in,
208+
)
194209

195210
def auth_header(self) -> str:
196211
return f"Bearer {self.token_provider.get_token()}"
@@ -274,3 +289,4 @@ def create(cls, class_or_name: str, config: Dict[str, Any]) -> AuthManager:
274289
AuthManagerFactory.register("noop", NoopAuthManager)
275290
AuthManagerFactory.register("basic", BasicAuthManager)
276291
AuthManagerFactory.register("legacyoauth2", LegacyOAuth2AuthManager)
292+
AuthManagerFactory.register("oauth2", OAuth2AuthManager)

tests/catalog/test_rest.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1536,7 +1536,7 @@ def test_rest_catalog_with_basic_auth_type() -> None:
15361536
assert "BasicAuthManager.__init__() missing 1 required positional argument: 'password'" in str(e.value)
15371537

15381538

1539-
def test_rest_catalog_with_auth_impl() -> None:
1539+
def test_rest_catalog_with_custom_auth_type() -> None:
15401540
# Given
15411541
catalog_properties = {
15421542
"uri": TEST_URI,
@@ -1555,6 +1555,95 @@ def test_rest_catalog_with_auth_impl() -> None:
15551555
assert "Could not load AuthManager class for 'dummy.nonexistent.package'" in str(e.value)
15561556

15571557

1558+
def test_rest_catalog_with_custom_auth_type_no_impl() -> None:
1559+
# Given
1560+
catalog_properties = {
1561+
"uri": TEST_URI,
1562+
"auth": {
1563+
"type": "custom",
1564+
"custom": {
1565+
"property1": "one",
1566+
"property2": "two",
1567+
},
1568+
},
1569+
}
1570+
with pytest.raises(ValueError) as e:
1571+
# Missing namespace
1572+
RestCatalog("rest", **catalog_properties) # type: ignore
1573+
assert "auth.impl must be specified when using custom auth.type" in str(e.value)
1574+
1575+
1576+
def test_rest_catalog_with_non_custom_auth_type_impl() -> None:
1577+
# Given
1578+
catalog_properties = {
1579+
"uri": TEST_URI,
1580+
"auth": {
1581+
"type": "oauth2",
1582+
"impl": "oauth2.package",
1583+
"oauth2": {
1584+
"property1": "one",
1585+
"property2": "two",
1586+
},
1587+
},
1588+
}
1589+
with pytest.raises(ValueError) as e:
1590+
# Missing namespace
1591+
RestCatalog("rest", **catalog_properties) # type: ignore
1592+
assert "auth.impl can only be specified when using custom auth.type" in str(e.value)
1593+
1594+
1595+
def test_rest_catalog_with_unsupported_auth_type() -> None:
1596+
# Given
1597+
catalog_properties = {
1598+
"uri": TEST_URI,
1599+
"auth": {
1600+
"type": "unsupported",
1601+
"unsupported": {
1602+
"property1": "one",
1603+
"property2": "two",
1604+
},
1605+
},
1606+
}
1607+
with pytest.raises(ValueError) as e:
1608+
# Missing namespace
1609+
RestCatalog("rest", **catalog_properties) # type: ignore
1610+
assert "Could not load AuthManager class for 'unsupported'" in str(e.value)
1611+
1612+
1613+
def test_rest_catalog_with_oauth2_auth_type(requests_mock: Mocker) -> None:
1614+
requests_mock.post(
1615+
f"{TEST_URI}oauth2/token",
1616+
json={
1617+
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
1618+
"token_type": "Bearer",
1619+
"expires_in": 3600,
1620+
"refresh_token": "IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk",
1621+
"scope": "read",
1622+
},
1623+
status_code=200,
1624+
)
1625+
requests_mock.get(
1626+
f"{TEST_URI}v1/config",
1627+
json={"defaults": {}, "overrides": {}},
1628+
status_code=200,
1629+
)
1630+
# Given
1631+
catalog_properties = {
1632+
"uri": TEST_URI,
1633+
"auth": {
1634+
"type": "oauth2",
1635+
"oauth2": {
1636+
"client_id": "some_client_id",
1637+
"client_secret": "some_client_secret",
1638+
"token_url": f"{TEST_URI}oauth2/token",
1639+
"scope": "read",
1640+
},
1641+
},
1642+
}
1643+
catalog = RestCatalog("rest", **catalog_properties) # type: ignore
1644+
assert catalog.uri == TEST_URI
1645+
1646+
15581647
EXAMPLE_ENV = {"PYICEBERG_CATALOG__PRODUCTION__URI": TEST_URI}
15591648

15601649

0 commit comments

Comments
 (0)