diff --git a/openhexa/__init__.py b/openhexa/__init__.py new file mode 100644 index 00000000..ee557fde --- /dev/null +++ b/openhexa/__init__.py @@ -0,0 +1 @@ +"""OpenHexa package initialization.""" diff --git a/openhexa/cli/api.py b/openhexa/cli/api.py index 69f2d16b..2bd81b87 100644 --- a/openhexa/cli/api.py +++ b/openhexa/cli/api.py @@ -8,24 +8,19 @@ import os import tempfile import typing -from datetime import datetime -from importlib.metadata import version from pathlib import Path from zipfile import ZipFile import click import docker -import requests from docker.models.containers import Container -from graphql import build_client_schema, build_schema, get_introspection_query -from graphql.utilities import find_breaking_changes from jinja2 import Template -from openhexa.cli.graphql.graphql_client import Client from openhexa.cli.settings import settings +from openhexa.graphql.openhexa_client import graphql from openhexa.sdk.pipelines import get_local_workspace_config from openhexa.sdk.pipelines.runtime import get_pipeline -from openhexa.utils import create_requests_session, stringcase +from openhexa.utils import stringcase class InvalidDefinitionError(Exception): @@ -46,24 +41,6 @@ class OutputDirectoryError(Exception): pass -class APIError(Exception): - """Raised when an error occurs while interacting with the API.""" - - pass - - -class InvalidTokenError(APIError): - """Raised when the token is invalid.""" - - pass - - -class GraphQLError(APIError): - """Raised when a GraphQL request returns an error.""" - - pass - - class PipelineDirectoryError(Exception): """Raised when the pipeline directory is not a directory or does not exist.""" @@ -90,108 +67,6 @@ class PermissionDenied(Exception): pass -def get_library_versions() -> tuple[str, str]: - """Return the current version and the one on PyPi.""" - # Get the currently installed version - installed_version = version("openhexa.sdk") - - # Get the latest version available on PyPI - try: - response = requests.get("https://pypi.org/pypi/openhexa.sdk/json") - latest_version = response.json()["info"]["version"] - return installed_version, latest_version - except requests.RequestException: - logging.error( - "Could not check for the latest version of the openhexa.sdk package.", - exc_info=True, - ) - return installed_version, installed_version - - -def _detect_graphql_breaking_changes_if_needed(token): - """Detect breaking changes if not done recently between the schema referenced in the SDK and the server using graphql-core.""" - ONE_HOUR = 60 * 60 - now_timestamp = int(datetime.now().timestamp()) - if not settings.last_breaking_change_check or now_timestamp - settings.last_breaking_change_check > ONE_HOUR: - _detect_graphql_breaking_changes(token) - settings.last_breaking_change_check = now_timestamp - - -def _detect_graphql_breaking_changes(token): - """Detect breaking changes between the schema referenced in the SDK and the server using graphql-core.""" - stored_schema_obj = build_schema((Path(__file__).parent / "graphql" / "schema.generated.graphql").open().read()) - server_schema_obj = build_client_schema( - _query_graphql(get_introspection_query(input_value_deprecation=True), token=token) - ) - - breaking_changes = find_breaking_changes(stored_schema_obj, server_schema_obj) - if breaking_changes: - current_version, latest_version = get_library_versions() - click.secho( - f"⚠️ Breaking changes detected between the SDK (version {current_version}) and the server:", - fg="red", - ) - for change in breaking_changes: - click.secho(f"- {change.description}", fg="yellow") - click.secho( - "This could lead to unexpected results.\n" - f"Please update the SDK to the latest version {latest_version} " - f"(using `pip install openhexa-sdk=={latest_version}`) or use a version of the SDK compatible with the server.", - fg="red", - ) - - -def graphql(query: str, variables=None, token=None): - """Check that there is no breaking change and perform a GraphQL request.""" - _detect_graphql_breaking_changes_if_needed(token) - return _query_graphql(query, variables, token) - - -def _query_graphql(query: str, variables=None, token=None): - """Perform a GraphQL request.""" - url = settings.api_url + "/graphql/" - if token is None: - token = settings.access_token - - if token is None: - raise InvalidTokenError("No token found for workspace") - - if settings.debug: - click.echo("") - click.echo("Graphql Query:") - click.echo(f"URL: {url}") - click.echo(f"Query: {query}") - click.echo(f"Variables: {variables}") - - session = create_requests_session() - - response = session.post( - url, - headers={ - "User-Agent": f"openhexa-cli/{version('openhexa.sdk')}", - "Authorization": f"Bearer {token}", - }, - json={"query": query, "variables": variables}, - ) - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: - raise GraphQLError(str(e)) - - data = response.json() - - if settings.debug: - click.echo("Graphql Response:") - click.echo(data) - click.echo("") - - if data.get("errors"): - if data.get("errors")[0].get("extensions", {}).get("code") == "UNAUTHENTICATED": - raise InvalidTokenError - raise GraphQLError(data["errors"]) - return data["data"] - - def get_skeleton_dir(): """Get the path to the skeleton directory.""" return Path(__file__).parent / "skeleton" @@ -745,47 +620,3 @@ def is_dhis2_connection_up(workspace_slug: str, connection_slug: str) -> bool: }, ) return response["data"]["connectionBySlug"]["status"] == "UP" - - -class OpenHexaClient(Client): - """OpenHexaClient is a class that provides methods to interact with the OpenHexa GraphQL API.""" - - def __init__(self, token=None): - """Initialize the OpenHexaClient with the OpenHexa API URL and headers.""" - self._url = settings.api_url + "/graphql/" - self._token = token or settings.access_token - - if not self._token: - raise InvalidTokenError("No token found for workspace") - - super().__init__( - url=self._url, - headers={ - "User-Agent": f"openhexa-cli/{version('openhexa.sdk')}", - "Authorization": f"Bearer {self._token}", - }, - ) - logging.getLogger("httpx").setLevel( - logging.WARNING - ) # HTTPX logs queries by default, we disable them here with WARNING level - - def execute(self, query, **kwargs): - """Decorate parent execute method to log the GraphQL query and response.""" - _detect_graphql_breaking_changes(token=self._token) - - if settings.debug: - click.echo("") - click.echo("Graphql Query:") - click.echo(f"URL: {self.url}") - click.echo(f"Query: {query}") - variables = kwargs.get("variables", {}) - click.echo(f"Variables: {variables}") - - response = super().execute(query=query, **kwargs) - - if settings.debug: - click.echo("") - click.echo("Graphql Response:") - click.echo(f"Response: {response}") - - return response diff --git a/openhexa/cli/cli.py b/openhexa/cli/cli.py index 145fd8c5..377519fb 100644 --- a/openhexa/cli/cli.py +++ b/openhexa/cli/cli.py @@ -14,7 +14,6 @@ DockerError, InvalidDefinitionError, NoActiveWorkspaceError, - OpenHexaClient, OutputDirectoryError, PipelineDirectoryError, create_pipeline, @@ -23,7 +22,6 @@ delete_pipeline, download_pipeline_sourcecode, ensure_is_pipeline_dir, - get_library_versions, get_pipeline_from_code, get_pipelines_pages, get_workspace, @@ -31,6 +29,7 @@ upload_pipeline, ) from openhexa.cli.settings import settings, setup_logging +from openhexa.graphql.openhexa_client import OpenHexaClient, get_library_versions from openhexa.sdk.pipelines.exceptions import PipelineNotFound from openhexa.sdk.pipelines.runtime import get_pipeline @@ -597,7 +596,9 @@ def pipelines_list(): _terminate("No workspace activated", err=True) workspace_pipelines = ( - OpenHexaClient().get_workspace_pipelines(workspace_slug=settings.current_workspace).pipelines.items + OpenHexaClient(settings.api_url, settings.access_token) + .get_workspace_pipelines(workspace_slug=settings.current_workspace) + .pipelines.items ) if len(workspace_pipelines) == 0: click.echo(f"No pipelines in workspace {settings.current_workspace}") diff --git a/openhexa/cli/settings.py b/openhexa/cli/settings.py index 67915fae..2a322570 100644 --- a/openhexa/cli/settings.py +++ b/openhexa/cli/settings.py @@ -6,8 +6,6 @@ import click -from openhexa.sdk.pipelines.log_level import LogLevel - CONFIGFILE_PATH = os.path.expanduser("~") + "/.openhexa.ini" @@ -78,11 +76,6 @@ def workspaces(self): """Return the workspaces from the settings file.""" return self._file_config["workspaces"] - @property - def log_level(self) -> LogLevel: - """Return the log level from the environment variables.""" - return LogLevel.parse_log_level(os.getenv("HEXA_LOG_LEVEL")) - def activate(self, workspace: str): """Set the current workspace in the settings file.""" if workspace not in self.workspaces: diff --git a/openhexa/graphql/__init__.py b/openhexa/graphql/__init__.py new file mode 100644 index 00000000..3678d524 --- /dev/null +++ b/openhexa/graphql/__init__.py @@ -0,0 +1 @@ +"""OpenHexa-SDK GraphQL client to communicate with OpenHexa API.""" diff --git a/openhexa/cli/graphql/graphql_client/__init__.py b/openhexa/graphql/graphql_client/__init__.py similarity index 100% rename from openhexa/cli/graphql/graphql_client/__init__.py rename to openhexa/graphql/graphql_client/__init__.py diff --git a/openhexa/cli/graphql/graphql_client/async_base_client.py b/openhexa/graphql/graphql_client/async_base_client.py similarity index 77% rename from openhexa/cli/graphql/graphql_client/async_base_client.py rename to openhexa/graphql/graphql_client/async_base_client.py index 311e672a..f3a25da4 100644 --- a/openhexa/cli/graphql/graphql_client/async_base_client.py +++ b/openhexa/graphql/graphql_client/async_base_client.py @@ -2,7 +2,8 @@ import enum import json -from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from collections.abc import AsyncIterator +from typing import IO, Any, Optional, TypeVar, cast from uuid import uuid4 import httpx @@ -20,6 +21,8 @@ try: from websockets.client import ( # type: ignore[import-not-found,unused-ignore] WebSocketClientProtocol, + ) + from websockets.client import ( connect as ws_connect, ) from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] @@ -63,18 +66,16 @@ class AsyncBaseClient: def __init__( self, url: str = "", - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, http_client: Optional[httpx.AsyncClient] = None, ws_url: str = "", - ws_headers: Optional[Dict[str, Any]] = None, + ws_headers: Optional[dict[str, Any]] = None, ws_origin: Optional[str] = None, - ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ws_connection_init_payload: Optional[dict[str, Any]] = None, ) -> None: self.url = url self.headers = headers - self.http_client = ( - http_client if http_client else httpx.AsyncClient(headers=headers) - ) + self.http_client = http_client if http_client else httpx.AsyncClient(headers=headers) self.ws_url = ws_url self.ws_headers = ws_headers or {} @@ -96,7 +97,7 @@ async def execute( self, query: str, operation_name: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, + variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -118,43 +119,37 @@ async def execute( **kwargs, ) - def get_data(self, response: httpx.Response) -> Dict[str, Any]: + def get_data(self, response: httpx.Response) -> dict[str, Any]: if not response.is_success: - raise GraphQLClientHttpError( - status_code=response.status_code, response=response - ) + raise GraphQLClientHttpError(status_code=response.status_code, response=response) try: response_json = response.json() except ValueError as exc: raise GraphQLClientInvalidResponseError(response=response) from exc - if (not isinstance(response_json, dict)) or ( - "data" not in response_json and "errors" not in response_json - ): + if (not isinstance(response_json, dict)) or ("data" not in response_json and "errors" not in response_json): raise GraphQLClientInvalidResponseError(response=response) data = response_json.get("data") errors = response_json.get("errors") if errors: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=errors, data=data - ) + raise GraphQLClientGraphQLMultiError.from_errors_dicts(errors_dicts=errors, data=data) - return cast(Dict[str, Any], data) + return cast(dict[str, Any], data) async def execute_ws( self, query: str, operation_name: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, + variables: Optional[dict[str, Any]] = None, **kwargs: Any, - ) -> AsyncIterator[Dict[str, Any]]: + ) -> AsyncIterator[dict[str, Any]]: headers = self.ws_headers.copy() headers.update(kwargs.get("extra_headers", {})) - merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs: dict[str, Any] = {"origin": self.ws_origin} merged_kwargs.update(kwargs) merged_kwargs["extra_headers"] = headers @@ -185,24 +180,16 @@ async def execute_ws( yield data def _process_variables( - self, variables: Optional[Dict[str, Any]] - ) -> Tuple[ - Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] - ]: + self, variables: Optional[dict[str, Any]] + ) -> tuple[dict[str, Any], dict[str, tuple[str, IO[bytes], str]], dict[str, list[str]]]: if not variables: return {}, {}, {} serializable_variables = self._convert_dict_to_json_serializable(variables) return self._get_files_from_variables(serializable_variables) - def _convert_dict_to_json_serializable( - self, dict_: Dict[str, Any] - ) -> Dict[str, Any]: - return { - key: self._convert_value(value) - for key, value in dict_.items() - if value is not UNSET - } + def _convert_dict_to_json_serializable(self, dict_: dict[str, Any]) -> dict[str, Any]: + return {key: self._convert_value(value) for key, value in dict_.items() if value is not UNSET} def _convert_value(self, value: Any) -> Any: if isinstance(value, BaseModel): @@ -212,12 +199,10 @@ def _convert_value(self, value: Any) -> Any: return value def _get_files_from_variables( - self, variables: Dict[str, Any] - ) -> Tuple[ - Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] - ]: - files_map: Dict[str, List[str]] = {} - files_list: List[Upload] = [] + self, variables: dict[str, Any] + ) -> tuple[dict[str, Any], dict[str, tuple[str, IO[bytes], str]], dict[str, list[str]]]: + files_map: dict[str, list[str]] = {} + files_list: list[Upload] = [] def separate_files(path: str, obj: Any) -> Any: if isinstance(obj, list): @@ -247,7 +232,7 @@ def separate_files(path: str, obj: Any) -> Any: return obj nulled_variables = separate_files("variables", variables) - files: Dict[str, Tuple[str, IO[bytes], str]] = { + files: dict[str, tuple[str, IO[bytes], str]] = { str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) for i, file_ in enumerate(files_list) } @@ -257,9 +242,9 @@ async def _execute_multipart( self, query: str, operation_name: Optional[str], - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], + variables: dict[str, Any], + files: dict[str, tuple[str, IO[bytes], str]], + files_map: dict[str, list[str]], **kwargs: Any, ) -> httpx.Response: data = { @@ -274,21 +259,19 @@ async def _execute_multipart( "map": json.dumps(files_map, default=to_jsonable_python), } - return await self.http_client.post( - url=self.url, data=data, files=files, **kwargs - ) + return await self.http_client.post(url=self.url, data=data, files=files, **kwargs) async def _execute_json( self, query: str, operation_name: Optional[str], - variables: Dict[str, Any], + variables: dict[str, Any], **kwargs: Any, ) -> httpx.Response: - headers: Dict[str, str] = {"Content-Type": "application/json"} + headers: dict[str, str] = {"Content-Type": "application/json"} headers.update(kwargs.get("headers", {})) - merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs: dict[str, Any] = kwargs.copy() merged_kwargs["headers"] = headers return await self.http_client.post( @@ -305,9 +288,7 @@ async def _execute_json( ) async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: - payload: Dict[str, Any] = { - "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value - } + payload: dict[str, Any] = {"type": GraphQLTransportWSMessageType.CONNECTION_INIT.value} if self.ws_connection_init_payload: payload["payload"] = self.ws_connection_init_payload await websocket.send(json.dumps(payload)) @@ -318,17 +299,15 @@ async def _send_subscribe( operation_id: str, query: str, operation_name: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, + variables: Optional[dict[str, Any]] = None, ) -> None: - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "id": operation_id, "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, "payload": {"query": query, "operationName": operation_name}, } if variables: - payload["payload"]["variables"] = self._convert_dict_to_json_serializable( - variables - ) + payload["payload"]["variables"] = self._convert_dict_to_json_serializable(variables) await websocket.send(json.dumps(payload)) async def _handle_ws_message( @@ -336,7 +315,7 @@ async def _handle_ws_message( message: Data, websocket: WebSocketClientProtocol, expected_type: Optional[GraphQLTransportWSMessageType] = None, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[dict[str, Any]]: try: message_dict = json.loads(message) except json.JSONDecodeError as exc: @@ -349,24 +328,18 @@ async def _handle_ws_message( raise GraphQLClientInvalidMessageFormat(message=message) if expected_type and expected_type != type_: - raise GraphQLClientInvalidMessageFormat( - f"Invalid message received. Expected: {expected_type.value}" - ) + raise GraphQLClientInvalidMessageFormat(f"Invalid message received. Expected: {expected_type.value}") if type_ == GraphQLTransportWSMessageType.NEXT: if "data" not in payload: raise GraphQLClientInvalidMessageFormat(message=message) - return cast(Dict[str, Any], payload["data"]) + return cast(dict[str, Any], payload["data"]) if type_ == GraphQLTransportWSMessageType.COMPLETE: await websocket.close() elif type_ == GraphQLTransportWSMessageType.PING: - await websocket.send( - json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) - ) + await websocket.send(json.dumps({"type": GraphQLTransportWSMessageType.PONG.value})) elif type_ == GraphQLTransportWSMessageType.ERROR: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=payload, data=message_dict - ) + raise GraphQLClientGraphQLMultiError.from_errors_dicts(errors_dicts=payload, data=message_dict) return None diff --git a/openhexa/cli/graphql/graphql_client/base_client.py b/openhexa/graphql/graphql_client/base_client.py similarity index 76% rename from openhexa/cli/graphql/graphql_client/base_client.py rename to openhexa/graphql/graphql_client/base_client.py index e7ffb904..07d8213a 100644 --- a/openhexa/cli/graphql/graphql_client/base_client.py +++ b/openhexa/graphql/graphql_client/base_client.py @@ -1,7 +1,7 @@ # Generated by ariadne-codegen import json -from typing import IO, Any, Dict, List, Optional, Tuple, TypeVar, cast +from typing import IO, Any, Optional, TypeVar, cast import httpx from pydantic import BaseModel @@ -21,7 +21,7 @@ class BaseClient: def __init__( self, url: str = "", - headers: Optional[Dict[str, str]] = None, + headers: Optional[dict[str, str]] = None, http_client: Optional[httpx.Client] = None, ) -> None: self.url = url @@ -44,7 +44,7 @@ def execute( self, query: str, operation_name: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, + variables: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> httpx.Response: processed_variables, files, files_map = self._process_variables(variables) @@ -66,51 +66,37 @@ def execute( **kwargs, ) - def get_data(self, response: httpx.Response) -> Dict[str, Any]: + def get_data(self, response: httpx.Response) -> dict[str, Any]: if not response.is_success: - raise GraphQLClientHttpError( - status_code=response.status_code, response=response - ) + raise GraphQLClientHttpError(status_code=response.status_code, response=response) try: response_json = response.json() except ValueError as exc: raise GraphQLClientInvalidResponseError(response=response) from exc - if (not isinstance(response_json, dict)) or ( - "data" not in response_json and "errors" not in response_json - ): + if (not isinstance(response_json, dict)) or ("data" not in response_json and "errors" not in response_json): raise GraphQLClientInvalidResponseError(response=response) data = response_json.get("data") errors = response_json.get("errors") if errors: - raise GraphQLClientGraphQLMultiError.from_errors_dicts( - errors_dicts=errors, data=data - ) + raise GraphQLClientGraphQLMultiError.from_errors_dicts(errors_dicts=errors, data=data) - return cast(Dict[str, Any], data) + return cast(dict[str, Any], data) def _process_variables( - self, variables: Optional[Dict[str, Any]] - ) -> Tuple[ - Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] - ]: + self, variables: Optional[dict[str, Any]] + ) -> tuple[dict[str, Any], dict[str, tuple[str, IO[bytes], str]], dict[str, list[str]]]: if not variables: return {}, {}, {} serializable_variables = self._convert_dict_to_json_serializable(variables) return self._get_files_from_variables(serializable_variables) - def _convert_dict_to_json_serializable( - self, dict_: Dict[str, Any] - ) -> Dict[str, Any]: - return { - key: self._convert_value(value) - for key, value in dict_.items() - if value is not UNSET - } + def _convert_dict_to_json_serializable(self, dict_: dict[str, Any]) -> dict[str, Any]: + return {key: self._convert_value(value) for key, value in dict_.items() if value is not UNSET} def _convert_value(self, value: Any) -> Any: if isinstance(value, BaseModel): @@ -120,12 +106,10 @@ def _convert_value(self, value: Any) -> Any: return value def _get_files_from_variables( - self, variables: Dict[str, Any] - ) -> Tuple[ - Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] - ]: - files_map: Dict[str, List[str]] = {} - files_list: List[Upload] = [] + self, variables: dict[str, Any] + ) -> tuple[dict[str, Any], dict[str, tuple[str, IO[bytes], str]], dict[str, list[str]]]: + files_map: dict[str, list[str]] = {} + files_list: list[Upload] = [] def separate_files(path: str, obj: Any) -> Any: if isinstance(obj, list): @@ -155,7 +139,7 @@ def separate_files(path: str, obj: Any) -> Any: return obj nulled_variables = separate_files("variables", variables) - files: Dict[str, Tuple[str, IO[bytes], str]] = { + files: dict[str, tuple[str, IO[bytes], str]] = { str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) for i, file_ in enumerate(files_list) } @@ -165,9 +149,9 @@ def _execute_multipart( self, query: str, operation_name: Optional[str], - variables: Dict[str, Any], - files: Dict[str, Tuple[str, IO[bytes], str]], - files_map: Dict[str, List[str]], + variables: dict[str, Any], + files: dict[str, tuple[str, IO[bytes], str]], + files_map: dict[str, list[str]], **kwargs: Any, ) -> httpx.Response: data = { @@ -188,13 +172,13 @@ def _execute_json( self, query: str, operation_name: Optional[str], - variables: Dict[str, Any], + variables: dict[str, Any], **kwargs: Any, ) -> httpx.Response: - headers: Dict[str, str] = {"Content-Type": "application/json"} + headers: dict[str, str] = {"Content-Type": "application/json"} headers.update(kwargs.get("headers", {})) - merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs: dict[str, Any] = kwargs.copy() merged_kwargs["headers"] = headers return self.http_client.post( diff --git a/openhexa/cli/graphql/graphql_client/base_model.py b/openhexa/graphql/graphql_client/base_model.py similarity index 86% rename from openhexa/cli/graphql/graphql_client/base_model.py rename to openhexa/graphql/graphql_client/base_model.py index 76b84873..a93b416e 100644 --- a/openhexa/cli/graphql/graphql_client/base_model.py +++ b/openhexa/graphql/graphql_client/base_model.py @@ -2,7 +2,8 @@ from io import IOBase -from pydantic import BaseModel as PydanticBaseModel, ConfigDict +from pydantic import BaseModel as PydanticBaseModel +from pydantic import ConfigDict class UnsetType: diff --git a/openhexa/cli/graphql/graphql_client/client.py b/openhexa/graphql/graphql_client/client.py similarity index 81% rename from openhexa/cli/graphql/graphql_client/client.py rename to openhexa/graphql/graphql_client/client.py index 1bb62ae1..74d4992a 100644 --- a/openhexa/cli/graphql/graphql_client/client.py +++ b/openhexa/graphql/graphql_client/client.py @@ -1,7 +1,7 @@ # Generated by ariadne-codegen # Source: openhexa/cli/graphql/queries.graphql -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from .base_client import BaseClient from .base_model import UNSET, UnsetType @@ -20,7 +20,7 @@ def get_workspace_pipelines( name: Union[Optional[str], UnsetType] = UNSET, page: Union[Optional[int], UnsetType] = UNSET, per_page: Union[Optional[int], UnsetType] = UNSET, - **kwargs: Any + **kwargs: Any, ) -> GetWorkspacePipelines: query = gql( """ @@ -47,18 +47,13 @@ def get_workspace_pipelines( } """ ) - variables: Dict[str, object] = { + variables: dict[str, object] = { "workspaceSlug": workspace_slug, "name": name, "page": page, "perPage": per_page, } - response = self.execute( - query=query, - operation_name="getWorkspacePipelines", - variables=variables, - **kwargs - ) + response = self.execute(query=query, operation_name="getWorkspacePipelines", variables=variables, **kwargs) data = self.get_data(response) return GetWorkspacePipelines.model_validate(data) @@ -77,9 +72,7 @@ def get_countries(self, workspace_slug: str, **kwargs: Any) -> GetCountries: } """ ) - variables: Dict[str, object] = {"workspaceSlug": workspace_slug} - response = self.execute( - query=query, operation_name="getCountries", variables=variables, **kwargs - ) + variables: dict[str, object] = {"workspaceSlug": workspace_slug} + response = self.execute(query=query, operation_name="getCountries", variables=variables, **kwargs) data = self.get_data(response) return GetCountries.model_validate(data) diff --git a/openhexa/cli/graphql/graphql_client/enums.py b/openhexa/graphql/graphql_client/enums.py similarity index 100% rename from openhexa/cli/graphql/graphql_client/enums.py rename to openhexa/graphql/graphql_client/enums.py diff --git a/openhexa/cli/graphql/graphql_client/exceptions.py b/openhexa/graphql/graphql_client/exceptions.py similarity index 78% rename from openhexa/cli/graphql/graphql_client/exceptions.py rename to openhexa/graphql/graphql_client/exceptions.py index 9fbe116d..fe178d51 100644 --- a/openhexa/cli/graphql/graphql_client/exceptions.py +++ b/openhexa/graphql/graphql_client/exceptions.py @@ -1,6 +1,6 @@ # Generated by ariadne-codegen -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import httpx @@ -30,10 +30,10 @@ class GraphQLClientGraphQLError(GraphQLClientError): def __init__( self, message: str, - locations: Optional[List[Dict[str, int]]] = None, - path: Optional[List[str]] = None, - extensions: Optional[Dict[str, object]] = None, - orginal: Optional[Dict[str, object]] = None, + locations: Optional[list[dict[str, int]]] = None, + path: Optional[list[str]] = None, + extensions: Optional[dict[str, object]] = None, + orginal: Optional[dict[str, object]] = None, ): self.message = message self.locations = locations @@ -45,7 +45,7 @@ def __str__(self) -> str: return self.message @classmethod - def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + def from_dict(cls, error: dict[str, Any]) -> "GraphQLClientGraphQLError": return cls( message=error["message"], locations=error.get("locations"), @@ -58,8 +58,8 @@ def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": class GraphQLClientGraphQLMultiError(GraphQLClientError): def __init__( self, - errors: List[GraphQLClientGraphQLError], - data: Optional[Dict[str, Any]] = None, + errors: list[GraphQLClientGraphQLError], + data: Optional[dict[str, Any]] = None, ): self.errors = errors self.data = data @@ -69,7 +69,7 @@ def __str__(self) -> str: @classmethod def from_errors_dicts( - cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + cls, errors_dicts: list[dict[str, Any]], data: Optional[dict[str, Any]] = None ) -> "GraphQLClientGraphQLMultiError": return cls( errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], diff --git a/openhexa/cli/graphql/graphql_client/get_countries.py b/openhexa/graphql/graphql_client/get_countries.py similarity index 84% rename from openhexa/cli/graphql/graphql_client/get_countries.py rename to openhexa/graphql/graphql_client/get_countries.py index 2292c773..d7270c2a 100644 --- a/openhexa/cli/graphql/graphql_client/get_countries.py +++ b/openhexa/graphql/graphql_client/get_countries.py @@ -1,7 +1,7 @@ # Generated by ariadne-codegen # Source: openhexa/cli/graphql/queries.graphql -from typing import List, Optional +from typing import Optional from pydantic import Field @@ -13,7 +13,7 @@ class GetCountries(BaseModel): class GetCountriesWorkspace(BaseModel): - countries: List["GetCountriesWorkspaceCountries"] + countries: list["GetCountriesWorkspaceCountries"] class GetCountriesWorkspaceCountries(BaseModel): diff --git a/openhexa/cli/graphql/graphql_client/get_workspace.py b/openhexa/graphql/graphql_client/get_workspace.py similarity index 100% rename from openhexa/cli/graphql/graphql_client/get_workspace.py rename to openhexa/graphql/graphql_client/get_workspace.py diff --git a/openhexa/cli/graphql/graphql_client/get_workspace_pipelines.py b/openhexa/graphql/graphql_client/get_workspace_pipelines.py similarity index 84% rename from openhexa/cli/graphql/graphql_client/get_workspace_pipelines.py rename to openhexa/graphql/graphql_client/get_workspace_pipelines.py index e9976b8b..a24a48e1 100644 --- a/openhexa/cli/graphql/graphql_client/get_workspace_pipelines.py +++ b/openhexa/graphql/graphql_client/get_workspace_pipelines.py @@ -1,7 +1,7 @@ # Generated by ariadne-codegen # Source: openhexa/cli/graphql/queries.graphql -from typing import Any, List, Optional +from typing import Any, Optional from pydantic import Field @@ -15,7 +15,7 @@ class GetWorkspacePipelines(BaseModel): class GetWorkspacePipelinesPipelines(BaseModel): total_pages: int = Field(alias="totalPages") - items: List["GetWorkspacePipelinesPipelinesItems"] + items: list["GetWorkspacePipelinesPipelinesItems"] class GetWorkspacePipelinesPipelinesItems(BaseModel): @@ -23,9 +23,7 @@ class GetWorkspacePipelinesPipelinesItems(BaseModel): code: str name: Optional[str] type: PipelineType - current_version: Optional["GetWorkspacePipelinesPipelinesItemsCurrentVersion"] = ( - Field(alias="currentVersion") - ) + current_version: Optional["GetWorkspacePipelinesPipelinesItemsCurrentVersion"] = Field(alias="currentVersion") class GetWorkspacePipelinesPipelinesItemsCurrentVersion(BaseModel): diff --git a/openhexa/cli/graphql/graphql_client/input_types.py b/openhexa/graphql/graphql_client/input_types.py similarity index 95% rename from openhexa/cli/graphql/graphql_client/input_types.py rename to openhexa/graphql/graphql_client/input_types.py index f270fb10..a31bf826 100644 --- a/openhexa/cli/graphql/graphql_client/input_types.py +++ b/openhexa/graphql/graphql_client/input_types.py @@ -1,7 +1,7 @@ # Generated by ariadne-codegen # Source: openhexa/cli/graphql/schema.generated.graphql -from typing import Any, List, Optional +from typing import Any, Optional from pydantic import Field @@ -72,7 +72,7 @@ class CreateAccessmodProjectInput(BaseModel): country: "CountryInput" crs: int description: Optional[str] = None - extent: Optional[List[List[float]]] = None + extent: Optional[list[list[float]]] = None name: str spatial_resolution: int = Field(alias="spatialResolution") @@ -96,7 +96,7 @@ class CreateBucketFolderInput(BaseModel): class CreateConnectionInput(BaseModel): description: Optional[str] = None - fields: Optional[List["ConnectionFieldInput"]] = None + fields: Optional[list["ConnectionFieldInput"]] = None name: str slug: Optional[str] = None type: ConnectionType @@ -169,7 +169,7 @@ class CreateWebappInput(BaseModel): class CreateWorkspaceInput(BaseModel): - countries: Optional[List["CountryInput"]] = None + countries: Optional[list["CountryInput"]] = None description: Optional[str] = None load_sample_data: Optional[bool] = Field(alias="loadSampleData", default=None) name: str @@ -344,7 +344,7 @@ class OrganizationInput(BaseModel): class ParameterInput(BaseModel): - choices: Optional[List[Any]] = None + choices: Optional[list[Any]] = None code: str connection: Optional[str] = None default: Optional[Any] = None @@ -434,9 +434,7 @@ class RunPipelineInput(BaseModel): config: Any enable_debug_logs: Optional[bool] = Field(alias="enableDebugLogs", default=None) id: Any - send_mail_notifications: Optional[bool] = Field( - alias="sendMailNotifications", default=None - ) + send_mail_notifications: Optional[bool] = Field(alias="sendMailNotifications", default=None) version_id: Optional[Any] = Field(alias="versionId", default=None) @@ -468,9 +466,7 @@ class UpdateAccessmodAccessibilityAnalysisInput(BaseModel): algorithm: Optional[AccessmodAccessibilityAnalysisAlgorithm] = None barrier_id: Optional[str] = Field(alias="barrierId", default=None) dem_id: Optional[str] = Field(alias="demId", default=None) - health_facilities_id: Optional[str] = Field( - alias="healthFacilitiesId", default=None - ) + health_facilities_id: Optional[str] = Field(alias="healthFacilitiesId", default=None) id: str invert_direction: Optional[bool] = Field(alias="invertDirection", default=None) knight_move: Optional[bool] = Field(alias="knightMove", default=None) @@ -480,9 +476,7 @@ class UpdateAccessmodAccessibilityAnalysisInput(BaseModel): name: Optional[str] = None stack_id: Optional[str] = Field(alias="stackId", default=None) stack_priorities: Optional[Any] = Field(alias="stackPriorities", default=None) - transport_network_id: Optional[str] = Field( - alias="transportNetworkId", default=None - ) + transport_network_id: Optional[str] = Field(alias="transportNetworkId", default=None) water_all_touched: Optional[bool] = Field(alias="waterAllTouched", default=None) water_id: Optional[str] = Field(alias="waterId", default=None) @@ -515,14 +509,14 @@ class UpdateAccessmodZonalStatisticsInput(BaseModel): class UpdateConnectionInput(BaseModel): description: Optional[str] = None - fields: Optional[List["ConnectionFieldInput"]] = None + fields: Optional[list["ConnectionFieldInput"]] = None id: str name: Optional[str] = None slug: Optional[str] = None class UpdateDAGInput(BaseModel): - countries: Optional[List["CountryInput"]] = None + countries: Optional[list["CountryInput"]] = None description: Optional[str] = None id: Any label: Optional[str] = None @@ -604,7 +598,7 @@ class UpdateWebappInput(BaseModel): class UpdateWorkspaceInput(BaseModel): - countries: Optional[List["CountryInput"]] = None + countries: Optional[list["CountryInput"]] = None description: Optional[str] = None docker_image: Optional[str] = Field(alias="dockerImage", default=None) name: Optional[str] = None @@ -626,7 +620,7 @@ class UploadPipelineInput(BaseModel): description: Optional[str] = None external_link: Optional[Any] = Field(alias="externalLink", default=None) name: Optional[str] = None - parameters: List["ParameterInput"] + parameters: list["ParameterInput"] pipeline_code: Optional[str] = Field(alias="pipelineCode", default=None) timeout: Optional[int] = None workspace_slug: str = Field(alias="workspaceSlug") diff --git a/openhexa/cli/graphql/graphql_client/me.py b/openhexa/graphql/graphql_client/me.py similarity index 100% rename from openhexa/cli/graphql/graphql_client/me.py rename to openhexa/graphql/graphql_client/me.py diff --git a/openhexa/graphql/openhexa_client.py b/openhexa/graphql/openhexa_client.py new file mode 100644 index 00000000..c0fdf5ad --- /dev/null +++ b/openhexa/graphql/openhexa_client.py @@ -0,0 +1,178 @@ +"""OpenHexaClient is a class that provides methods to interact with the OpenHexa GraphQL API.""" +import logging +from datetime import datetime +from pathlib import Path + +import click +import requests +from graphql import build_client_schema, build_schema, get_introspection_query +from graphql.utilities import find_breaking_changes + +from openhexa.cli.settings import settings +from openhexa.graphql.graphql_client import Client +from openhexa.utils import create_requests_session + + +class APIError(Exception): + """Raised when an error occurs while interacting with the API.""" + + pass + + +class InvalidTokenError(APIError): + """Raised when the token is invalid.""" + + pass + + +class GraphQLError(APIError): + """Raised when a GraphQL request returns an error.""" + + pass + + +class OpenHexaClient(Client): + """OpenHexaClient is a class that provides methods to interact with the OpenHexa GraphQL API.""" + + def __init__(self, *, api_url: str, token: str): + """Initialize the OpenHexaClient with the OpenHexa API URL and headers.""" + self._url = settings.api_url + "/graphql/" + self._token = token or settings.access_token + + if not self._token: + raise InvalidTokenError("No token found for workspace") + + super().__init__(url=self._url, headers={}) + logging.getLogger("httpx").setLevel( + logging.WARNING + ) # HTTPX logs queries by default, we disable them here with WARNING level + + def execute(self, query, **kwargs): + """Decorate parent execute method to log the GraphQL query and response.""" + from openhexa.version import __version__ + + self.headers["User-Agent"] = f"openhexa-cli/{__version__}" + self.headers["Authorization"] = f"Bearer {self._token}" + + _detect_graphql_breaking_changes(token=self._token) + + if settings.debug: + click.echo("") + click.echo("Graphql Query:") + click.echo(f"URL: {self.url}") + click.echo(f"Query: {query}") + variables = kwargs.get("variables", {}) + click.echo(f"Variables: {variables}") + + response = super().execute(query=query, **kwargs) + + if settings.debug: + click.echo("") + click.echo("Graphql Response:") + click.echo(f"Response: {response}") + + return response + + +def _detect_graphql_breaking_changes_if_needed(token): + """Detect breaking changes if not done recently between the schema referenced in the SDK and the server using graphql-core.""" + ONE_HOUR = 60 * 60 + now_timestamp = int(datetime.now().timestamp()) + if not settings.last_breaking_change_check or now_timestamp - settings.last_breaking_change_check > ONE_HOUR: + _detect_graphql_breaking_changes(token) + settings.last_breaking_change_check = now_timestamp + + +def _detect_graphql_breaking_changes(token): + """Detect breaking changes between the schema referenced in the SDK and the server using graphql-core.""" + stored_schema_obj = build_schema((Path(__file__).parent / "graphql" / "schema.generated.graphql").open().read()) + server_schema_obj = build_client_schema( + _query_graphql(get_introspection_query(input_value_deprecation=True), token=token) + ) + + breaking_changes = find_breaking_changes(stored_schema_obj, server_schema_obj) + if breaking_changes: + current_version, latest_version = get_library_versions() + click.secho( + f"⚠️ Breaking changes detected between the SDK (version {current_version}) and the server:", + fg="red", + ) + for change in breaking_changes: + click.secho(f"- {change.description}", fg="yellow") + click.secho( + "This could lead to unexpected results.\n" + f"Please update the SDK to the latest version {latest_version} " + f"(using `pip install openhexa-sdk=={latest_version}`) or use a version of the SDK compatible with the server.", + fg="red", + ) + + +def graphql(query: str, variables=None, token=None): + """Check that there is no breaking change and perform a GraphQL request.""" + _detect_graphql_breaking_changes_if_needed(token) + return _query_graphql(query, variables, token) + + +def _query_graphql(query: str, variables=None, token=None): + """Perform a GraphQL request.""" + from openhexa.version import __version__ + + url = settings.api_url + "/graphql/" + if token is None: + token = settings.access_token + + if token is None: + raise InvalidTokenError("No token found for workspace") + + if settings.debug: + click.echo("") + click.echo("Graphql Query:") + click.echo(f"URL: {url}") + click.echo(f"Query: {query}") + click.echo(f"Variables: {variables}") + + session = create_requests_session() + + response = session.post( + url, + headers={ + "User-Agent": f"openhexa-cli/{__version__}", + "Authorization": f"Bearer {token}", + }, + json={"query": query, "variables": variables}, + ) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + raise GraphQLError(str(e)) + + data = response.json() + + if settings.debug: + click.echo("Graphql Response:") + click.echo(data) + click.echo("") + + if data.get("errors"): + if data.get("errors")[0].get("extensions", {}).get("code") == "UNAUTHENTICATED": + raise InvalidTokenError + raise GraphQLError(data["errors"]) + return data["data"] + + +def get_library_versions() -> tuple[str, str]: + """Return the current version and the one on PyPi.""" + # Get the currently installed version + from openhexa.version import __version__ as installed_version + + # Get the latest version available on PyPI + try: + response = requests.get("https://pypi.org/pypi/openhexa.sdk/json") + latest_version = response.json()["info"]["version"] + return installed_version, latest_version + except requests.RequestException: + logging.error( + "Could not check for the latest version of the openhexa.sdk package.", + exc_info=True, + ) + return installed_version, installed_version diff --git a/openhexa/cli/graphql/queries.graphql b/openhexa/graphql/queries.graphql similarity index 100% rename from openhexa/cli/graphql/queries.graphql rename to openhexa/graphql/queries.graphql diff --git a/openhexa/cli/graphql/schema.generated.graphql b/openhexa/graphql/schema.generated.graphql similarity index 100% rename from openhexa/cli/graphql/schema.generated.graphql rename to openhexa/graphql/schema.generated.graphql diff --git a/openhexa/sdk/workspaces/current_workspace.py b/openhexa/sdk/workspaces/current_workspace.py index 189edbab..3ea4bc09 100644 --- a/openhexa/sdk/workspaces/current_workspace.py +++ b/openhexa/sdk/workspaces/current_workspace.py @@ -7,9 +7,10 @@ from dataclasses import fields, make_dataclass from warnings import warn +from openhexa.graphql.graphql_client import GetCountriesWorkspaceCountries +from openhexa.graphql.openhexa_client import OpenHexaClient from openhexa.utils import stringcase -from ...cli.graphql.graphql_client import GetCountriesWorkspaceCountries from ..datasets import Dataset from ..utils import graphql from .connection import ( @@ -63,8 +64,6 @@ def slug(self) -> str: @property def countries(self) -> list[GetCountriesWorkspaceCountries]: """The countries of the workspace.""" - from openhexa.cli.api import OpenHexaClient - try: return OpenHexaClient().get_countries(workspace_slug=self.slug).workspace.countries except KeyError: diff --git a/openhexa/version.py b/openhexa/version.py new file mode 100644 index 00000000..5c8dc85a --- /dev/null +++ b/openhexa/version.py @@ -0,0 +1,7 @@ +"""Openhexa SDK version module.""" +from importlib.metadata import version + +try: + __version__ = version("openhexa.sdk") +except Exception: + __version__ = "unknown" diff --git a/pyproject.toml b/pyproject.toml index 05db4259..9193529f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ include-package-data = true [tool.ruff] line-length = 120 ignore = ["E501"] -exclude = ["openhexa/cli/graphql/graphql_client"] +exclude = ["openhexa/graphql/graphql_client"] per-file-ignores = { "tests/**/test_*.py" = ["D100","D101","D102", "D103"] } # Ignore missing docstrings in tests [tool.ruff.lint] diff --git a/tests/test_detect_breaking_changes.py b/tests/test_detect_breaking_changes.py index 4805b01c..6db8b824 100644 --- a/tests/test_detect_breaking_changes.py +++ b/tests/test_detect_breaking_changes.py @@ -1,12 +1,12 @@ import time from unittest import TestCase, mock -from openhexa.cli.api import _detect_graphql_breaking_changes, graphql +from openhexa.graphql.openhexa_client import _detect_graphql_breaking_changes, graphql class TestGraphQLFunctions(TestCase): - @mock.patch("openhexa.cli.api._query_graphql") - @mock.patch("openhexa.cli.api.get_library_versions") + @mock.patch("openhexa.graphql.openhexa_client._query_graphql") + @mock.patch("openhexa.graphql.openhexa_client.get_library_versions") def test_detect_graphql_breaking_changes_with_mocked_server_schema( self, mock_get_library_versions, mock_query_graphql ): @@ -53,11 +53,11 @@ def test_detect_graphql_breaking_changes_with_mocked_server_schema( ) mock_click_secho.assert_any_call("- Query.testField changed type from Int to String.", fg="yellow") - @mock.patch("openhexa.cli.api._query_graphql") - @mock.patch("openhexa.cli.api._detect_graphql_breaking_changes") + @mock.patch("openhexa.graphql.openhexa_client._query_graphql") + @mock.patch("openhexa.graphql.openhexa_client._detect_graphql_breaking_changes") def test_graphql(self, mock_detect_graphql_breaking_changes, mock_query_graphql): """Test that the graphql function is caching the breaking change detection for 1 hour.""" - with mock.patch("openhexa.cli.api.settings") as mock_settings: + with mock.patch("openhexa.graphql.openhexa_client.settings") as mock_settings: mock_settings.last_breaking_change_check = time.time() - 59 * 60 # Last checked 59 minutes ago mock_query_graphql.return_value = {"data": "response"}