Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions openhexa/graphql/graphql_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@
GraphQLClientHttpError,
GraphQLClientInvalidResponseError,
)
from .get_connection import (
GetConnection,
GetConnectionConnectionBySlug,
GetConnectionConnectionBySlugFields,
)
from .get_users import GetUsers, GetUsersUsers, GetUsersUsersAvatar
from .input_types import (
AddPipelineOutputInput,
Expand Down Expand Up @@ -568,6 +573,9 @@
"GeneratePipelineWebhookUrlInput",
"GenerateWorkspaceTokenError",
"GenerateWorkspaceTokenInput",
"GetConnection",
"GetConnectionConnectionBySlug",
"GetConnectionConnectionBySlugFields",
"GetUsers",
"GetUsersUsers",
"GetUsersUsersAvatar",
Expand Down
28 changes: 28 additions & 0 deletions openhexa/graphql/graphql_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DeletePipelineVersionDeletePipelineVersion,
)
from .delete_webapp import DeleteWebapp, DeleteWebappDeleteWebapp
from .get_connection import GetConnection, GetConnectionConnectionBySlug
from .get_users import GetUsers, GetUsersUsers
from .input_types import (
AddToFavoritesInput,
Expand Down Expand Up @@ -990,3 +991,30 @@ def delete_dataset(
)
data = self.get_data(response)
return DeleteDataset.model_validate(data).delete_dataset

def get_connection(
self, workspace_slug: str, connection_slug: str, **kwargs: Any
) -> Optional[GetConnectionConnectionBySlug]:
query = gql(
"""
query getConnection($workspaceSlug: String!, $connectionSlug: String!) {
connectionBySlug(workspaceSlug: $workspaceSlug, connectionSlug: $connectionSlug) {
__typename
type
fields {
code
value
}
}
}
"""
)
variables: Dict[str, object] = {
"workspaceSlug": workspace_slug,
"connectionSlug": connection_slug,
}
response = self.execute(
query=query, operation_name="getConnection", variables=variables, **kwargs
)
data = self.get_data(response)
return GetConnection.model_validate(data).connection_by_slug
38 changes: 38 additions & 0 deletions openhexa/graphql/graphql_client/get_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Generated by ariadne-codegen
# Source: openhexa/graphql/queries.graphql

from typing import List, Literal, Optional

from pydantic import Field

from .base_model import BaseModel
from .enums import ConnectionType


class GetConnection(BaseModel):
connection_by_slug: Optional["GetConnectionConnectionBySlug"] = Field(
alias="connectionBySlug"
)


class GetConnectionConnectionBySlug(BaseModel):
typename__: Literal[
"Connection",
"CustomConnection",
"DHIS2Connection",
"GCSConnection",
"IASOConnection",
"PostgreSQLConnection",
"S3Connection",
] = Field(alias="__typename")
type: ConnectionType
fields: List["GetConnectionConnectionBySlugFields"]


class GetConnectionConnectionBySlugFields(BaseModel):
code: str
value: Optional[str]


GetConnection.model_rebuild()
GetConnectionConnectionBySlug.model_rebuild()
12 changes: 11 additions & 1 deletion openhexa/graphql/queries.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,14 @@ mutation DeleteDataset($input: DeleteDatasetInput!) {
success
errors
}
}
}

query getConnection($workspaceSlug:String!, $connectionSlug: String!) {
connectionBySlug(workspaceSlug:$workspaceSlug, connectionSlug: $connectionSlug) {
type
fields {
code
value
}
}
}
60 changes: 27 additions & 33 deletions openhexa/sdk/workspaces/current_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,27 @@ def _get_local_connection_fields(self, env_variable_prefix: str):

return connection_fields

def get_connection_from_api(self, identifier: str) -> tuple[dict[str, str], str] | None:
"""Get a connection by its identifier from the OpenHEXA API."""
connection_fields: dict[str, str] = {}
connection = OpenHexaClient().get_connection(workspace_slug=self.slug, connection_slug=identifier.lower())
if not connection:
return None
for f in connection.fields:
connection_fields[f.code] = f.value
connection_type = connection.type.upper()
return connection_fields, connection_type

