diff --git a/.gitignore b/.gitignore index 258a8e94..5c6986ab 100644 --- a/.gitignore +++ b/.gitignore @@ -14,11 +14,14 @@ doc/_build doc/generated doc/api .vscode +.idea Highs.log paper/ monkeytype.sqlite3 .github/copilot-instructions.md uv.lock +*.pyc + # Environments .env diff --git a/doc/release_notes.rst b/doc/release_notes.rst index d3721743..99225933 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -10,6 +10,7 @@ Version 0.5.6 * Improved variable/expression arithmetic methods so that they correctly handle types * Gurobi: Pass dictionary as env argument `env={...}` through to gurobi env creation +* Added integration with OETC platform * Mosek: Remove explicit use of Env, use global env instead * Objectives can now be created from variables via `linopy.Model.add_objective`. diff --git a/linopy/__init__.py b/linopy/__init__.py index 88ee5251..3efc297a 100644 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -20,7 +20,7 @@ from linopy.io import read_netcdf from linopy.model import Model, Variable, Variables, available_solvers from linopy.objective import Objective -from linopy.remote import RemoteHandler +from linopy.remote import OetcHandler, RemoteHandler __all__ = ( "Constraint", @@ -31,6 +31,7 @@ "LinearExpression", "Model", "Objective", + "OetcHandler", "QuadraticExpression", "RemoteHandler", "Variable", diff --git a/linopy/model.py b/linopy/model.py index 0255d00e..ef0d2ff9 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -58,6 +58,7 @@ ) from linopy.matrices import MatrixAccessor from linopy.objective import Objective +from linopy.remote import OetcHandler, RemoteHandler from linopy.solvers import IO_APIS, available_solvers, quadratic_solvers from linopy.types import ( ConstantLike, @@ -1050,7 +1051,7 @@ def solve( sanitize_zeros: bool = True, sanitize_infinities: bool = True, slice_size: int = 2_000_000, - remote: Any = None, + remote: RemoteHandler | OetcHandler = None, # type: ignore progress: bool | None = None, **solver_options: Any, ) -> tuple[str, str]: @@ -1111,7 +1112,7 @@ def solve( Size of the slice to use for writing the lp file. The slice size is used to split large variables and constraints into smaller chunks to avoid memory issues. The default is 2_000_000. - remote : linopy.remote.RemoteHandler + remote : linopy.remote.RemoteHandler | linopy.oetc.OetcHandler, optional Remote handler to use for solving model on a server. Note that when solving on a rSee linopy.remote.RemoteHandler for more details. @@ -1137,20 +1138,23 @@ def solve( f"Keyword argument `io_api` has to be one of {IO_APIS} or None" ) - if remote: - solved = remote.solve_on_remote( - self, - solver_name=solver_name, - io_api=io_api, - problem_fn=problem_fn, - solution_fn=solution_fn, - log_fn=log_fn, - basis_fn=basis_fn, - warmstart_fn=warmstart_fn, - keep_files=keep_files, - sanitize_zeros=sanitize_zeros, - **solver_options, - ) + if remote is not None: + if isinstance(remote, OetcHandler): + solved = remote.solve_on_oetc(self) + else: + solved = remote.solve_on_remote( + self, + solver_name=solver_name, + io_api=io_api, + problem_fn=problem_fn, + solution_fn=solution_fn, + log_fn=log_fn, + basis_fn=basis_fn, + warmstart_fn=warmstart_fn, + keep_files=keep_files, + sanitize_zeros=sanitize_zeros, + **solver_options, + ) self.objective.set_value(solved.objective.value) self.status = solved.status diff --git a/linopy/remote/__init__.py b/linopy/remote/__init__.py new file mode 100644 index 00000000..0ae1df26 --- /dev/null +++ b/linopy/remote/__init__.py @@ -0,0 +1,19 @@ +""" +Remote execution handlers for linopy models. + +This module provides different handlers for executing optimization models +on remote systems: + +- RemoteHandler: SSH-based remote execution using paramiko +- OetcHandler: Cloud-based execution via OET Cloud service +""" + +from linopy.remote.oetc import OetcCredentials, OetcHandler, OetcSettings +from linopy.remote.ssh import RemoteHandler + +__all__ = [ + "RemoteHandler", + "OetcHandler", + "OetcSettings", + "OetcCredentials", +] diff --git a/linopy/remote/oetc.py b/linopy/remote/oetc.py new file mode 100644 index 00000000..7f1c7b22 --- /dev/null +++ b/linopy/remote/oetc.py @@ -0,0 +1,616 @@ +import base64 +import gzip +import json +import logging +import os +import tempfile +import time +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from enum import Enum + +import requests +from google.cloud import storage +from google.oauth2 import service_account +from requests import RequestException + +import linopy + +logger = logging.getLogger(__name__) + + +class ComputeProvider(str, Enum): + GCP = "GCP" + + +@dataclass +class OetcCredentials: + email: str + password: str + + +@dataclass +class OetcSettings: + credentials: OetcCredentials + name: str + authentication_server_url: str + orchestrator_server_url: str + compute_provider: ComputeProvider = ComputeProvider.GCP + solver: str = "highs" + solver_options: dict = field(default_factory=dict) + cpu_cores: int = 2 + disk_space_gb: int = 10 + delete_worker_on_error: bool = False + + +@dataclass +class GcpCredentials: + gcp_project_id: str + gcp_service_key: str + input_bucket: str + solution_bucket: str + + +@dataclass +class AuthenticationResult: + token: str + token_type: str + expires_in: int # value represented in seconds + authenticated_at: datetime + + @property + def expires_at(self) -> datetime: + """Calculate when the token expires""" + return self.authenticated_at + timedelta(seconds=self.expires_in) + + @property + def is_expired(self) -> bool: + """Check if the token has expired""" + return datetime.now() >= self.expires_at + + +@dataclass +class JobResult: + uuid: str + status: str + name: str | None = None + owner: str | None = None + solver: str | None = None + duration_in_seconds: int | None = None + solving_duration_in_seconds: int | None = None + input_files: list | None = None + output_files: list | None = None + created_at: str | None = None + + +class OetcHandler: + def __init__(self, settings: OetcSettings) -> None: + self.settings = settings + self.jwt = self.__sign_in() + self.cloud_provider_credentials = self.__get_cloud_provider_credentials() + + def __sign_in(self) -> AuthenticationResult: + """ + Authenticate with the server and return the authentication result. + + Returns: + AuthenticationResult: The complete authentication result including token and expiration info + + Raises: + Exception: If authentication fails or response is invalid + """ + try: + logger.info("OETC - Signing in...") + payload = { + "email": self.settings.credentials.email, + "password": self.settings.credentials.password, + } + + response = requests.post( + f"{self.settings.authentication_server_url}/sign-in", + json=payload, + headers={"Content-Type": "application/json"}, + timeout=30, + ) + + response.raise_for_status() + jwt_result = response.json() + + logger.info("OETC - Signed in") + + return AuthenticationResult( + token=jwt_result["token"], + token_type=jwt_result["token_type"], + expires_in=jwt_result["expires_in"], + authenticated_at=datetime.now(), + ) + + except RequestException as e: + raise Exception(f"Authentication request failed: {e}") + except KeyError as e: + raise Exception(f"Invalid response format: missing field {e}") + except Exception as e: + raise Exception(f"Authentication error: {e}") + + def _decode_jwt_payload(self, token: str) -> dict: + """ + Decode JWT payload without verification to extract user information. + + Args: + token: The JWT token + + Returns: + dict: The decoded payload containing user information + + Raises: + Exception: If token cannot be decoded + """ + try: + payload_part = token.split(".")[1] + payload_part += "=" * (4 - len(payload_part) % 4) + payload_bytes = base64.urlsafe_b64decode(payload_part) + return json.loads(payload_bytes.decode("utf-8")) + except (IndexError, json.JSONDecodeError, Exception) as e: + raise Exception(f"Failed to decode JWT payload: {e}") + + def __get_cloud_provider_credentials(self) -> GcpCredentials: + """ + Fetch cloud provider credentials based on the configured provider. + + Returns: + GcpCredentials: The cloud provider credentials + + Raises: + Exception: If the compute provider is not supported + """ + if self.settings.compute_provider == ComputeProvider.GCP: + return self.__get_gcp_credentials() + else: + raise Exception( + f"Unsupported compute provider: {self.settings.compute_provider}" + ) + + def __get_gcp_credentials(self) -> GcpCredentials: + """ + Fetch GCP credentials for the authenticated user. + + Returns: + GcpCredentials: The GCP credentials including project ID, service key, and bucket information + + Raises: + Exception: If credentials fetching fails or response is invalid + """ + try: + logger.info("OETC - Fetching user GCP credentials...") + payload = self._decode_jwt_payload(self.jwt.token) + user_uuid = payload.get("sub") + + if not user_uuid: + raise Exception("User UUID not found in JWT token") + + response = requests.get( + f"{self.settings.authentication_server_url}/users/{user_uuid}/gcp-credentials", + headers={ + "Authorization": f"{self.jwt.token_type} {self.jwt.token}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + response.raise_for_status() + credentials_data = response.json() + + logger.info("OETC - Fetched user GCP credentials") + + return GcpCredentials( + gcp_project_id=credentials_data["gcp_project_id"], + gcp_service_key=credentials_data["gcp_service_key"], + input_bucket=credentials_data["input_bucket"], + solution_bucket=credentials_data["solution_bucket"], + ) + + except RequestException as e: + raise Exception(f"Failed to fetch GCP credentials: {e}") + except KeyError as e: + raise Exception(f"Invalid credentials response format: missing field {e}") + except Exception as e: + raise Exception(f"Error fetching GCP credentials: {e}") + + def _submit_job_to_compute_service(self, input_file_name: str) -> str: + """ + Submit a job to the compute service. + + Args: + input_file_name: Name of the input file uploaded to GCP + + Returns: + CreateComputeJobResult: The job creation result with UUID + + Raises: + Exception: If job submission fails + """ + try: + logger.info("OETC - Submitting compute job...") + payload = { + "name": self.settings.name, + "solver": self.settings.solver, + "solver_options": self.settings.solver_options, + "provider": self.settings.compute_provider.value, + "cpu_cores": self.settings.cpu_cores, + "disk_space_gb": self.settings.disk_space_gb, + "input_file_name": input_file_name, + "delete_worker_on_error": self.settings.delete_worker_on_error, + } + + response = requests.post( + f"{self.settings.orchestrator_server_url}/compute-job/create", + json=payload, + headers={ + "Authorization": f"{self.jwt.token_type} {self.jwt.token}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + response.raise_for_status() + job_result = response.json() + + logger.info(f"OETC - Compute job {job_result['uuid']} started") + + return job_result["uuid"] + + except RequestException as e: + raise Exception(f"Failed to submit job to compute service: {e}") + except KeyError as e: + raise Exception( + f"Invalid job submission response format: missing field {e}" + ) + except Exception as e: + raise Exception(f"Error submitting job to compute service: {e}") + + def wait_and_get_job_data( + self, + job_uuid: str, + initial_poll_interval: int = 30, + max_poll_interval: int = 300, + ) -> JobResult: + """ + Wait for job completion and get job data by polling the orchestrator service. + + This method will poll indefinitely until the job finishes (FINISHED) or + encounters an error (SETUP_ERROR, RUNTIME_ERROR). + + Args: + job_uuid: UUID of the job to wait for + initial_poll_interval: Initial polling interval in seconds (default: 30) + max_poll_interval: Maximum polling interval in seconds (default: 300) + + Returns: + JobResult: The job result when complete + + Raises: + Exception: If job encounters errors or network requests consistently fail + """ + poll_interval = initial_poll_interval + consecutive_failures = 0 + max_network_retries = 10 + + logger.info(f"OETC - Waiting for job {job_uuid} to complete...") + + while True: + try: + response = requests.get( + f"{self.settings.orchestrator_server_url}/compute-job/{job_uuid}", + headers={ + "Authorization": f"{self.jwt.token_type} {self.jwt.token}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + response.raise_for_status() + job_data_dict = response.json() + + job_result = JobResult( + uuid=job_data_dict["uuid"], + status=job_data_dict["status"], + name=job_data_dict.get("name"), + owner=job_data_dict.get("owner"), + solver=job_data_dict.get("solver"), + duration_in_seconds=job_data_dict.get("duration_in_seconds"), + solving_duration_in_seconds=job_data_dict.get( + "solving_duration_in_seconds" + ), + input_files=job_data_dict.get("input_files", []), + output_files=job_data_dict.get("output_files", []), + created_at=job_data_dict.get("created_at"), + ) + + consecutive_failures = 0 + + if job_result.status == "FINISHED": + logger.info(f"OETC - Job {job_uuid} completed successfully!") + if not job_result.output_files: + logger.warning( + "OETC - Warning: Job completed but no output files found" + ) + return job_result + + elif job_result.status == "SETUP_ERROR": + error_msg = f"Job failed during setup phase (status: {job_result.status}). Please check the OETC logs for details." + logger.error(f"OETC Error: {error_msg}") + raise Exception(error_msg) + + elif job_result.status == "RUNTIME_ERROR": + error_msg = f"Job failed during execution (status: {job_result.status}). Please check the OETC logs for details." + logger.error(f"OETC Error: {error_msg}") + raise Exception(error_msg) + + elif job_result.status in ["PENDING", "STARTING", "RUNNING"]: + status_msg = f"Job {job_uuid} status: {job_result.status}" + if job_result.duration_in_seconds: + status_msg += ( + f" (running for {job_result.duration_in_seconds}s)" + ) + status_msg += f", checking again in {poll_interval} seconds..." + logger.info(f"OETC - {status_msg}") + + time.sleep(poll_interval) + + # Exponential backoff for polling interval, capped at max_poll_interval + poll_interval = min(int(poll_interval * 1.5), max_poll_interval) + + else: + # Unknown status + error_msg = f"Unknown job status: {job_result.status}. Please check the OETC logs for details." + logger.error(f"OETC Error: {error_msg}") + raise Exception(error_msg) + + except RequestException as e: + consecutive_failures += 1 + + if consecutive_failures >= max_network_retries: + raise Exception( + f"Failed to get job status after {max_network_retries} network retries: {e}" + ) + + # Wait before retrying network request + retry_wait = min(consecutive_failures * 10, 60) + logger.error( + f"OETC - Network error getting job status (attempt {consecutive_failures}/{max_network_retries}), " + f"retrying in {retry_wait} seconds: {e}" + ) + time.sleep(retry_wait) + + except KeyError as e: + raise Exception( + f"Invalid job status response format: missing field {e}" + ) + except Exception as e: + if "status:" in str(e) or "OETC logs" in str(e): + raise + else: + raise Exception(f"Error getting job status: {e}") + + def _gzip_decompress(self, input_path: str) -> str: + """ + Decompress a gzip-compressed file. + + Args: + input_path: Path to the compressed file + + Returns: + str: Path to the decompressed file + + Raises: + Exception: If decompression fails + """ + try: + logger.info(f"OETC - Decompressing file: {input_path}") + output_path = input_path[:-3] + chunk_size = 1024 * 1024 + + with gzip.open(input_path, "rb") as f_in: + with open(output_path, "wb") as f_out: + while True: + chunk = f_in.read(chunk_size) + if not chunk: + break + f_out.write(chunk) + + logger.info(f"OETC - File decompressed successfully: {output_path}") + return output_path + except Exception as e: + raise Exception(f"Failed to decompress file: {e}") + + def _download_file_from_gcp(self, file_name: str) -> str: + """ + Download a file from GCP storage bucket. + + Args: + file_name: Name of the file to download from the solution bucket + + Returns: + str: Path to the downloaded and decompressed file + + Raises: + Exception: If download or decompression fails + """ + try: + logger.info(f"OETC - Downloading file from GCP: {file_name}") + + # Create GCP credentials from service key + service_key_dict = json.loads( + self.cloud_provider_credentials.gcp_service_key + ) + credentials = service_account.Credentials.from_service_account_info( + service_key_dict, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + # Download from GCP solution bucket + storage_client = storage.Client( + credentials=credentials, + project=self.cloud_provider_credentials.gcp_project_id, + ) + bucket = storage_client.bucket( + self.cloud_provider_credentials.solution_bucket + ) + blob = bucket.blob(file_name) + + # Create temporary file for download + with tempfile.NamedTemporaryFile(delete=False, suffix=".gz") as temp_file: + compressed_file_path = temp_file.name + + blob.download_to_filename(compressed_file_path) + logger.info(f"OETC - File downloaded from GCP successfully: {file_name}") + + # Decompress the downloaded file + decompressed_file_path = self._gzip_decompress(compressed_file_path) + + # Clean up compressed file + os.remove(compressed_file_path) + + return decompressed_file_path + + except Exception as e: + raise Exception(f"Failed to download file from GCP: {e}") + + def solve_on_oetc(self, model): # type: ignore + """ + Solve a linopy model on the OET Cloud compute app. + + Parameters + ---------- + model : linopy.model.Model + + Returns + ------- + linopy.model.Model + Solved model. + + Raises + ------ + Exception: If solving fails at any stage + """ + try: + # Save model to temporary file and upload + with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: + model.to_netcdf(fn.name) + input_file_name = self._upload_file_to_gcp(fn.name) + + # Submit job and wait for completion + job_uuid = self._submit_job_to_compute_service(input_file_name) + job_result = self.wait_and_get_job_data(job_uuid) + + # Download and load the solution + if not job_result.output_files: + raise Exception("No output files found in completed job") + + output_file_name = job_result.output_files[0] + if isinstance(output_file_name, dict) and "name" in output_file_name: + output_file_name = output_file_name["name"] + + solution_file_path = self._download_file_from_gcp(output_file_name) + + # Load the solved model + solved_model = linopy.read_netcdf(solution_file_path) + + # Clean up downloaded file + os.remove(solution_file_path) + + logger.info( + f"OETC - Model solved successfully. Status: {solved_model.status}" + ) + if hasattr(solved_model, "objective") and hasattr( + solved_model.objective, "value" + ): + logger.info( + f"OETC - Objective value: {solved_model.objective.value:.2e}" + ) + + return solved_model + + except Exception as e: + raise Exception(f"Error solving model on OETC: {e}") + + def _gzip_compress(self, source_path: str) -> str: + """ + Compress a file using gzip compression. + + Args: + source_path: Path to the source file to compress + + Returns: + str: Path to the compressed file + + Raises: + Exception: If compression fails + """ + try: + logger.info(f"OETC - Compressing file: {source_path}") + output_path = source_path + ".gz" + chunk_size = 1024 * 1024 + + with open(source_path, "rb") as f_in: + with gzip.open(output_path, "wb", compresslevel=9) as f_out: + while True: + chunk = f_in.read(chunk_size) + if not chunk: + break + f_out.write(chunk) + + logger.info(f"OETC - File compressed successfully: {output_path}") + return output_path + except Exception as e: + raise Exception(f"Failed to compress file: {e}") + + def _upload_file_to_gcp(self, file_path: str) -> str: + """ + Upload a file to GCP storage bucket after compression. + + Args: + file_path: Path to the file to upload + + Returns: + str: Name of the uploaded file in the bucket + + Raises: + Exception: If upload fails + """ + try: + compressed_file_path = self._gzip_compress(file_path) + compressed_file_name = os.path.basename(compressed_file_path) + + logger.info(f"OETC - Uploading file to GCP: {compressed_file_name}") + + # Create GCP credentials from service key + service_key_dict = json.loads( + self.cloud_provider_credentials.gcp_service_key + ) + credentials = service_account.Credentials.from_service_account_info( + service_key_dict, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + # Upload to GCP bucket + storage_client = storage.Client( + credentials=credentials, + project=self.cloud_provider_credentials.gcp_project_id, + ) + bucket = storage_client.bucket(self.cloud_provider_credentials.input_bucket) + blob = bucket.blob(compressed_file_name) + + blob.upload_from_filename(compressed_file_path) + + logger.info( + f"OETC - File uploaded to GCP successfully: {compressed_file_name}" + ) + + # Clean up compressed file + os.remove(compressed_file_path) + + return compressed_file_name + + except Exception as e: + raise Exception(f"Failed to upload file to GCP: {e}") diff --git a/linopy/remote.py b/linopy/remote/ssh.py similarity index 100% rename from linopy/remote.py rename to linopy/remote/ssh.py diff --git a/pyproject.toml b/pyproject.toml index c2b48ee5..18fdd8aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,8 @@ dependencies = [ "polars", "tqdm", "deprecation", + "google-cloud-storage", + "requests" ] [project.urls] @@ -63,6 +65,7 @@ dev = [ "netcdf4", "paramiko", "types-paramiko", + "types-requests", "gurobipy", "highspy", ] diff --git a/test/remote/test_oetc.py b/test/remote/test_oetc.py new file mode 100644 index 00000000..2f4372a4 --- /dev/null +++ b/test/remote/test_oetc.py @@ -0,0 +1,1680 @@ +import base64 +import json +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest +import requests +from requests import RequestException + +from linopy.remote.oetc import ( + AuthenticationResult, + ComputeProvider, + GcpCredentials, + JobResult, + OetcCredentials, + OetcHandler, + OetcSettings, +) + + +@pytest.fixture +def sample_jwt_token() -> str: + """Create a sample JWT token with test payload""" + payload = { + "iss": "OETC", + "sub": "user-uuid-123", + "exp": 1640995200, + "jti": "jwt-id-456", + "email": "test@example.com", + "firstname": "Test", + "lastname": "User", + } + + # Create a simple JWT-like token (header.payload.signature) + header = ( + base64.urlsafe_b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()) + .decode() + .rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) + signature = "fake_signature" + + return f"{header}.{payload_encoded}.{signature}" + + +@pytest.fixture +def mock_gcp_credentials_response() -> dict: + """Create a mock GCP credentials response""" + return { + "gcp_project_id": "test-project-123", + "gcp_service_key": "test-service-key-content", + "input_bucket": "test-input-bucket", + "solution_bucket": "test-solution-bucket", + } + + +@pytest.fixture +def mock_settings() -> OetcSettings: + """Create mock settings for testing""" + credentials = OetcCredentials(email="test@example.com", password="test_password") + return OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + +class TestOetcHandler: + @pytest.fixture + def mock_jwt_response(self) -> dict: + """Create a mock JWT response""" + return { + "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "Bearer", + "expires_in": 3600, + } + + @patch("linopy.remote.oetc.requests.post") + @patch("linopy.remote.oetc.requests.get") + @patch("linopy.remote.oetc.datetime") + def test_successful_authentication( + self, + mock_datetime: Mock, + mock_get: Mock, + mock_post: Mock, + mock_settings: OetcSettings, + mock_jwt_response: dict, + mock_gcp_credentials_response: dict, + sample_jwt_token: str, + ) -> None: + """Test successful authentication flow""" + # Setup mocks + fixed_time = datetime(2024, 1, 15, 12, 0, 0) + mock_datetime.now.return_value = fixed_time + + # Mock authentication response + mock_auth_response = Mock() + mock_jwt_response["token"] = sample_jwt_token + mock_auth_response.json.return_value = mock_jwt_response + mock_auth_response.raise_for_status.return_value = None + mock_post.return_value = mock_auth_response + + # Mock GCP credentials response + mock_gcp_response = Mock() + mock_gcp_response.json.return_value = mock_gcp_credentials_response + mock_gcp_response.raise_for_status.return_value = None + mock_get.return_value = mock_gcp_response + + # Execute + handler = OetcHandler(mock_settings) + + # Verify authentication request + mock_post.assert_called_once_with( + "https://auth.example.com/sign-in", + json={"email": "test@example.com", "password": "test_password"}, + headers={"Content-Type": "application/json"}, + timeout=30, + ) + + # Verify GCP credentials request + mock_get.assert_called_once_with( + "https://auth.example.com/users/user-uuid-123/gcp-credentials", + headers={ + "Authorization": f"Bearer {sample_jwt_token}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + # Verify AuthenticationResult + assert isinstance(handler.jwt, AuthenticationResult) + assert handler.jwt.token == sample_jwt_token + assert handler.jwt.token_type == "Bearer" + assert handler.jwt.expires_in == 3600 + assert handler.jwt.authenticated_at == fixed_time + + # Verify GcpCredentials + assert isinstance(handler.cloud_provider_credentials, GcpCredentials) + assert handler.cloud_provider_credentials.gcp_project_id == "test-project-123" + assert ( + handler.cloud_provider_credentials.gcp_service_key + == "test-service-key-content" + ) + assert handler.cloud_provider_credentials.input_bucket == "test-input-bucket" + assert ( + handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" + ) + + @patch("linopy.remote.oetc.requests.post") + def test_authentication_http_error( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: + """Test authentication failure with HTTP error""" + # Setup mock to raise HTTP error + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError( + "401 Unauthorized" + ) + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + +class TestJwtDecoding: + @pytest.fixture + def handler_with_mocked_auth(self) -> OetcHandler: + """Create handler with mocked authentication for testing JWT decoding""" + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + # Mock the authentication and credentials fetching + mock_auth_result = AuthenticationResult( + token="fake.token.here", + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now(), + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = mock_auth_result + handler.cloud_provider_credentials = None # type: ignore + + return handler + + def test_decode_jwt_payload_success( + self, handler_with_mocked_auth: OetcHandler, sample_jwt_token: str + ) -> None: + """Test successful JWT payload decoding""" + result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) + + assert result["iss"] == "OETC" + assert result["sub"] == "user-uuid-123" + assert result["email"] == "test@example.com" + assert result["firstname"] == "Test" + assert result["lastname"] == "User" + + def test_decode_jwt_payload_invalid_token( + self, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test JWT payload decoding with invalid token""" + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._decode_jwt_payload("invalid.token") + + assert "Failed to decode JWT payload" in str(exc_info.value) + + def test_decode_jwt_payload_malformed_token( + self, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test JWT payload decoding with malformed token""" + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._decode_jwt_payload("not_a_jwt_token") + + assert "Failed to decode JWT payload" in str(exc_info.value) + + +class TestCloudProviderCredentials: + @pytest.fixture + def handler_with_mocked_auth(self, sample_jwt_token: str) -> OetcHandler: + """Create handler with mocked authentication for testing credentials fetching""" + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + # Mock the authentication result + mock_auth_result = AuthenticationResult( + token=sample_jwt_token, + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now(), + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = mock_auth_result + + return handler + + @patch("linopy.remote.oetc.requests.get") + def test_get_gcp_credentials_success( + self, + mock_get: Mock, + handler_with_mocked_auth: OetcHandler, + mock_gcp_credentials_response: dict, + ) -> None: + """Test successful GCP credentials fetching""" + # Setup mock + mock_response = Mock() + mock_response.json.return_value = mock_gcp_credentials_response + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + # Execute + result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + + # Verify request + mock_get.assert_called_once_with( + "https://auth.example.com/users/user-uuid-123/gcp-credentials", + headers={ + "Authorization": f"Bearer {handler_with_mocked_auth.jwt.token}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + # Verify result + assert isinstance(result, GcpCredentials) + assert result.gcp_project_id == "test-project-123" + assert result.gcp_service_key == "test-service-key-content" + assert result.input_bucket == "test-input-bucket" + assert result.solution_bucket == "test-solution-bucket" + + @patch("linopy.remote.oetc.requests.get") + def test_get_gcp_credentials_http_error( + self, mock_get: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test GCP credentials fetching with HTTP error""" + # Setup mock to raise HTTP error + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("403 Forbidden") + mock_get.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + + assert "Failed to fetch GCP credentials" in str(exc_info.value) + + @patch("linopy.remote.oetc.requests.get") + def test_get_gcp_credentials_missing_field( + self, mock_get: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test GCP credentials fetching with missing response field""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "gcp_project_id": "test-project-123", + "gcp_service_key": "test-service-key-content", + "input_bucket": "test-input-bucket", + # Missing "solution_bucket" field + } + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + + assert ( + "Invalid credentials response format: missing field 'solution_bucket'" + in str(exc_info.value) + ) + + def test_get_cloud_provider_credentials_unsupported_provider( + self, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test cloud provider credentials with unsupported provider""" + # Change to unsupported provider + handler_with_mocked_auth.settings.compute_provider = "AWS" # type: ignore # Not in enum, but for testing + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() # type: ignore[attr-defined] + + assert "Unsupported compute provider: AWS" in str(exc_info.value) + + def test_get_gcp_credentials_no_user_uuid_in_token( + self, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test GCP credentials fetching when JWT token has no user UUID""" + # Create token without 'sub' field + payload = {"iss": "OETC", "email": "test@example.com"} + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) + token_without_sub = f"header.{payload_encoded}.signature" + + handler_with_mocked_auth.jwt.token = token_without_sub + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + + assert "User UUID not found in JWT token" in str(exc_info.value) + + +class TestGcpCredentials: + def test_gcp_credentials_creation(self) -> None: + """Test GcpCredentials dataclass creation""" + credentials = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key="test-service-key-content", + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket", + ) + + assert credentials.gcp_project_id == "test-project-123" + assert credentials.gcp_service_key == "test-service-key-content" + assert credentials.input_bucket == "test-input-bucket" + assert credentials.solution_bucket == "test-solution-bucket" + + +class TestComputeProvider: + def test_compute_provider_enum(self) -> None: + """Test ComputeProvider enum values""" + assert ComputeProvider.GCP == "GCP" + assert ComputeProvider.GCP.value == "GCP" + + @patch("linopy.remote.oetc.requests.post") + def test_authentication_network_error( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: + """Test authentication failure with network error""" + # Setup mock to raise network error + mock_post.side_effect = RequestException("Connection timeout") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + @patch("linopy.remote.oetc.requests.post") + def test_authentication_invalid_response_missing_token( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: + """Test authentication failure with missing token in response""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "token_type": "Bearer", + "expires_in": 3600, + # Missing "token" field + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Invalid response format: missing field 'token'" in str(exc_info.value) + + @patch("linopy.remote.oetc.requests.post") + def test_authentication_invalid_response_missing_expires_in( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: + """Test authentication failure with missing expires_in in response""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "Bearer", + # Missing "expires_in" field + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Invalid response format: missing field 'expires_in'" in str( + exc_info.value + ) + + @patch("linopy.remote.oetc.requests.post") + def test_authentication_timeout_error( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: + """Test authentication failure with timeout""" + # Setup mock to raise timeout error + mock_post.side_effect = requests.Timeout("Request timeout") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + +class TestAuthenticationResult: + @pytest.fixture + def auth_result(self) -> AuthenticationResult: + """Create an AuthenticationResult for testing""" + return AuthenticationResult( + token="test_token", + token_type="Bearer", + expires_in=3600, # 1 hour + authenticated_at=datetime(2024, 1, 15, 12, 0, 0), + ) + + def test_expires_at_calculation(self, auth_result: AuthenticationResult) -> None: + """Test that expires_at correctly calculates expiration time""" + expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later + assert auth_result.expires_at == expected_expiry + + @patch("linopy.remote.oetc.datetime") + def test_is_expired_false_when_not_expired( + self, mock_datetime: Mock, auth_result: AuthenticationResult + ) -> None: + """Test is_expired returns False when token is still valid""" + # Set current time to before expiration + mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 30, 0) + + assert auth_result.is_expired is False + + @patch("linopy.remote.oetc.datetime") + def test_is_expired_true_when_expired( + self, mock_datetime: Mock, auth_result: AuthenticationResult + ) -> None: + """Test is_expired returns True when token has expired""" + # Set current time to after expiration + mock_datetime.now.return_value = datetime(2024, 1, 15, 14, 0, 0) + + assert auth_result.is_expired is True + + @patch("linopy.remote.oetc.datetime") + def test_is_expired_true_when_exactly_expired( + self, mock_datetime: Mock, auth_result: AuthenticationResult + ) -> None: + """Test is_expired returns True when token expires exactly now""" + # Set current time to exact expiration time + mock_datetime.now.return_value = datetime(2024, 1, 15, 13, 0, 0) + + assert auth_result.is_expired is True + + +class TestFileCompression: + @pytest.fixture + def handler_with_mocked_auth(self) -> OetcHandler: + """Create handler with mocked authentication for testing file operations""" + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = Mock() + + return handler + + @patch("linopy.remote.oetc.gzip.open") + @patch("linopy.remote.oetc.os.path.exists") + @patch("builtins.open") + def test_gzip_compress_success( + self, + mock_open: Mock, + mock_exists: Mock, + mock_gzip_open: Mock, + handler_with_mocked_auth: OetcHandler, + ) -> None: + """Test successful file compression""" + # Setup + source_path = "/tmp/test_file.nc" + expected_output = "/tmp/test_file.nc.gz" + + # Mock file operations + mock_exists.return_value = True + mock_file_in = Mock() + mock_file_out = Mock() + mock_open.return_value.__enter__.return_value = mock_file_in + mock_gzip_open.return_value.__enter__.return_value = mock_file_out + + # Mock file reading + mock_file_in.read.side_effect = [ + b"test_data_chunk", + b"", + ] # First read returns data, second returns empty + + # Execute + result = handler_with_mocked_auth._gzip_compress(source_path) + + # Verify + assert result == expected_output + mock_open.assert_called_once_with(source_path, "rb") + mock_gzip_open.assert_called_once_with(expected_output, "wb", compresslevel=9) + mock_file_out.write.assert_called_once_with(b"test_data_chunk") + + @patch("builtins.open") + def test_gzip_compress_file_read_error( + self, mock_open: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test file compression with read error""" + # Setup + source_path = "/tmp/test_file.nc" + mock_open.side_effect = OSError("File not found") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._gzip_compress(source_path) + + assert "Failed to compress file" in str(exc_info.value) + + +class TestGcpUpload: + @pytest.fixture + def handler_with_gcp_credentials( + self, mock_gcp_credentials_response: dict + ) -> OetcHandler: + """Create handler with GCP credentials for testing upload""" + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + # Create proper GCP credentials + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket", + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds + + return handler + + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.os.path.basename") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") + def test_upload_file_to_gcp_success( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_basename: Mock, + mock_remove: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: + """Test successful file upload to GCP""" + # Setup + file_path = "/tmp/test_file.nc" + compressed_path = "/tmp/test_file.nc.gz" + compressed_name = "test_file.nc.gz" + + # Mock compression + with patch.object( + handler_with_gcp_credentials, "_gzip_compress", return_value=compressed_path + ): + # Mock path operations + mock_basename.return_value = compressed_name + + # Mock GCP components + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_bucket.blob.return_value = mock_blob + + # Execute + result = handler_with_gcp_credentials._upload_file_to_gcp(file_path) + + # Verify + assert result == compressed_name + + # Verify GCP credentials creation + mock_creds_from_info.assert_called_once_with( + {"type": "service_account", "project_id": "test-project-123"}, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + # Verify GCP client creation + mock_storage_client.assert_called_once_with( + credentials=mock_credentials, project="test-project-123" + ) + + # Verify bucket access + mock_client.bucket.assert_called_once_with("test-input-bucket") + + # Verify blob operations + mock_bucket.blob.assert_called_once_with(compressed_name) + mock_blob.upload_from_filename.assert_called_once_with(compressed_path) + + # Verify cleanup + mock_remove.assert_called_once_with(compressed_path) + + @patch("linopy.remote.oetc.json.loads") + def test_upload_file_to_gcp_invalid_service_key( + self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler + ) -> None: + """Test upload failure with invalid service key""" + # Setup + file_path = "/tmp/test_file.nc" + mock_json_loads.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._upload_file_to_gcp(file_path) + + assert "Failed to upload file to GCP" in str(exc_info.value) + + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") + def test_upload_file_to_gcp_upload_error( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: + """Test upload failure during blob upload""" + # Setup + file_path = "/tmp/test_file.nc" + compressed_path = "/tmp/test_file.nc.gz" + + # Mock compression + with patch.object( + handler_with_gcp_credentials, "_gzip_compress", return_value=compressed_path + ): + # Mock GCP setup + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_blob.upload_from_filename.side_effect = Exception("Upload failed") + mock_bucket.blob.return_value = mock_blob + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._upload_file_to_gcp(file_path) + + assert "Failed to upload file to GCP" in str(exc_info.value) + + +class TestFileDecompression: + @pytest.fixture + def handler_with_mocked_auth(self) -> OetcHandler: + """Create handler with mocked authentication for testing file operations""" + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = Mock() + + return handler + + @patch("linopy.remote.oetc.gzip.open") + @patch("builtins.open") + def test_gzip_decompress_success( + self, + mock_open_file: Mock, + mock_gzip_open: Mock, + handler_with_mocked_auth: OetcHandler, + ) -> None: + """Test successful file decompression""" + # Setup + input_path = "/tmp/test_file.nc.gz" + expected_output = "/tmp/test_file.nc" + + # Mock file operations + mock_file_in = Mock() + mock_file_out = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.return_value.__enter__.return_value = mock_file_out + + # Mock file reading - simulate reading compressed data in chunks + mock_file_in.read.side_effect = [ + b"decompressed_data_chunk", + b"", + ] # First read returns data, second returns empty + + # Execute + result = handler_with_mocked_auth._gzip_decompress(input_path) + + # Verify + assert result == expected_output + mock_gzip_open.assert_called_once_with(input_path, "rb") + mock_open_file.assert_called_once_with(expected_output, "wb") + mock_file_out.write.assert_called_once_with(b"decompressed_data_chunk") + + @patch("linopy.remote.oetc.gzip.open") + def test_gzip_decompress_gzip_open_error( + self, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test file decompression with gzip open error""" + # Setup + input_path = "/tmp/test_file.nc.gz" + mock_gzip_open.side_effect = OSError("Failed to open gzip file") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._gzip_decompress(input_path) + + assert "Failed to decompress file" in str(exc_info.value) + + @patch("linopy.remote.oetc.gzip.open") + @patch("builtins.open") + def test_gzip_decompress_write_error( + self, + mock_open_file: Mock, + mock_gzip_open: Mock, + handler_with_mocked_auth: OetcHandler, + ) -> None: + """Test file decompression with write error""" + # Setup + input_path = "/tmp/test_file.nc.gz" + + # Mock file operations + mock_file_in = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.side_effect = OSError("Permission denied") + + # Mock file reading + mock_file_in.read.return_value = b"test_data" + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._gzip_decompress(input_path) + + assert "Failed to decompress file" in str(exc_info.value) + + def test_gzip_decompress_output_path_generation( + self, handler_with_mocked_auth: OetcHandler + ) -> None: + """Test correct output path generation for decompression""" + # Test first path + with patch("linopy.remote.oetc.gzip.open") as mock_gzip_open: + with patch("builtins.open") as mock_open_file: + mock_file_in = Mock() + mock_file_out = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.return_value.__enter__.return_value = mock_file_out + mock_file_in.read.side_effect = [b"test", b""] + + result = handler_with_mocked_auth._gzip_decompress("/tmp/file.nc.gz") + assert result == "/tmp/file.nc" + + # Test second path with fresh mocks + with patch("linopy.remote.oetc.gzip.open") as mock_gzip_open: + with patch("builtins.open") as mock_open_file: + mock_file_in = Mock() + mock_file_out = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.return_value.__enter__.return_value = mock_file_out + mock_file_in.read.side_effect = [b"test", b""] + + result = handler_with_mocked_auth._gzip_decompress( + "/path/to/model.data.gz" + ) + assert result == "/path/to/model.data" + + +class TestGcpDownload: + @pytest.fixture + def handler_with_gcp_credentials( + self, mock_gcp_credentials_response: dict + ) -> OetcHandler: + """Create handler with GCP credentials for testing download""" + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + # Create proper GCP credentials + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket", + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds + + return handler + + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_success( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_tempfile: Mock, + mock_remove: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: + """Test successful file download from GCP""" + # Setup + file_name = "solution_file.nc.gz" + compressed_path = "/tmp/tmpfile.gz" + decompressed_path = "/tmp/tmpfile" + + # Mock temporary file creation + mock_temp_file = Mock() + mock_temp_file.name = compressed_path + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock decompression + with patch.object( + handler_with_gcp_credentials, + "_gzip_decompress", + return_value=decompressed_path, + ): + # Mock GCP components + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_bucket.blob.return_value = mock_blob + + # Execute + result = handler_with_gcp_credentials._download_file_from_gcp(file_name) + + # Verify + assert result == decompressed_path + + # Verify GCP credentials creation + mock_creds_from_info.assert_called_once_with( + {"type": "service_account", "project_id": "test-project-123"}, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + # Verify GCP client creation + mock_storage_client.assert_called_once_with( + credentials=mock_credentials, project="test-project-123" + ) + + # Verify bucket access (solution bucket, not input bucket) + mock_client.bucket.assert_called_once_with("test-solution-bucket") + + # Verify blob operations + mock_bucket.blob.assert_called_once_with(file_name) + mock_blob.download_to_filename.assert_called_once_with(compressed_path) + + # Verify cleanup + mock_remove.assert_called_once_with(compressed_path) + + @patch("linopy.remote.oetc.json.loads") + def test_download_file_from_gcp_invalid_service_key( + self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler + ) -> None: + """Test download failure with invalid service key""" + # Setup + file_name = "solution_file.nc.gz" + mock_json_loads.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_download_error( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_tempfile: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: + """Test download failure during blob download""" + # Setup + file_name = "solution_file.nc.gz" + compressed_path = "/tmp/tmpfile.gz" + + # Mock temporary file creation + mock_temp_file = Mock() + mock_temp_file.name = compressed_path + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock GCP setup + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_blob.download_to_filename.side_effect = Exception("Download failed") + mock_bucket.blob.return_value = mock_blob + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_decompression_error( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_tempfile: Mock, + mock_remove: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: + """Test download failure during decompression""" + # Setup + file_name = "solution_file.nc.gz" + compressed_path = "/tmp/tmpfile.gz" + + # Mock temporary file creation + mock_temp_file = Mock() + mock_temp_file.name = compressed_path + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock successful GCP download but failed decompression + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_bucket.blob.return_value = mock_blob + + # Mock decompression failure + with patch.object( + handler_with_gcp_credentials, + "_gzip_decompress", + side_effect=Exception("Decompression failed"), + ): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_credentials_error( + self, mock_creds_from_info: Mock, handler_with_gcp_credentials: OetcHandler + ) -> None: + """Test download failure during credentials creation""" + # Setup + file_name = "solution_file.nc.gz" + mock_creds_from_info.side_effect = Exception("Invalid credentials") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + + +class TestJobSubmission: + @pytest.fixture + def handler_with_auth_setup(self, sample_jwt_token: str) -> OetcHandler: + """Create handler with authentication setup for testing job submission""" + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Optimization Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + solver="gurobi", + cpu_cores=4, + disk_space_gb=20, + ) + + # Mock the authentication result + mock_auth_result = AuthenticationResult( + token=sample_jwt_token, + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now(), + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = mock_auth_result + handler.cloud_provider_credentials = Mock() + + return handler + + @patch("linopy.remote.oetc.requests.post") + def test_submit_job_success( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: + """Test successful job submission to compute service""" + # Setup + input_file_name = "test_model.nc.gz" + expected_job_uuid = "job-uuid-123" + + # Mock successful response + mock_response = Mock() + mock_response.json.return_value = {"uuid": expected_job_uuid} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute + result = handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + + # Verify request + expected_payload = { + "name": "Test Optimization Job", + "solver": "gurobi", + "solver_options": {}, + "provider": "GCP", + "cpu_cores": 4, + "disk_space_gb": 20, + "input_file_name": input_file_name, + "delete_worker_on_error": False, + } + + mock_post.assert_called_once_with( + "https://orchestrator.example.com/compute-job/create", + json=expected_payload, + headers={ + "Authorization": f"Bearer {handler_with_auth_setup.jwt.token}", + "Content-Type": "application/json", + }, + timeout=30, + ) + + # Verify result + assert result == expected_job_uuid + + @patch("linopy.remote.oetc.requests.post") + def test_submit_job_http_error( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: + """Test job submission with HTTP error""" + # Setup + input_file_name = "test_model.nc.gz" + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError( + "400 Bad Request" + ) + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + + assert "Failed to submit job to compute service" in str(exc_info.value) + + @patch("linopy.remote.oetc.requests.post") + def test_submit_job_missing_uuid_in_response( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: + """Test job submission with missing UUID in response""" + # Setup + input_file_name = "test_model.nc.gz" + mock_response = Mock() + mock_response.json.return_value = {} # Missing "uuid" field + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + + assert "Invalid job submission response format: missing field 'uuid'" in str( + exc_info.value + ) + + @patch("linopy.remote.oetc.requests.post") + def test_submit_job_network_error( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: + """Test job submission with network error""" + # Setup + input_file_name = "test_model.nc.gz" + mock_post.side_effect = RequestException("Connection timeout") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_auth_setup._submit_job_to_compute_service(input_file_name) + + assert "Failed to submit job to compute service" in str(exc_info.value) + + +class TestSolveOnOetc: + @pytest.fixture + def handler_with_complete_setup( + self, mock_gcp_credentials_response: dict + ) -> OetcHandler: + """Create handler with complete setup for testing solve functionality""" + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket", + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds + + return handler + + @patch("linopy.remote.oetc.linopy.read_netcdf") + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_file_upload( + self, + mock_tempfile: Mock, + mock_remove: Mock, + mock_read_netcdf: Mock, + handler_with_complete_setup: OetcHandler, + ) -> None: + """Test solve_on_oetc method complete workflow""" + # Setup + mock_model = Mock() + mock_solved_model = Mock() + mock_solved_model.status = "optimal" + mock_solved_model.objective.value = 42.0 + + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + mock_read_netcdf.return_value = mock_solved_model + + # Mock file upload, job submission, job waiting, and download + with patch.object( + handler_with_complete_setup, + "_upload_file_to_gcp", + return_value="uploaded_file.nc.gz", + ) as mock_upload: + with patch.object( + handler_with_complete_setup, + "_submit_job_to_compute_service", + return_value="test-job-uuid", + ) as mock_submit: + with patch.object( + handler_with_complete_setup, + "wait_and_get_job_data", + return_value=JobResult( + uuid="test-job-uuid", + status="FINISHED", + output_files=[{"name": "result.nc.gz"}], + ), + ) as mock_wait: + with patch.object( + handler_with_complete_setup, + "_download_file_from_gcp", + return_value="/tmp/downloaded_result.nc", + ) as mock_download: + # Execute + result = handler_with_complete_setup.solve_on_oetc(mock_model) + + # Verify + assert ( + result == mock_solved_model + ) # Now returns the solved model + mock_model.to_netcdf.assert_called_once_with( + "/tmp/linopy-abc123.nc" + ) + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with("uploaded_file.nc.gz") + mock_wait.assert_called_once_with("test-job-uuid") + mock_download.assert_called_once_with("result.nc.gz") + mock_read_netcdf.assert_called_once_with( + "/tmp/downloaded_result.nc" + ) + mock_remove.assert_called_once_with("/tmp/downloaded_result.nc") + + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_upload_failure( + self, mock_tempfile: Mock, handler_with_complete_setup: OetcHandler + ) -> None: + """Test solve_on_oetc method with upload failure""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock upload failure + with patch.object( + handler_with_complete_setup, + "_upload_file_to_gcp", + side_effect=Exception("Upload failed"), + ): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_complete_setup.solve_on_oetc(mock_model) + + assert "Upload failed" in str(exc_info.value) + + +class TestSolveOnOetcWithJobSubmission: + @pytest.fixture + def handler_with_full_setup(self) -> OetcHandler: + """Create handler with full setup for testing complete solve flow""" + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) + settings = OetcSettings( + credentials=credentials, + name="Linopy Solve Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + solver="highs", + cpu_cores=2, + disk_space_gb=15, + ) + + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket", + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds + + return handler + + @patch("linopy.remote.oetc.linopy.read_netcdf") + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_with_job_submission( + self, + mock_tempfile: Mock, + mock_remove: Mock, + mock_read_netcdf: Mock, + handler_with_full_setup: OetcHandler, + ) -> None: + """Test solve_on_oetc method including job submission, waiting, and download""" + # Setup + mock_model = Mock() + mock_solved_model = Mock() + mock_solved_model.status = "optimal" + mock_solved_model.objective.value = 100.5 + + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + mock_read_netcdf.return_value = mock_solved_model + + uploaded_file_name = "model_file.nc.gz" + job_uuid = "job-uuid-456" + + # Mock complete workflow + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ) as mock_upload: + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + return_value=job_uuid, + ) as mock_submit: + with patch.object( + handler_with_full_setup, + "wait_and_get_job_data", + return_value=JobResult( + uuid=job_uuid, + status="FINISHED", + output_files=[{"name": "result.nc.gz"}], + ), + ) as mock_wait: + with patch.object( + handler_with_full_setup, + "_download_file_from_gcp", + return_value="/tmp/solution_file.nc", + ) as mock_download: + # Execute + result = handler_with_full_setup.solve_on_oetc(mock_model) + + # Verify + assert ( + result == mock_solved_model + ) # Now returns the solved model + mock_model.to_netcdf.assert_called_once_with( + "/tmp/linopy-abc123.nc" + ) + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with(uploaded_file_name) + mock_wait.assert_called_once_with(job_uuid) + mock_download.assert_called_once_with("result.nc.gz") + mock_read_netcdf.assert_called_once_with( + "/tmp/solution_file.nc" + ) + mock_remove.assert_called_once_with("/tmp/solution_file.nc") + + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_job_submission_failure( + self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler + ) -> None: + """Test solve_on_oetc method with job submission failure""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + uploaded_file_name = "model_file.nc.gz" + + # Mock successful upload but failed job submission - no need to mock waiting since submission fails + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ): + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + side_effect=Exception("Job submission failed"), + ): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_full_setup.solve_on_oetc(mock_model) + + assert "Job submission failed" in str(exc_info.value) + + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_job_waiting_failure( + self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler + ) -> None: + """Test solve_on_oetc method with job waiting failure""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + uploaded_file_name = "model_file.nc.gz" + job_uuid = "job-uuid-failed" + + # Mock successful upload and job submission but failed job waiting + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ): + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + return_value=job_uuid, + ): + with patch.object( + handler_with_full_setup, + "wait_and_get_job_data", + side_effect=Exception("Job failed: solver error"), + ): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_full_setup.solve_on_oetc(mock_model) + + assert "Job failed: solver error" in str(exc_info.value) + + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_no_output_files_error( + self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler + ) -> None: + """Test solve_on_oetc method when job completes but has no output files""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + uploaded_file_name = "model_file.nc.gz" + job_uuid = "job-uuid-456" + + # Mock successful workflow until job completion with no output files + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ): + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + return_value=job_uuid, + ): + with patch.object( + handler_with_full_setup, + "wait_and_get_job_data", + return_value=JobResult( + uuid=job_uuid, status="FINISHED", output_files=[] + ), + ): # No output files + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_full_setup.solve_on_oetc(mock_model) + + assert "No output files found in completed job" in str( + exc_info.value + ) + + +# Additional integration-style test +class TestOetcHandlerIntegration: + @patch("linopy.remote.oetc.requests.post") + @patch("linopy.remote.oetc.requests.get") + @patch("linopy.remote.oetc.datetime") + def test_complete_authentication_flow( + self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock + ) -> None: + """Test complete authentication and credentials flow with realistic data""" + # Setup + fixed_time = datetime(2024, 1, 15, 12, 0, 0) + mock_datetime.now.return_value = fixed_time + + credentials = OetcCredentials( + email="user@company.com", password="secure_password_123" + ) + settings = OetcSettings( + credentials=credentials, + name="Integration Test Job", + authentication_server_url="https://api.company.com/auth", + orchestrator_server_url="https://api.company.com/orchestrator", + compute_provider=ComputeProvider.GCP, + ) + + # Create realistic JWT token + payload = { + "iss": "OETC", + "sub": "user-uuid-456", + "exp": 1640995200, + "jti": "jwt-id-789", + "email": "user@company.com", + "firstname": "John", + "lastname": "Doe", + } + header = ( + base64.urlsafe_b64encode( + json.dumps({"alg": "HS256", "typ": "JWT"}).encode() + ) + .decode() + .rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) + realistic_token = f"{header}.{payload_encoded}.realistic_signature" + + # Mock authentication response + mock_auth_response = Mock() + mock_auth_response.json.return_value = { + "token": realistic_token, + "token_type": "Bearer", + "expires_in": 7200, # 2 hours + } + mock_auth_response.raise_for_status.return_value = None + mock_post.return_value = mock_auth_response + + # Mock GCP credentials response + mock_gcp_response = Mock() + mock_gcp_response.json.return_value = { + "gcp_project_id": "production-project-456", + "gcp_service_key": "production-service-key-content", + "input_bucket": "prod-input-bucket", + "solution_bucket": "prod-solution-bucket", + } + mock_gcp_response.raise_for_status.return_value = None + mock_get.return_value = mock_gcp_response + + # Execute + handler = OetcHandler(settings) + + # Verify authentication + assert handler.jwt.token == realistic_token + assert handler.jwt.token_type == "Bearer" + assert handler.jwt.expires_in == 7200 + assert handler.jwt.authenticated_at == fixed_time + assert handler.jwt.expires_at == datetime( + 2024, 1, 15, 14, 0, 0 + ) # 2 hours later + assert handler.jwt.is_expired is False + + # Verify GCP credentials + assert isinstance(handler.cloud_provider_credentials, GcpCredentials) + assert ( + handler.cloud_provider_credentials.gcp_project_id + == "production-project-456" + ) + assert ( + handler.cloud_provider_credentials.gcp_service_key + == "production-service-key-content" + ) + assert handler.cloud_provider_credentials.input_bucket == "prod-input-bucket" + assert ( + handler.cloud_provider_credentials.solution_bucket == "prod-solution-bucket" + ) + + # Verify correct API calls were made + mock_post.assert_called_once_with( + "https://api.company.com/auth/sign-in", + json={"email": "user@company.com", "password": "secure_password_123"}, + headers={"Content-Type": "application/json"}, + timeout=30, + ) + + mock_get.assert_called_once_with( + "https://api.company.com/auth/users/user-uuid-456/gcp-credentials", + headers={ + "Authorization": f"Bearer {realistic_token}", + "Content-Type": "application/json", + }, + timeout=30, + ) diff --git a/test/remote/test_oetc_job_polling.py b/test/remote/test_oetc_job_polling.py new file mode 100644 index 00000000..96ec98b4 --- /dev/null +++ b/test/remote/test_oetc_job_polling.py @@ -0,0 +1,327 @@ +""" +Tests for OetcHandler job polling and status monitoring. + +This module tests the wait_and_get_job_data method which polls for job completion +and handles various job states and error conditions. +""" + +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest +from requests import RequestException + +from linopy.remote.oetc import ( + AuthenticationResult, + ComputeProvider, + OetcCredentials, + OetcHandler, + OetcSettings, +) + + +@pytest.fixture +def mock_settings() -> OetcSettings: + """Create mock settings for testing.""" + credentials = OetcCredentials(email="test@example.com", password="test_password") + return OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + +@pytest.fixture +def mock_auth_result() -> AuthenticationResult: + """Create mock authentication result.""" + return AuthenticationResult( + token="mock_token", + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now(), + ) + + +@pytest.fixture +def oetc_handler( + mock_settings: OetcSettings, mock_auth_result: AuthenticationResult +) -> OetcHandler: + """Create OetcHandler with mocked authentication.""" + with patch( + "linopy.remote.oetc.OetcHandler._OetcHandler__sign_in", + return_value=mock_auth_result, + ): + with patch( + "linopy.remote.oetc.OetcHandler._OetcHandler__get_cloud_provider_credentials" + ): + handler = OetcHandler(mock_settings) + return handler + + +class TestJobPollingSuccess: + """Test successful job polling scenarios.""" + + def test_job_completes_immediately(self, oetc_handler: OetcHandler) -> None: + """Test job that completes on first poll.""" + job_data = { + "uuid": "job-123", + "status": "FINISHED", + "name": "test-job", + "owner": "test-user", + "solver": "highs", + "duration_in_seconds": 120, + "solving_duration_in_seconds": 90, + "input_files": ["input.nc"], + "output_files": ["output.nc"], + "created_at": "2024-01-01T00:00:00Z", + } + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + result = oetc_handler.wait_and_get_job_data("job-123") + + assert result.uuid == "job-123" + assert result.status == "FINISHED" + assert result.output_files == ["output.nc"] + mock_get.assert_called_once() + + def test_job_completes_with_no_output_files_warning( + self, oetc_handler: OetcHandler + ) -> None: + """Test job completion with no output files generates warning.""" + job_data = {"uuid": "job-123", "status": "FINISHED", "output_files": []} + + with patch("requests.get") as mock_get: + with patch("linopy.remote.oetc.logger.warning") as mock_warning: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + result = oetc_handler.wait_and_get_job_data("job-123") + + assert result.status == "FINISHED" + mock_warning.assert_called_once_with( + "OETC - Warning: Job completed but no output files found" + ) + + @patch("time.sleep") # Mock sleep to speed up test + def test_job_polling_progression( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: + """Test job progresses through multiple states before completion.""" + responses = [ + {"uuid": "job-123", "status": "PENDING"}, + {"uuid": "job-123", "status": "STARTING"}, + {"uuid": "job-123", "status": "RUNNING", "duration_in_seconds": 30}, + {"uuid": "job-123", "status": "RUNNING", "duration_in_seconds": 60}, + {"uuid": "job-123", "status": "FINISHED", "output_files": ["output.nc"]}, + ] + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.side_effect = responses + mock_get.return_value = mock_response + + result = oetc_handler.wait_and_get_job_data( + "job-123", initial_poll_interval=1 + ) + + assert result.status == "FINISHED" + assert mock_get.call_count == 5 + assert mock_sleep.call_count == 4 # Sleep called 4 times between 5 polls + + @patch("time.sleep") + def test_polling_interval_backoff( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: + """Test polling interval increases with exponential backoff.""" + responses = [ + {"uuid": "job-123", "status": "PENDING"}, + {"uuid": "job-123", "status": "RUNNING"}, + {"uuid": "job-123", "status": "FINISHED", "output_files": ["output.nc"]}, + ] + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.side_effect = responses + mock_get.return_value = mock_response + + oetc_handler.wait_and_get_job_data( + "job-123", initial_poll_interval=10, max_poll_interval=100 + ) + + # Verify sleep was called with increasing intervals + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] + assert sleep_calls[0] == 10 # Initial interval + assert sleep_calls[1] == 15 # 10 * 1.5 = 15 + + +class TestJobPollingErrors: + """Test job polling error scenarios.""" + + def test_setup_error_status(self, oetc_handler: OetcHandler) -> None: + """Test job with SETUP_ERROR status raises exception.""" + job_data = {"uuid": "job-123", "status": "SETUP_ERROR"} + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Job failed during setup phase"): + oetc_handler.wait_and_get_job_data("job-123") + + def test_runtime_error_status(self, oetc_handler: OetcHandler) -> None: + """Test job with RUNTIME_ERROR status raises exception.""" + job_data = {"uuid": "job-123", "status": "RUNTIME_ERROR"} + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Job failed during execution"): + oetc_handler.wait_and_get_job_data("job-123") + + def test_unknown_status_error(self, oetc_handler: OetcHandler) -> None: + """Test job with unknown status raises exception.""" + job_data = {"uuid": "job-123", "status": "UNKNOWN_STATUS"} + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Unknown job status: UNKNOWN_STATUS"): + oetc_handler.wait_and_get_job_data("job-123") + + +class TestJobPollingNetworkErrors: + """Test network error handling during job polling.""" + + @patch("time.sleep") + def test_network_retry_success( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: + """Test network errors are retried and eventually succeed.""" + successful_response = { + "uuid": "job-123", + "status": "FINISHED", + "output_files": ["output.nc"], + } + + with patch("requests.get") as mock_get: + # First two calls fail, third succeeds + mock_get.side_effect = [ + RequestException("Network error 1"), + RequestException("Network error 2"), + Mock( + raise_for_status=Mock(), json=Mock(return_value=successful_response) + ), + ] + + result = oetc_handler.wait_and_get_job_data("job-123") + + assert result.status == "FINISHED" + assert mock_get.call_count == 3 + assert mock_sleep.call_count == 2 # Retry delays + + # Verify retry delays increase + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] + assert sleep_calls[0] == 10 # First retry: 1 * 10 = 10 + assert sleep_calls[1] == 20 # Second retry: 2 * 10 = 20 + + @patch("time.sleep") + def test_max_network_retries_exceeded( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: + """Test max network retries causes exception.""" + with patch("requests.get") as mock_get: + # All calls fail with RequestException + mock_get.side_effect = RequestException("Network error") + + with pytest.raises( + Exception, match="Failed to get job status after 10 network retries" + ): + oetc_handler.wait_and_get_job_data("job-123") + + # Should retry exactly 10 times before failing + assert mock_get.call_count == 10 + + @patch("time.sleep") + def test_network_retry_delay_cap( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: + """Test network retry delay is capped at 60 seconds.""" + with patch("requests.get") as mock_get: + mock_get.side_effect = RequestException("Network error") + + with pytest.raises(Exception): + oetc_handler.wait_and_get_job_data("job-123") + + # Check that delay is capped at 60 seconds + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] + assert all(delay <= 60 for delay in sleep_calls) + + def test_keyerror_in_response(self, oetc_handler: OetcHandler) -> None: + """Test KeyError in response parsing raises exception.""" + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {} # Missing required 'uuid' field + mock_get.return_value = mock_response + + with pytest.raises( + Exception, match="Invalid job status response format: missing field" + ): + oetc_handler.wait_and_get_job_data("job-123") + + def test_generic_exception_handling(self, oetc_handler: OetcHandler) -> None: + """Test generic exception handling in polling loop.""" + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.side_effect = ValueError("Unexpected error") + mock_get.return_value = mock_response + + with pytest.raises( + Exception, match="Error getting job status: Unexpected error" + ): + oetc_handler.wait_and_get_job_data("job-123") + + def test_status_error_exception_preserved(self, oetc_handler: OetcHandler) -> None: + """Test that status-related exceptions are preserved.""" + with patch("requests.get") as mock_get: + # Simulate an exception that mentions "status:" - should be re-raised as-is + mock_response = Mock() + mock_response.raise_for_status.side_effect = Exception( + "Custom status: error" + ) + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Custom status: error"): + oetc_handler.wait_and_get_job_data("job-123") + + def test_oetc_logs_exception_preserved(self, oetc_handler: OetcHandler) -> None: + """Test that OETC logs exceptions are preserved.""" + with patch("requests.get") as mock_get: + # Simulate an exception that mentions "OETC logs" - should be re-raised as-is + mock_response = Mock() + mock_response.raise_for_status.side_effect = Exception( + "Check the OETC logs for details" + ) + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Check the OETC logs for details"): + oetc_handler.wait_and_get_job_data("job-123")