Skip to content

Commit f3ed79a

Browse files
yannforgetnazarfil
andauthored
feat: prioritize connections from env (#297)
* feat: prioritize connections from env * style: unused import * fix: use openehxa client * fix: fix tests mock --------- Co-authored-by: nazarfil <nfilipchuk@bluesquarehub.com>
1 parent 172ecf6 commit f3ed79a

6 files changed

Lines changed: 131 additions & 53 deletions

File tree

openhexa/graphql/graphql_client/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@
210210
GraphQLClientHttpError,
211211
GraphQLClientInvalidResponseError,
212212
)
213+
from .get_connection import (
214+
GetConnection,
215+
GetConnectionConnectionBySlug,
216+
GetConnectionConnectionBySlugFields,
217+
)
213218
from .get_users import GetUsers, GetUsersUsers, GetUsersUsersAvatar
214219
from .input_types import (
215220
AddPipelineOutputInput,
@@ -568,6 +573,9 @@
568573
"GeneratePipelineWebhookUrlInput",
569574
"GenerateWorkspaceTokenError",
570575
"GenerateWorkspaceTokenInput",
576+
"GetConnection",
577+
"GetConnectionConnectionBySlug",
578+
"GetConnectionConnectionBySlugFields",
571579
"GetUsers",
572580
"GetUsersUsers",
573581
"GetUsersUsersAvatar",

openhexa/graphql/graphql_client/client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DeletePipelineVersionDeletePipelineVersion,
4242
)
4343
from .delete_webapp import DeleteWebapp, DeleteWebappDeleteWebapp
44+
from .get_connection import GetConnection, GetConnectionConnectionBySlug
4445
from .get_users import GetUsers, GetUsersUsers
4546
from .input_types import (
4647
AddToFavoritesInput,
@@ -990,3 +991,30 @@ def delete_dataset(
990991
)
991992
data = self.get_data(response)
992993
return DeleteDataset.model_validate(data).delete_dataset
994+
995+
def get_connection(
996+
self, workspace_slug: str, connection_slug: str, **kwargs: Any
997+
) -> Optional[GetConnectionConnectionBySlug]:
998+
query = gql(
999+
"""
1000+
query getConnection($workspaceSlug: String!, $connectionSlug: String!) {
1001+
connectionBySlug(workspaceSlug: $workspaceSlug, connectionSlug: $connectionSlug) {
1002+
__typename
1003+
type
1004+
fields {
1005+
code
1006+
value
1007+
}
1008+
}
1009+
}
1010+
"""
1011+
)
1012+
variables: Dict[str, object] = {
1013+
"workspaceSlug": workspace_slug,
1014+
"connectionSlug": connection_slug,
1015+
}
1016+
response = self.execute(
1017+
query=query, operation_name="getConnection", variables=variables, **kwargs
1018+
)
1019+
data = self.get_data(response)
1020+
return GetConnection.model_validate(data).connection_by_slug
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Generated by ariadne-codegen
2+
# Source: openhexa/graphql/queries.graphql
3+
4+
from typing import List, Literal, Optional
5+
6+
from pydantic import Field
7+
8+
from .base_model import BaseModel
9+
from .enums import ConnectionType
10+
11+
12+
class GetConnection(BaseModel):
13+
connection_by_slug: Optional["GetConnectionConnectionBySlug"] = Field(
14+
alias="connectionBySlug"
15+
)
16+
17+
18+
class GetConnectionConnectionBySlug(BaseModel):
19+
typename__: Literal[
20+
"Connection",
21+
"CustomConnection",
22+
"DHIS2Connection",
23+
"GCSConnection",
24+
"IASOConnection",
25+
"PostgreSQLConnection",
26+
"S3Connection",
27+
] = Field(alias="__typename")
28+
type: ConnectionType
29+
fields: List["GetConnectionConnectionBySlugFields"]
30+
31+
32+
class GetConnectionConnectionBySlugFields(BaseModel):
33+
code: str
34+
value: Optional[str]
35+
36+
37+
GetConnection.model_rebuild()
38+
GetConnectionConnectionBySlug.model_rebuild()

openhexa/graphql/queries.graphql

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,4 +450,14 @@ mutation DeleteDataset($input: DeleteDatasetInput!) {
450450
success
451451
errors
452452
}
453-
}
453+
}
454+
455+
query getConnection($workspaceSlug:String!, $connectionSlug: String!) {
456+
connectionBySlug(workspaceSlug:$workspaceSlug, connectionSlug: $connectionSlug) {
457+
type
458+
fields {
459+
code
460+
value
461+
}
462+
}
463+
}

openhexa/sdk/workspaces/current_workspace.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,27 @@ def _get_local_connection_fields(self, env_variable_prefix: str):
190190

191191
return connection_fields
192192