def get_connection_from_env(self, identifier: str) -> tuple[dict[str, str], str] | None:
"""Get a connection by its identifier from the environment variables."""
env_variable_prefix = stringcase.constcase(identifier.lower())
try:
connection_type = os.environ[f"{env_variable_prefix}"].upper()
connection_fields = self._get_local_connection_fields(env_variable_prefix)
return connection_fields, connection_type
except KeyError:
return None

def get_connection(
self, identifier: str
) -> (
Expand All @@ -212,42 +233,15 @@ def get_connection(
ValueError
If the connection does not exist
"""
connection_fields = {}
connection_type = None
if self._connected:
response = graphql(
"""
query getConnection($workspaceSlug:String!, $connectionSlug: String!) {
connectionBySlug(workspaceSlug:$workspaceSlug, connectionSlug: $connectionSlug) {
type
fields {
code
value
}
}
}
""",
{"workspaceSlug": self.slug, "connectionSlug": identifier.lower()},
)
data = response["connectionBySlug"]
if data is None:
raise ValueError(f"Connection {identifier} does not exist.")

for d in data["fields"]:
connection_fields[d.get("code")] = d.get("value")
connection = self.get_connection_from_env(identifier)
if not connection and self._connected:
connection = self.get_connection_from_api(identifier)

connection_type = data["type"].upper()
else:
try:
env_variable_prefix = stringcase.constcase(identifier.lower())
connection_type = os.environ[f"{env_variable_prefix}"].upper()
connection_fields = self._get_local_connection_fields(env_variable_prefix)
except KeyError:
raise ValueError

if not connection_type:
if not connection:
raise ValueError(f"Connection {identifier} does not exist.")

connection_fields, connection_type = connection

# In connected mode (API call) the secret_access_key field and db_name name are
# different from the offline ones
if connection_type == "S3":
Expand Down
38 changes: 19 additions & 19 deletions tests/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pytest

from openhexa.graphql import GetConnectionConnectionBySlug
from openhexa.sdk.workspaces.connection import (
CustomConnection,
DHIS2Connection,
Expand Down Expand Up @@ -318,10 +319,10 @@ def test_workspace_tmp_path(self, monkeypatch, workspace):

def test_workspace_get_connection_not_exist(self, workspace):
"""Test get connection not found."""
data = {"connectionBySlug": None}
data = None

with mock.patch(
"openhexa.sdk.workspaces.current_workspace.graphql",
"openhexa.sdk.workspaces.current_workspace.OpenHexaClient.get_connection",
return_value=data,
):
with pytest.raises(ValueError):
Expand All @@ -330,34 +331,33 @@ def test_workspace_get_connection_not_exist(self, workspace):
def test_workspace_get_connection_case_insensitive(self, workspace):
"""Test get connection."""
data = {
"connectionBySlug": {
"type": "CUSTOM",
"fields": [{"code": "field_1", "value": "field_1_value"}],
}
"__typename": "CustomConnection",
"type": "CUSTOM",
"fields": [{"code": "field_1", "value": "field_1_value"}],
}
mocked_data = GetConnectionConnectionBySlug(**data)
with mock.patch(
"openhexa.sdk.workspaces.current_workspace.graphql",
return_value=data,
"openhexa.sdk.workspaces.current_workspace.OpenHexaClient.get_connection",
return_value=mocked_data,
):
connection = workspace.get_connection("RaNDom")
assert isinstance(connection, CustomConnection)

def test_workspace_get_connection(self, workspace):
"""Test get connection."""
data = {
"connectionBySlug": {
"type": "S3",
"fields": [
{"code": "bucket_name", "value": "bucket_name"},
{"code": "access_key_id", "value": "access_key_id"},
{"code": "access_key_secret", "value": "secret_access_key"},
],
}
"__typename": "S3Connection",
"type": "S3",
"fields": [
{"code": "bucket_name", "value": "bucket_name"},
{"code": "access_key_id", "value": "access_key_id"},
{"code": "access_key_secret", "value": "secret_access_key"},
],
}

mocked_data = GetConnectionConnectionBySlug(**data)
with mock.patch(
"openhexa.sdk.workspaces.current_workspace.graphql",
return_value=data,
"openhexa.sdk.workspaces.current_workspace.OpenHexaClient.get_connection",
return_value=mocked_data,
):
connection = workspace.get_connection("s3-connection")
assert isinstance(connection, S3Connection)
Expand Down