Skip to content

Commit a5139d8

Browse files
authored
Merge pull request #1661 from weaviate/auth/allow-string-api-key-in-connect-helpers
Permit parsing API key as string in all `connect_to_x` helpers
2 parents 4c3b939 + 10ad064 commit a5139d8

3 files changed

Lines changed: 60 additions & 20 deletions

File tree

integration/test_auth.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ def test_authentication_with_bearer_token_no_refresh() -> None:
226226
assert str(recwarn.list[0].message).startswith("Auth002")
227227

228228

229+
def test_api_key_string() -> None:
230+
assert is_auth_enabled(f"localhost:{WCS_PORT}")
231+
with weaviate.connect_to_local(port=WCS_PORT, auth_credentials="my-secret-key") as client:
232+
client.collections.list_all()
233+
234+
229235
def test_api_key() -> None:
230236
assert is_auth_enabled(f"localhost:{WCS_PORT}")
231237
with weaviate.connect_to_local(

mock_tests/test_auth.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import time
44
import warnings
5+
from typing import Union
56

67
import grpc
78
import pytest
@@ -412,8 +413,18 @@ def handler(request: Request):
412413
assert issubclass(w[0].category, UserWarning)
413414

414415

416+
@pytest.mark.parametrize(
417+
"api_key",
418+
[
419+
"Super-secret-key",
420+
weaviate.auth.AuthApiKey(api_key="Super-secret-key"),
421+
],
422+
)
415423
def test_with_simple_auth_no_oidc_via_api_key(
416-
weaviate_mock: HTTPServer, start_grpc_server: grpc.Server, recwarn
424+
weaviate_mock: HTTPServer,
425+
start_grpc_server: grpc.Server,
426+
recwarn,
427+
api_key: Union[str, weaviate.auth.AuthApiKey],
417428
) -> None:
418429
weaviate_mock.expect_request(
419430
"/v1/schema", headers={"Authorization": "Bearer " + "Super-secret-key"}
@@ -423,7 +434,7 @@ def test_with_simple_auth_no_oidc_via_api_key(
423434
host=MOCK_IP,
424435
port=MOCK_PORT,
425436
grpc_port=MOCK_PORT_GRPC,
426-
auth_credentials=weaviate.auth.AuthApiKey(api_key="Super-secret-key"),
437+
auth_credentials=api_key,
427438
)
428439
client.collections.list_all()
429440

weaviate/connect/helpers.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
"""Helper functions for creating new WeaviateClient or WeaviateAsyncClient instances in common scenarios."""
22

3-
from typing import Dict, Optional, Tuple
3+
from typing import Dict, Optional, Tuple, Union
44
from urllib.parse import urlparse
55

66
from deprecation import deprecated as docstring_deprecated
77
from typing_extensions import deprecated as typing_deprecated
88

9-
from weaviate.auth import AuthCredentials
9+
from weaviate.auth import (
10+
Auth,
11+
AuthCredentials,
12+
_APIKey,
13+
_BearerToken,
14+
_ClientCredentials,
15+
_ClientPassword,
16+
)
1017
from weaviate.client import WeaviateAsyncClient, WeaviateClient
1118
from weaviate.config import AdditionalConfig
1219
from weaviate.connect.base import ConnectionParams, ProtocolParams
@@ -27,9 +34,25 @@ def __parse_weaviate_cloud_cluster_url(cluster_url: str) -> Tuple[str, str]:
2734
return cluster_url, grpc_host
2835

2936

37+
def __parse_auth_credentials(creds: Union[str, AuthCredentials, None]) -> Optional[AuthCredentials]:
38+
if isinstance(creds, str):
39+
# If the credentials are a string, assume it's an API key.
40+
return Auth.api_key(creds)
41+
elif isinstance(
42+
creds, (_BearerToken, _ClientPassword, _ClientCredentials, _APIKey)
43+
): # use AuthCredentials after python 3.9 has been removed
44+
# If the credentials are already an AuthCredentials object, return it as is.
45+
return creds
46+
elif creds is None:
47+
# If no credentials are provided, return None.
48+
return None
49+
else:
50+
raise ValueError("Invalid auth credentials provided.")
51+
52+
3053
def connect_to_weaviate_cloud(
3154
cluster_url: str,
32-
auth_credentials: Optional[AuthCredentials],
55+
auth_credentials: Union[str, AuthCredentials],
3356
headers: Optional[Dict[str, str]] = None,
3457
additional_config: Optional[AdditionalConfig] = None,
3558
skip_init_checks: bool = False,
@@ -42,7 +65,7 @@ def connect_to_weaviate_cloud(
4265
4366
Args:
4467
cluster_url: The WCD cluster URL or hostname to connect to. Usually in the form: rAnD0mD1g1t5.something.weaviate.cloud
45-
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case use
68+
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case pass a string or use
4669
`weaviate.classes.init.Auth.api_key()`, a bearer token, in which case use `weaviate.classes.init.Auth.bearer_token()`, a client secret,
4770
in which case use `weaviate.classes.init.Auth.client_credentials()` or a username and password, in which case use `weaviate.classes.init.Auth.client_password()`.
4871
headers: Additional headers to include in the requests, e.g. API keys for third-party Cloud vectorization.
@@ -79,7 +102,7 @@ def connect_to_weaviate_cloud(
79102
http=ProtocolParams(host=cluster_url, port=443, secure=True),
80103
grpc=ProtocolParams(host=grpc_host, port=443, secure=True),
81104
),
82-
auth_client_secret=auth_credentials,
105+
auth_client_secret=__parse_auth_credentials(auth_credentials),
83106
additional_headers=headers,
84107
additional_config=additional_config,
85108
skip_init_checks=skip_init_checks,
@@ -98,7 +121,7 @@ def connect_to_weaviate_cloud(
98121
)
99122
def connect_to_wcs(
100123
cluster_url: str,
101-
auth_credentials: Optional[AuthCredentials],
124+
auth_credentials: Union[str, AuthCredentials],
102125
headers: Optional[Dict[str, str]] = None,
103126
additional_config: Optional[AdditionalConfig] = None,
104127
skip_init_checks: bool = False,
@@ -115,7 +138,7 @@ def connect_to_local(
115138
headers: Optional[Dict[str, str]] = None,
116139
additional_config: Optional[AdditionalConfig] = None,
117140
skip_init_checks: bool = False,
118-
auth_credentials: Optional[AuthCredentials] = None,
141+
auth_credentials: Union[str, AuthCredentials, None] = None,
119142
) -> WeaviateClient:
120143
"""Connect to a local Weaviate instance deployed using Docker compose with standard port configurations.
121144
@@ -130,7 +153,7 @@ def connect_to_local(
130153
headers: Additional headers to include in the requests, e.g. API keys for Cloud vectorization.
131154
additional_config: This includes many additional, rarely used config options. use wvc.init.AdditionalConfig() to configure.
132155
skip_init_checks: Whether to skip the initialization checks when connecting to Weaviate.
133-
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case use `weaviate.classes.init.Auth.api_key()`,
156+
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case pass a string or use `weaviate.classes.init.Auth.api_key()`,
134157
a bearer token, in which case use `weaviate.classes.init.Auth.bearer_token()`, a client secret, in which case use `weaviate.classes.init.Auth.client_credentials()`
135158
or a username and password, in which case use `weaviate.classes.init.Auth.client_password()`.
136159
@@ -168,7 +191,7 @@ def connect_to_local(
168191
additional_headers=headers,
169192
additional_config=additional_config,
170193
skip_init_checks=skip_init_checks,
171-
auth_client_secret=auth_credentials,
194+
auth_client_secret=__parse_auth_credentials(auth_credentials),
172195
)
173196
)
174197

@@ -277,7 +300,7 @@ def connect_to_custom(
277300
grpc_secure: Whether to use a secure channel for the underlying gRPC API.
278301
headers: Additional headers to include in the requests, e.g. API keys for Cloud vectorization.
279302
additional_config: This includes many additional, rarely used config options. use wvc.init.AdditionalConfig() to configure.
280-
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case use `weaviate.classes.init.Auth.api_key()`,
303+
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case pass a string or use `weaviate.classes.init.Auth.api_key()`,
281304
a bearer token, in which case use `weaviate.classes.init.Auth.bearer_token()`, a client secret, in which case use `weaviate.classes.init.Auth.client_credentials()`
282305
or a username and password, in which case use `weaviate.classes.init.Auth.client_password()`.
283306
skip_init_checks: Whether to skip the initialization checks when connecting to Weaviate.
@@ -323,7 +346,7 @@ def connect_to_custom(
323346
grpc_port=grpc_port,
324347
grpc_secure=grpc_secure,
325348
),
326-
auth_client_secret=auth_credentials,
349+
auth_client_secret=__parse_auth_credentials(auth_credentials),
327350
additional_headers=headers,
328351
additional_config=additional_config,
329352
skip_init_checks=skip_init_checks,
@@ -355,7 +378,7 @@ def use_async_with_weaviate_cloud(
355378
356379
Args:
357380
cluster_url: The WCD cluster URL or hostname to connect to. Usually in the form: rAnD0mD1g1t5.something.weaviate.cloud
358-
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case use `weaviate.classes.init.Auth.api_key()`,
381+
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case pass a string or use `weaviate.classes.init.Auth.api_key()`,
359382
a bearer token, in which case use `weaviate.classes.init.Auth.bearer_token()`, a client secret, in which case use `weaviate.classes.init.Auth.client_credentials()`
360383
or a username and password, in which case use `weaviate.classes.init.Auth.client_password()`.
361384
headers: Additional headers to include in the requests, e.g. API keys for third-party Cloud vectorization.
@@ -393,7 +416,7 @@ def use_async_with_weaviate_cloud(
393416
http=ProtocolParams(host=cluster_url, port=443, secure=True),
394417
grpc=ProtocolParams(host=grpc_host, port=443, secure=True),
395418
),
396-
auth_client_secret=auth_credentials,
419+
auth_client_secret=__parse_auth_credentials(auth_credentials),
397420
additional_headers=headers,
398421
additional_config=additional_config,
399422
skip_init_checks=skip_init_checks,
@@ -407,7 +430,7 @@ def use_async_with_local(
407430
headers: Optional[Dict[str, str]] = None,
408431
additional_config: Optional[AdditionalConfig] = None,
409432
skip_init_checks: bool = False,
410-
auth_credentials: Optional[AuthCredentials] = None,
433+
auth_credentials: Union[str, AuthCredentials, None] = None,
411434
) -> WeaviateAsyncClient:
412435
"""Create an async client object ready to connect to a local Weaviate instance deployed using Docker compose with standard port configurations.
413436
@@ -422,7 +445,7 @@ def use_async_with_local(
422445
headers: Additional headers to include in the requests, e.g. API keys for Cloud vectorization.
423446
additional_config: This includes many additional, rarely used config options. use wvc.init.AdditionalConfig() to configure.
424447
skip_init_checks: Whether to skip the initialization checks when connecting to Weaviate.
425-
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case use `weaviate.classes.init.Auth.api_key()`,
448+
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case pass a string or use `weaviate.classes.init.Auth.api_key()`,
426449
a bearer token, in which case use `weaviate.classes.init.Auth.bearer_token()`, a client secret, in which case use `weaviate.classes.init.Auth.client_credentials()`
427450
or a username and password, in which case use `weaviate.classes.init.Auth.client_password()`.
428451
@@ -462,7 +485,7 @@ def use_async_with_local(
462485
additional_headers=headers,
463486
additional_config=additional_config,
464487
skip_init_checks=skip_init_checks,
465-
auth_client_secret=auth_credentials,
488+
auth_client_secret=__parse_auth_credentials(auth_credentials),
466489
)
467490

468491

@@ -574,7 +597,7 @@ def use_async_with_custom(
574597
grpc_secure: Whether to use a secure channel for the underlying gRPC API.
575598
headers: Additional headers to include in the requests, e.g. API keys for Cloud vectorization.
576599
additional_config: This includes many additional, rarely used config options. use wvc.init.AdditionalConfig() to configure.
577-
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case use `weaviate.classes.init.Auth.api_key()`,
600+
auth_credentials: The credentials to use for authentication with your Weaviate instance. This can be an API key, in which case pass a string or use `weaviate.classes.init.Auth.api_key()`,
578601
a bearer token, in which case use `weaviate.classes.init.Auth.bearer_token()`, a client secret, in which case use `weaviate.classes.init.Auth.client_credentials()`
579602
or a username and password, in which case use `weaviate.classes.init.Auth.client_password()`.
580603
skip_init_checks: Whether to skip the initialization checks when connecting to Weaviate.
@@ -622,7 +645,7 @@ def use_async_with_custom(
622645
grpc_port=grpc_port,
623646
grpc_secure=grpc_secure,
624647
),
625-
auth_client_secret=auth_credentials,
648+
auth_client_secret=__parse_auth_credentials(auth_credentials),
626649
additional_headers=headers,
627650
additional_config=additional_config,
628651
skip_init_checks=skip_init_checks,

0 commit comments

Comments
 (0)