193+
def get_connection_from_api(self, identifier: str) -> tuple[dict[str, str], str] | None:
194+
"""Get a connection by its identifier from the OpenHEXA API."""
195+
connection_fields: dict[str, str] = {}
196+
connection = OpenHexaClient().get_connection(workspace_slug=self.slug, connection_slug=identifier.lower())
197+
if not connection:
198+
return None
199+
for f in connection.fields:
200+
connection_fields[f.code] = f.value
201+
connection_type = connection.type.upper()
202+
return connection_fields, connection_type
203+
204+
def get_connection_from_env(self, identifier: str) -> tuple[dict[str, str], str] | None:
205+
"""Get a connection by its identifier from the environment variables."""
206+
env_variable_prefix = stringcase.constcase(identifier.lower())
207+
try:
208+
connection_type = os.environ[f"{env_variable_prefix}"].upper()
209+
connection_fields = self._get_local_connection_fields(env_variable_prefix)
210+
return connection_fields, connection_type
211+
except KeyError:
212+
return None
213+
193214
def get_connection(
194215
self, identifier: str
195216
) -> (
@@ -212,42 +233,15 @@ def get_connection(
212233
ValueError
213234
If the connection does not exist
214235
"""
215-
connection_fields = {}
216-
connection_type = None
217-
if self._connected:
218-
response = graphql(
219-
"""
220-
query getConnection($workspaceSlug:String!, $connectionSlug: String!) {
221-
connectionBySlug(workspaceSlug:$workspaceSlug, connectionSlug: $connectionSlug) {
222-
type
223-
fields {
224-
code
225-
value
226-
}
227-
}
228-
}
229-
""",
230-
{"workspaceSlug": self.slug, "connectionSlug": identifier.lower()},
231-
)
232-
data = response["connectionBySlug"]
233-
if data is None:
234-
raise ValueError(f"Connection {identifier} does not exist.")
235-
236-
for d in data["fields"]:
237-
connection_fields[d.get("code")] = d.get("value")
236+
connection = self.get_connection_from_env(identifier)
237+
if not connection and self._connected:
238+
connection = self.get_connection_from_api(identifier)
238239

239-
connection_type = data["type"].upper()
240-
else:
241-
try:
242-
env_variable_prefix = stringcase.constcase(identifier.lower())
243-
connection_type = os.environ[f"{env_variable_prefix}"].upper()
244-
connection_fields = self._get_local_connection_fields(env_variable_prefix)
245-
except KeyError:
246-
raise ValueError
247-
248-
if not connection_type:
240+
if not connection:
249241
raise ValueError(f"Connection {identifier} does not exist.")
250242

243+
connection_fields, connection_type = connection
244+
251245
# In connected mode (API call) the secret_access_key field and db_name name are
252246
# different from the offline ones
253247
if connection_type == "S3":

tests/test_workspace.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytest
1010

11+
from openhexa.graphql import GetConnectionConnectionBySlug
1112
from openhexa.sdk.workspaces.connection import (
1213
CustomConnection,
1314
DHIS2Connection,
@@ -318,10 +319,10 @@ def test_workspace_tmp_path(self, monkeypatch, workspace):
318319

319320
def test_workspace_get_connection_not_exist(self, workspace):
320321
"""Test get connection not found."""
321-
data = {"connectionBySlug": None}
322+
data = None
322323

323324
with mock.patch(
324-
"openhexa.sdk.workspaces.current_workspace.graphql",
325+
"openhexa.sdk.workspaces.current_workspace.OpenHexaClient.get_connection",
325326
return_value=data,
326327
):
327328
with pytest.raises(ValueError):
@@ -330,34 +331,33 @@ def test_workspace_get_connection_not_exist(self, workspace):
330331
def test_workspace_get_connection_case_insensitive(self, workspace):
331332
"""Test get connection."""
332333
data = {
333-
"connectionBySlug": {
334-
"type": "CUSTOM",
335-
"fields": [{"code": "field_1", "value": "field_1_value"}],
336-
}
334+
"__typename": "CustomConnection",
335+
"type": "CUSTOM",
336+
"fields": [{"code": "field_1", "value": "field_1_value"}],
337337
}
338+
mocked_data = GetConnectionConnectionBySlug(**data)
338339
with mock.patch(
339-
"openhexa.sdk.workspaces.current_workspace.graphql",
340-
return_value=data,
340+
"openhexa.sdk.workspaces.current_workspace.OpenHexaClient.get_connection",
341+
return_value=mocked_data,
341342
):
342343
connection = workspace.get_connection("RaNDom")
343344
assert isinstance(connection, CustomConnection)
344345

345346
def test_workspace_get_connection(self, workspace):
346347
"""Test get connection."""
347348
data = {
348-
"connectionBySlug": {
349-
"type": "S3",
350-
"fields": [
351-
{"code": "bucket_name", "value": "bucket_name"},
352-
{"code": "access_key_id", "value": "access_key_id"},
353-
{"code": "access_key_secret", "value": "secret_access_key"},
354-
],
355-
}
349+
"__typename": "S3Connection",
350+
"type": "S3",
351+
"fields": [
352+
{"code": "bucket_name", "value": "bucket_name"},
353+
{"code": "access_key_id", "value": "access_key_id"},
354+
{"code": "access_key_secret", "value": "secret_access_key"},
355+
],
356356
}
357-
357+
mocked_data = GetConnectionConnectionBySlug(**data)
358358
with mock.patch(
359-
"openhexa.sdk.workspaces.current_workspace.graphql",
360-
return_value=data,
359+
"openhexa.sdk.workspaces.current_workspace.OpenHexaClient.get_connection",
360+
return_value=mocked_data,
361361
):
362362
connection = workspace.get_connection("s3-connection")
363363
assert isinstance(connection, S3Connection)

0 commit comments

Comments
 (0)