From f1f8d4ec3c618318c5191a9cdfa492cbba15f56c Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 9 Jun 2025 11:34:29 +0200 Subject: [PATCH 01/28] Implement OETC authentication step --- .gitignore | 1 + linopy/model.py | 1 + linopy/oetc.py | 79 ++++++++++++++++ test/test_oetc.py | 225 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 306 insertions(+) create mode 100644 linopy/oetc.py create mode 100644 test/test_oetc.py diff --git a/.gitignore b/.gitignore index 2088cae0..4caa465c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ doc/_build doc/generated doc/api .vscode +.idea Highs.log paper/ monkeytype.sqlite3 diff --git a/linopy/model.py b/linopy/model.py index c8024843..5767d653 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1049,6 +1049,7 @@ def solve( slice_size: int = 2_000_000, remote: Any = None, progress: bool | None = None, + oetc_settings: dict | None = {}, **solver_options: Any, ) -> tuple[str, str]: """ diff --git a/linopy/oetc.py b/linopy/oetc.py new file mode 100644 index 00000000..6c493014 --- /dev/null +++ b/linopy/oetc.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta + +import requests +from requests import RequestException + + +@dataclass +class OetcCredentials: + email: str + password: str + +@dataclass +class OetcSettings: + credentials: OetcCredentials + authentication_server_url: str + + +@dataclass +class AuthenticationResult: + token: str + token_type: str + expires_in: int # value represented in seconds + authenticated_at: datetime + + @property + def expires_at(self) -> datetime: + """Calculate when the token expires""" + return self.authenticated_at + timedelta(seconds=self.expires_in) + + @property + def is_expired(self) -> bool: + """Check if the token has expired""" + return datetime.now() >= self.expires_at + +class OetcHandler: + def __init__(self, settings: OetcSettings) -> None: + self.settings = settings + self.jwt = self.__sign_in() + + def __sign_in(self) -> AuthenticationResult: + """ + Authenticate with the server and return the authentication result. + + Returns: + AuthenticationResult: The complete authentication result including token and expiration info + + Raises: + Exception: If authentication fails or response is invalid + """ + try: + payload = { + "email": self.settings.credentials.email, + "password": self.settings.credentials.password + } + + response = requests.post( + f"{self.settings.authentication_server_url}/sign-in", + json=payload, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + response.raise_for_status() + jwt_result = response.json() + + return AuthenticationResult( + token=jwt_result["token"], + token_type=jwt_result["token_type"], + expires_in=jwt_result["expires_in"], + authenticated_at=datetime.now() + ) + + except RequestException as e: + raise Exception(f"Authentication request failed: {e}") + except KeyError as e: + raise Exception(f"Invalid response format: missing field {e}") + except Exception as e: + raise Exception(f"Authentication error: {e}") \ No newline at end of file diff --git a/test/test_oetc.py b/test/test_oetc.py new file mode 100644 index 00000000..640e68bc --- /dev/null +++ b/test/test_oetc.py @@ -0,0 +1,225 @@ +import pytest +from datetime import datetime +from unittest.mock import patch, Mock + +import requests +from requests import RequestException + +from linopy.oetc import OetcCredentials, OetcSettings, OetcHandler, AuthenticationResult + + +class TestOetcHandler: + + @pytest.fixture + def mock_settings(self): + """Create mock settings for testing""" + credentials = OetcCredentials( + email="test@example.com", + password="test_password" + ) + return OetcSettings( + credentials=credentials, + authentication_server_url="https://auth.example.com" + ) + + @pytest.fixture + def mock_jwt_response(self): + """Create a mock JWT response""" + return { + "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "Bearer", + "expires_in": 3600 + } + + @patch('linopy.oetc.requests.post') + @patch('linopy.oetc.datetime') + def test_successful_authentication(self, mock_datetime, mock_post, mock_settings, mock_jwt_response): + """Test successful authentication flow""" + # Setup mocks + fixed_time = datetime(2024, 1, 15, 12, 0, 0) + mock_datetime.now.return_value = fixed_time + + mock_response = Mock() + mock_response.json.return_value = mock_jwt_response + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute + handler = OetcHandler(mock_settings) + + # Verify requests.post was called correctly + mock_post.assert_called_once_with( + "https://auth.example.com/sign-in", + json={ + "email": "test@example.com", + "password": "test_password" + }, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + # Verify AuthenticationResult + assert isinstance(handler.jwt, AuthenticationResult) + assert handler.jwt.token == "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + assert handler.jwt.token_type == "Bearer" + assert handler.jwt.expires_in == 3600 + assert handler.jwt.authenticated_at == fixed_time + + @patch('linopy.oetc.requests.post') + def test_authentication_http_error(self, mock_post, mock_settings): + """Test authentication failure with HTTP error""" + # Setup mock to raise HTTP error + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("401 Unauthorized") + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + @patch('linopy.oetc.requests.post') + def test_authentication_network_error(self, mock_post, mock_settings): + """Test authentication failure with network error""" + # Setup mock to raise network error + mock_post.side_effect = RequestException("Connection timeout") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + @patch('linopy.oetc.requests.post') + def test_authentication_invalid_response_missing_token(self, mock_post, mock_settings): + """Test authentication failure with missing token in response""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "token_type": "Bearer", + "expires_in": 3600 + # Missing "token" field + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Invalid response format: missing field 'token'" in str(exc_info.value) + + @patch('linopy.oetc.requests.post') + def test_authentication_invalid_response_missing_expires_in(self, mock_post, mock_settings): + """Test authentication failure with missing expires_in in response""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "Bearer" + # Missing "expires_in" field + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Invalid response format: missing field 'expires_in'" in str(exc_info.value) + + @patch('linopy.oetc.requests.post') + def test_authentication_timeout_error(self, mock_post, mock_settings): + """Test authentication failure with timeout""" + # Setup mock to raise timeout error + mock_post.side_effect = requests.Timeout("Request timeout") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + +class TestAuthenticationResult: + + @pytest.fixture + def auth_result(self): + """Create an AuthenticationResult for testing""" + return AuthenticationResult( + token="test_token", + token_type="Bearer", + expires_in=3600, # 1 hour + authenticated_at=datetime(2024, 1, 15, 12, 0, 0) + ) + + def test_expires_at_calculation(self, auth_result): + """Test that expires_at correctly calculates expiration time""" + expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later + assert auth_result.expires_at == expected_expiry + + @patch('linopy.oetc.datetime') + def test_is_expired_false_when_not_expired(self, mock_datetime, auth_result): + """Test is_expired returns False when token is still valid""" + # Set current time to before expiration + mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 30, 0) + + assert auth_result.is_expired is False + + @patch('linopy.oetc.datetime') + def test_is_expired_true_when_expired(self, mock_datetime, auth_result): + """Test is_expired returns True when token has expired""" + # Set current time to after expiration + mock_datetime.now.return_value = datetime(2024, 1, 15, 14, 0, 0) + + assert auth_result.is_expired is True + + @patch('linopy.oetc.datetime') + def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): + """Test is_expired returns True when token expires exactly now""" + # Set current time to exact expiration time + mock_datetime.now.return_value = datetime(2024, 1, 15, 13, 0, 0) + + assert auth_result.is_expired is True + + +class TestOetcHandlerIntegration: + @patch('linopy.oetc.requests.post') + @patch('linopy.oetc.datetime') + def test_complete_authentication_flow(self, mock_datetime, mock_post): + """Test complete authentication flow with realistic data""" + # Setup + fixed_time = datetime(2024, 1, 15, 12, 0, 0) + mock_datetime.now.return_value = fixed_time + + credentials = OetcCredentials( + email="user@company.com", + password="secure_password_123" + ) + settings = OetcSettings( + credentials=credentials, + authentication_server_url="https://api.company.com/auth" + ) + + mock_response = Mock() + mock_response.json.return_value = { + "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", + "token_type": "Bearer", + "expires_in": 7200 # 2 hours + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute + handler = OetcHandler(settings) + + # Verify + assert handler.jwt.token.startswith("eyJhbGciOiJIUzI1NiI") + assert handler.jwt.token_type == "Bearer" + assert handler.jwt.expires_in == 7200 + assert handler.jwt.authenticated_at == fixed_time + assert handler.jwt.expires_at == datetime(2024, 1, 15, 14, 0, 0) # 2 hours later + + # Test that token is not expired immediately after authentication + assert handler.jwt.is_expired is False \ No newline at end of file From cf88264b006bd1483b6264ee156b74f5c3af6f7e Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 9 Jun 2025 13:49:29 +0200 Subject: [PATCH 02/28] Implement cloud provider credentials fetch --- linopy/oetc.py | 101 ++++++++++++- test/test_oetc.py | 376 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 444 insertions(+), 33 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 6c493014..406ed83c 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -1,19 +1,37 @@ from dataclasses import dataclass from datetime import datetime, timedelta +from enum import Enum +from typing import Union +import json +import base64 import requests from requests import RequestException +class ComputeProvider(str, Enum): + GCP = "GCP" + + @dataclass class OetcCredentials: email: str password: str + @dataclass class OetcSettings: credentials: OetcCredentials authentication_server_url: str + compute_provider: ComputeProvider = ComputeProvider.GCP + + +@dataclass +class GcpCredentials: + gcp_project_id: str + gcp_service_key: str + input_bucket: str + solution_bucket: str @dataclass @@ -33,10 +51,12 @@ def is_expired(self) -> bool: """Check if the token has expired""" return datetime.now() >= self.expires_at + class OetcHandler: def __init__(self, settings: OetcSettings) -> None: self.settings = settings self.jwt = self.__sign_in() + self.cloud_provider_credentials = self.__get_cloud_provider_credentials() def __sign_in(self) -> AuthenticationResult: """ @@ -76,4 +96,83 @@ def __sign_in(self) -> AuthenticationResult: except KeyError as e: raise Exception(f"Invalid response format: missing field {e}") except Exception as e: - raise Exception(f"Authentication error: {e}") \ No newline at end of file + raise Exception(f"Authentication error: {e}") + + def _decode_jwt_payload(self, token: str) -> dict: + """ + Decode JWT payload without verification to extract user information. + + Args: + token: The JWT token + + Returns: + dict: The decoded payload containing user information + + Raises: + Exception: If token cannot be decoded + """ + try: + payload_part = token.split('.')[1] + payload_part += '=' * (4 - len(payload_part) % 4) + payload_bytes = base64.urlsafe_b64decode(payload_part) + return json.loads(payload_bytes.decode('utf-8')) + except (IndexError, json.JSONDecodeError, Exception) as e: + raise Exception(f"Failed to decode JWT payload: {e}") + + def __get_cloud_provider_credentials(self) -> Union[GcpCredentials, None]: + """ + Fetch cloud provider credentials based on the configured provider. + + Returns: + Union[GcpCredentials, None]: The cloud provider credentials + + Raises: + Exception: If the compute provider is not supported + """ + if self.settings.compute_provider == ComputeProvider.GCP: + return self.__get_gcp_credentials() + else: + raise Exception(f"Unsupported compute provider: {self.settings.compute_provider}") + + def __get_gcp_credentials(self) -> GcpCredentials: + """ + Fetch GCP credentials for the authenticated user. + + Returns: + GcpCredentials: The GCP credentials including project ID, service key, and bucket information + + Raises: + Exception: If credentials fetching fails or response is invalid + """ + try: + payload = self._decode_jwt_payload(self.jwt.token) + user_uuid = payload.get('sub') + + if not user_uuid: + raise Exception("User UUID not found in JWT token") + + response = requests.get( + f"{self.settings.authentication_server_url}/users/{user_uuid}/gcp-credentials", + headers={ + "Authorization": f"{self.jwt.token_type} {self.jwt.token}", + "Content-Type": "application/json" + }, + timeout=30 + ) + + response.raise_for_status() + credentials_data = response.json() + + return GcpCredentials( + gcp_project_id=credentials_data["gcp_project_id"], + gcp_service_key=credentials_data["gcp_service_key"], + input_bucket=credentials_data["input_bucket"], + solution_bucket=credentials_data["solution_bucket"] + ) + + except RequestException as e: + raise Exception(f"Failed to fetch GCP credentials: {e}") + except KeyError as e: + raise Exception(f"Invalid credentials response format: missing field {e}") + except Exception as e: + raise Exception(f"Error fetching GCP credentials: {e}") \ No newline at end of file diff --git a/test/test_oetc.py b/test/test_oetc.py index 640e68bc..62b63244 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -1,27 +1,66 @@ import pytest from datetime import datetime from unittest.mock import patch, Mock +import base64 +import json import requests from requests import RequestException -from linopy.oetc import OetcCredentials, OetcSettings, OetcHandler, AuthenticationResult +from linopy.oetc import ( + OetcHandler, OetcSettings, OetcCredentials, AuthenticationResult, + ComputeProvider, GcpCredentials +) + + +@pytest.fixture +def sample_jwt_token(): + """Create a sample JWT token with test payload""" + payload = { + "iss": "OETC", + "sub": "user-uuid-123", + "exp": 1640995200, + "jti": "jwt-id-456", + "email": "test@example.com", + "firstname": "Test", + "lastname": "User" + } + + # Create a simple JWT-like token (header.payload.signature) + header = base64.urlsafe_b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()).decode().rstrip('=') + payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + signature = "fake_signature" + + return f"{header}.{payload_encoded}.{signature}" + + +@pytest.fixture +def mock_gcp_credentials_response(): + """Create a mock GCP credentials response""" + return { + "gcp_project_id": "test-project-123", + "gcp_service_key": "test-service-key-content", + "input_bucket": "test-input-bucket", + "solution_bucket": "test-solution-bucket" + } + + +@pytest.fixture +def mock_settings(): + """Create mock settings for testing""" + credentials = OetcCredentials( + email="test@example.com", + password="test_password" + ) + return OetcSettings( + credentials=credentials, + authentication_server_url="https://auth.example.com", + compute_provider=ComputeProvider.GCP + ) class TestOetcHandler: - @pytest.fixture - def mock_settings(self): - """Create mock settings for testing""" - credentials = OetcCredentials( - email="test@example.com", - password="test_password" - ) - return OetcSettings( - credentials=credentials, - authentication_server_url="https://auth.example.com" - ) - @pytest.fixture def mock_jwt_response(self): """Create a mock JWT response""" @@ -32,22 +71,32 @@ def mock_jwt_response(self): } @patch('linopy.oetc.requests.post') + @patch('linopy.oetc.requests.get') @patch('linopy.oetc.datetime') - def test_successful_authentication(self, mock_datetime, mock_post, mock_settings, mock_jwt_response): + def test_successful_authentication(self, mock_datetime, mock_get, mock_post, mock_settings, mock_jwt_response, + mock_gcp_credentials_response, sample_jwt_token): """Test successful authentication flow""" # Setup mocks fixed_time = datetime(2024, 1, 15, 12, 0, 0) mock_datetime.now.return_value = fixed_time - mock_response = Mock() - mock_response.json.return_value = mock_jwt_response - mock_response.raise_for_status.return_value = None - mock_post.return_value = mock_response + # Mock authentication response + mock_auth_response = Mock() + mock_jwt_response["token"] = sample_jwt_token + mock_auth_response.json.return_value = mock_jwt_response + mock_auth_response.raise_for_status.return_value = None + mock_post.return_value = mock_auth_response + + # Mock GCP credentials response + mock_gcp_response = Mock() + mock_gcp_response.json.return_value = mock_gcp_credentials_response + mock_gcp_response.raise_for_status.return_value = None + mock_get.return_value = mock_gcp_response # Execute handler = OetcHandler(mock_settings) - # Verify requests.post was called correctly + # Verify authentication request mock_post.assert_called_once_with( "https://auth.example.com/sign-in", json={ @@ -58,13 +107,30 @@ def test_successful_authentication(self, mock_datetime, mock_post, mock_settings timeout=30 ) + # Verify GCP credentials request + mock_get.assert_called_once_with( + "https://auth.example.com/users/user-uuid-123/gcp-credentials", + headers={ + "Authorization": f"Bearer {sample_jwt_token}", + "Content-Type": "application/json" + }, + timeout=30 + ) + # Verify AuthenticationResult assert isinstance(handler.jwt, AuthenticationResult) - assert handler.jwt.token == "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + assert handler.jwt.token == sample_jwt_token assert handler.jwt.token_type == "Bearer" assert handler.jwt.expires_in == 3600 assert handler.jwt.authenticated_at == fixed_time + # Verify GcpCredentials + assert isinstance(handler.cloud_provider_credentials, GcpCredentials) + assert handler.cloud_provider_credentials.gcp_project_id == "test-project-123" + assert handler.cloud_provider_credentials.gcp_service_key == "test-service-key-content" + assert handler.cloud_provider_credentials.input_bucket == "test-input-bucket" + assert handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" + @patch('linopy.oetc.requests.post') def test_authentication_http_error(self, mock_post, mock_settings): """Test authentication failure with HTTP error""" @@ -79,6 +145,200 @@ def test_authentication_http_error(self, mock_post, mock_settings): assert "Authentication request failed" in str(exc_info.value) + +class TestJwtDecoding: + + @pytest.fixture + def handler_with_mocked_auth(self): + """Create handler with mocked authentication for testing JWT decoding""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + authentication_server_url="https://auth.example.com", + compute_provider=ComputeProvider.GCP + ) + + # Mock the authentication and credentials fetching + mock_auth_result = AuthenticationResult( + token="fake.token.here", + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now() + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = mock_auth_result + handler.cloud_provider_credentials = None + + return handler + + def test_decode_jwt_payload_success(self, handler_with_mocked_auth, sample_jwt_token): + """Test successful JWT payload decoding""" + result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) + + assert result["iss"] == "OETC" + assert result["sub"] == "user-uuid-123" + assert result["email"] == "test@example.com" + assert result["firstname"] == "Test" + assert result["lastname"] == "User" + + def test_decode_jwt_payload_invalid_token(self, handler_with_mocked_auth): + """Test JWT payload decoding with invalid token""" + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._decode_jwt_payload("invalid.token") + + assert "Failed to decode JWT payload" in str(exc_info.value) + + def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth): + """Test JWT payload decoding with malformed token""" + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._decode_jwt_payload("not_a_jwt_token") + + assert "Failed to decode JWT payload" in str(exc_info.value) + + +class TestCloudProviderCredentials: + + @pytest.fixture + def handler_with_mocked_auth(self, sample_jwt_token): + """Create handler with mocked authentication for testing credentials fetching""" + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + authentication_server_url="https://auth.example.com", + compute_provider=ComputeProvider.GCP + ) + + # Mock the authentication result + mock_auth_result = AuthenticationResult( + token=sample_jwt_token, + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now() + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = mock_auth_result + + return handler + + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_success(self, mock_get, handler_with_mocked_auth, mock_gcp_credentials_response): + """Test successful GCP credentials fetching""" + # Setup mock + mock_response = Mock() + mock_response.json.return_value = mock_gcp_credentials_response + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + # Execute + result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + # Verify request + mock_get.assert_called_once_with( + "https://auth.example.com/users/user-uuid-123/gcp-credentials", + headers={ + "Authorization": f"Bearer {handler_with_mocked_auth.jwt.token}", + "Content-Type": "application/json" + }, + timeout=30 + ) + + # Verify result + assert isinstance(result, GcpCredentials) + assert result.gcp_project_id == "test-project-123" + assert result.gcp_service_key == "test-service-key-content" + assert result.input_bucket == "test-input-bucket" + assert result.solution_bucket == "test-solution-bucket" + + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth): + """Test GCP credentials fetching with HTTP error""" + # Setup mock to raise HTTP error + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("403 Forbidden") + mock_get.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + assert "Failed to fetch GCP credentials" in str(exc_info.value) + + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_missing_field(self, mock_get, handler_with_mocked_auth): + """Test GCP credentials fetching with missing response field""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "gcp_project_id": "test-project-123", + "gcp_service_key": "test-service-key-content", + "input_bucket": "test-input-bucket" + # Missing "solution_bucket" field + } + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + assert "Invalid credentials response format: missing field 'solution_bucket'" in str(exc_info.value) + + def test_get_cloud_provider_credentials_unsupported_provider(self, handler_with_mocked_auth): + """Test cloud provider credentials with unsupported provider""" + # Change to unsupported provider + handler_with_mocked_auth.settings.compute_provider = "AWS" # Not in enum, but for testing + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() + + assert "Unsupported compute provider: AWS" in str(exc_info.value) + + def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_auth): + """Test GCP credentials fetching when JWT token has no user UUID""" + # Create token without 'sub' field + payload = {"iss": "OETC", "email": "test@example.com"} + payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + token_without_sub = f"header.{payload_encoded}.signature" + + handler_with_mocked_auth.jwt.token = token_without_sub + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + assert "User UUID not found in JWT token" in str(exc_info.value) + + +class TestGcpCredentials: + + def test_gcp_credentials_creation(self): + """Test GcpCredentials dataclass creation""" + credentials = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key="test-service-key-content", + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket" + ) + + assert credentials.gcp_project_id == "test-project-123" + assert credentials.gcp_service_key == "test-service-key-content" + assert credentials.input_bucket == "test-input-bucket" + assert credentials.solution_bucket == "test-solution-bucket" + + +class TestComputeProvider: + + def test_compute_provider_enum(self): + """Test ComputeProvider enum values""" + assert ComputeProvider.GCP == "GCP" + assert ComputeProvider.GCP.value == "GCP" + @patch('linopy.oetc.requests.post') def test_authentication_network_error(self, mock_post, mock_settings): """Test authentication failure with network error""" @@ -184,11 +444,14 @@ def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): assert auth_result.is_expired is True +# Additional integration-style test class TestOetcHandlerIntegration: + @patch('linopy.oetc.requests.post') + @patch('linopy.oetc.requests.get') @patch('linopy.oetc.datetime') - def test_complete_authentication_flow(self, mock_datetime, mock_post): - """Test complete authentication flow with realistic data""" + def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): + """Test complete authentication and credentials flow with realistic data""" # Setup fixed_time = datetime(2024, 1, 15, 12, 0, 0) mock_datetime.now.return_value = fixed_time @@ -199,27 +462,76 @@ def test_complete_authentication_flow(self, mock_datetime, mock_post): ) settings = OetcSettings( credentials=credentials, - authentication_server_url="https://api.company.com/auth" + authentication_server_url="https://api.company.com/auth", + compute_provider=ComputeProvider.GCP ) - mock_response = Mock() - mock_response.json.return_value = { - "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", + # Create realistic JWT token + payload = { + "iss": "OETC", + "sub": "user-uuid-456", + "exp": 1640995200, + "jti": "jwt-id-789", + "email": "user@company.com", + "firstname": "John", + "lastname": "Doe" + } + header = base64.urlsafe_b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()).decode().rstrip('=') + payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + realistic_token = f"{header}.{payload_encoded}.realistic_signature" + + # Mock authentication response + mock_auth_response = Mock() + mock_auth_response.json.return_value = { + "token": realistic_token, "token_type": "Bearer", "expires_in": 7200 # 2 hours } - mock_response.raise_for_status.return_value = None - mock_post.return_value = mock_response + mock_auth_response.raise_for_status.return_value = None + mock_post.return_value = mock_auth_response + + # Mock GCP credentials response + mock_gcp_response = Mock() + mock_gcp_response.json.return_value = { + "gcp_project_id": "production-project-456", + "gcp_service_key": "production-service-key-content", + "input_bucket": "prod-input-bucket", + "solution_bucket": "prod-solution-bucket" + } + mock_gcp_response.raise_for_status.return_value = None + mock_get.return_value = mock_gcp_response # Execute handler = OetcHandler(settings) - # Verify - assert handler.jwt.token.startswith("eyJhbGciOiJIUzI1NiI") + # Verify authentication + assert handler.jwt.token == realistic_token assert handler.jwt.token_type == "Bearer" assert handler.jwt.expires_in == 7200 assert handler.jwt.authenticated_at == fixed_time assert handler.jwt.expires_at == datetime(2024, 1, 15, 14, 0, 0) # 2 hours later + assert handler.jwt.is_expired is False + + # Verify GCP credentials + assert isinstance(handler.cloud_provider_credentials, GcpCredentials) + assert handler.cloud_provider_credentials.gcp_project_id == "production-project-456" + assert handler.cloud_provider_credentials.gcp_service_key == "production-service-key-content" + assert handler.cloud_provider_credentials.input_bucket == "prod-input-bucket" + assert handler.cloud_provider_credentials.solution_bucket == "prod-solution-bucket" + + # Verify correct API calls were made + mock_post.assert_called_once_with( + "https://api.company.com/auth/sign-in", + json={"email": "user@company.com", "password": "secure_password_123"}, + headers={"Content-Type": "application/json"}, + timeout=30 + ) - # Test that token is not expired immediately after authentication - assert handler.jwt.is_expired is False \ No newline at end of file + mock_get.assert_called_once_with( + "https://api.company.com/auth/users/user-uuid-456/gcp-credentials", + headers={ + "Authorization": f"Bearer {realistic_token}", + "Content-Type": "application/json" + }, + timeout=30 + ) From 89c2da511574a991173e83e4b27652544b725ad6 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 9 Jun 2025 17:00:46 +0200 Subject: [PATCH 03/28] Implement OETC model compression and upload --- linopy/oetc.py | 104 ++++++++++++++++- pyproject.toml | 1 + test/test_oetc.py | 277 ++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 360 insertions(+), 22 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 406ed83c..4dd66018 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -4,9 +4,14 @@ from typing import Union import json import base64 +import gzip +import os +import tempfile import requests from requests import RequestException +from google.cloud import storage +from google.oauth2 import service_account class ComputeProvider(str, Enum): @@ -175,4 +180,101 @@ def __get_gcp_credentials(self) -> GcpCredentials: except KeyError as e: raise Exception(f"Invalid credentials response format: missing field {e}") except Exception as e: - raise Exception(f"Error fetching GCP credentials: {e}") \ No newline at end of file + raise Exception(f"Error fetching GCP credentials: {e}") + + def solve_on_oetc(self, model): + """ + Solve a linopy model on the OET Cloud compute app. + + Parameters + ---------- + model : linopy.model.Model + **kwargs : + Keyword arguments passed to `linopy.model.Model.solve`. + + Returns + ------- + linopy.model.Model + Solved model. + """ + with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: + model.to_netcdf(fn.name) + input_file_name = self._upload_file_to_gcp(fn.name) + + # TODO: Submit job to compute service + # TODO: Wait for job completion + # TODO: Download result from GCP + # TODO: Return solved model + + return model + + def _gzip_compress(self, source_path: str) -> str: + """ + Compress a file using gzip compression. + + Args: + source_path: Path to the source file to compress + + Returns: + str: Path to the compressed file + + Raises: + Exception: If compression fails + """ + try: + output_path = source_path + ".gz" + chunk_size = 1024 * 1024 + + with open(source_path, "rb") as f_in: + with gzip.open(output_path, "wb", compresslevel=9) as f_out: + while True: + chunk = f_in.read(chunk_size) + if not chunk: + break + f_out.write(chunk) + + return output_path + except Exception as e: + raise Exception(f"Failed to compress file: {e}") + + def _upload_file_to_gcp(self, file_path: str) -> str: + """ + Upload a file to GCP storage bucket after compression. + + Args: + file_path: Path to the file to upload + + Returns: + str: Name of the uploaded file in the bucket + + Raises: + Exception: If upload fails + """ + try: + compressed_file_path = self._gzip_compress(file_path) + compressed_file_name = os.path.basename(compressed_file_path) + + # Create GCP credentials from service key + service_key_dict = json.loads(self.cloud_provider_credentials.gcp_service_key) + credentials = service_account.Credentials.from_service_account_info( + service_key_dict, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + # Upload to GCP bucket + storage_client = storage.Client( + credentials=credentials, + project=self.cloud_provider_credentials.gcp_project_id + ) + bucket = storage_client.bucket(self.cloud_provider_credentials.input_bucket) + blob = bucket.blob(compressed_file_name) + + blob.upload_from_filename(compressed_file_path) + + # Clean up compressed file + os.remove(compressed_file_path) + + return compressed_file_name + + except Exception as e: + raise Exception(f"Failed to upload file to GCP: {e}") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 88773701..eb23115f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "polars", "tqdm", "deprecation", + "google-cloud-storage", ] [project.urls] diff --git a/test/test_oetc.py b/test/test_oetc.py index 62b63244..a3156881 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -107,29 +107,263 @@ def test_successful_authentication(self, mock_datetime, mock_get, mock_post, moc timeout=30 ) - # Verify GCP credentials request - mock_get.assert_called_once_with( - "https://auth.example.com/users/user-uuid-123/gcp-credentials", - headers={ - "Authorization": f"Bearer {sample_jwt_token}", - "Content-Type": "application/json" - }, - timeout=30 - ) - # Verify AuthenticationResult - assert isinstance(handler.jwt, AuthenticationResult) - assert handler.jwt.token == sample_jwt_token - assert handler.jwt.token_type == "Bearer" - assert handler.jwt.expires_in == 3600 - assert handler.jwt.authenticated_at == fixed_time +class TestFileCompression: - # Verify GcpCredentials - assert isinstance(handler.cloud_provider_credentials, GcpCredentials) - assert handler.cloud_provider_credentials.gcp_project_id == "test-project-123" - assert handler.cloud_provider_credentials.gcp_service_key == "test-service-key-content" - assert handler.cloud_provider_credentials.input_bucket == "test-input-bucket" - assert handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" + @pytest.fixture + def handler_with_mocked_auth(self): + """Create handler with mocked authentication for testing file operations""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + authentication_server_url="https://auth.example.com", + compute_provider=ComputeProvider.GCP + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = Mock() + + return handler + + @patch('linopy.oetc.gzip.open') + @patch('linopy.oetc.os.path.exists') + @patch('builtins.open') + def test_gzip_compress_success(self, mock_open, mock_exists, mock_gzip_open, handler_with_mocked_auth): + """Test successful file compression""" + # Setup + source_path = "/tmp/test_file.nc" + expected_output = "/tmp/test_file.nc.gz" + + # Mock file operations + mock_exists.return_value = True + mock_file_in = Mock() + mock_file_out = Mock() + mock_open.return_value.__enter__.return_value = mock_file_in + mock_gzip_open.return_value.__enter__.return_value = mock_file_out + + # Mock file reading + mock_file_in.read.side_effect = [b"test_data_chunk", b""] # First read returns data, second returns empty + + # Execute + result = handler_with_mocked_auth._gzip_compress(source_path) + + # Verify + assert result == expected_output + mock_open.assert_called_once_with(source_path, "rb") + mock_gzip_open.assert_called_once_with(expected_output, "wb", compresslevel=9) + mock_file_out.write.assert_called_once_with(b"test_data_chunk") + + @patch('builtins.open') + def test_gzip_compress_file_read_error(self, mock_open, handler_with_mocked_auth): + """Test file compression with read error""" + # Setup + source_path = "/tmp/test_file.nc" + mock_open.side_effect = IOError("File not found") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._gzip_compress(source_path) + + assert "Failed to compress file" in str(exc_info.value) + + +class TestGcpUpload: + + @pytest.fixture + def handler_with_gcp_credentials(self, mock_gcp_credentials_response): + """Create handler with GCP credentials for testing upload""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + authentication_server_url="https://auth.example.com", + compute_provider=ComputeProvider.GCP + ) + + # Create proper GCP credentials + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket" + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds + + return handler + + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.os.path.basename') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_upload_file_to_gcp_success(self, mock_creds_from_info, mock_storage_client, mock_basename, mock_remove, + handler_with_gcp_credentials): + """Test successful file upload to GCP""" + # Setup + file_path = "/tmp/test_file.nc" + compressed_path = "/tmp/test_file.nc.gz" + compressed_name = "test_file.nc.gz" + + # Mock compression + with patch.object(handler_with_gcp_credentials, '_gzip_compress', return_value=compressed_path): + # Mock path operations + mock_basename.return_value = compressed_name + + # Mock GCP components + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_bucket.blob.return_value = mock_blob + + # Execute + result = handler_with_gcp_credentials._upload_file_to_gcp(file_path) + + # Verify + assert result == compressed_name + + # Verify GCP credentials creation + mock_creds_from_info.assert_called_once_with( + {"type": "service_account", "project_id": "test-project-123"}, + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + # Verify GCP client creation + mock_storage_client.assert_called_once_with( + credentials=mock_credentials, + project="test-project-123" + ) + + # Verify bucket access + mock_client.bucket.assert_called_once_with("test-input-bucket") + + # Verify blob operations + mock_bucket.blob.assert_called_once_with(compressed_name) + mock_blob.upload_from_filename.assert_called_once_with(compressed_path) + + # Verify cleanup + mock_remove.assert_called_once_with(compressed_path) + + @patch('linopy.oetc.json.loads') + def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads, handler_with_gcp_credentials): + """Test upload failure with invalid service key""" + # Setup + file_path = "/tmp/test_file.nc" + mock_json_loads.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._upload_file_to_gcp(file_path) + + assert "Failed to upload file to GCP" in str(exc_info.value) + + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info, mock_storage_client, + handler_with_gcp_credentials): + """Test upload failure during blob upload""" + # Setup + file_path = "/tmp/test_file.nc" + compressed_path = "/tmp/test_file.nc.gz" + + # Mock compression + with patch.object(handler_with_gcp_credentials, '_gzip_compress', return_value=compressed_path): + # Mock GCP setup + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_blob.upload_from_filename.side_effect = Exception("Upload failed") + mock_bucket.blob.return_value = mock_blob + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._upload_file_to_gcp(file_path) + + assert "Failed to upload file to GCP" in str(exc_info.value) + + +class TestSolveOnOetc: + + @pytest.fixture + def handler_with_complete_setup(self, mock_gcp_credentials_response): + """Create handler with complete setup for testing solve functionality""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + authentication_server_url="https://auth.example.com", + compute_provider=ComputeProvider.GCP + ) + + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket" + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds + + return handler + + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_file_upload(self, mock_tempfile, handler_with_complete_setup): + """Test solve_on_oetc method file upload flow""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock file upload + with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', + return_value="uploaded_file.nc.gz") as mock_upload: + # Execute + result = handler_with_complete_setup.solve_on_oetc(mock_model) + + # Verify + assert result == mock_model # Currently returns the input model + mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete_setup): + """Test solve_on_oetc method with upload failure""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock upload failure + with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', side_effect=Exception("Upload failed")): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_complete_setup.solve_on_oetc(mock_model) + + assert "Upload failed" in str(exc_info.value) @patch('linopy.oetc.requests.post') def test_authentication_http_error(self, mock_post, mock_settings): @@ -535,3 +769,4 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): }, timeout=30 ) + From 888cfb6dbcffba32f8d577acc2815740e4ca5423 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Tue, 10 Jun 2025 13:41:43 +0200 Subject: [PATCH 04/28] Implement job submission to oetc --- linopy/oetc.py | 66 +++- test/test_oetc.py | 840 +++++++++++++++++++++++++++++----------------- 2 files changed, 601 insertions(+), 305 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 4dd66018..fb63ba9e 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum from typing import Union @@ -27,8 +27,15 @@ class OetcCredentials: @dataclass class OetcSettings: credentials: OetcCredentials + name: str authentication_server_url: str + orchestrator_server_url: str compute_provider: ComputeProvider = ComputeProvider.GCP + solver: str = "highs" + solver_options: dict = field(default_factory=dict) + cpu_cores: int = 2 + disk_space_gb: int = 10 + delete_worker_on_error: bool = False @dataclass @@ -57,6 +64,11 @@ def is_expired(self) -> bool: return datetime.now() >= self.expires_at +@dataclass +class CreateComputeJobResult: + uuid: str + + class OetcHandler: def __init__(self, settings: OetcSettings) -> None: self.settings = settings @@ -182,6 +194,53 @@ def __get_gcp_credentials(self) -> GcpCredentials: except Exception as e: raise Exception(f"Error fetching GCP credentials: {e}") + def _submit_job_to_compute_service(self, input_file_name: str) -> CreateComputeJobResult: + """ + Submit a job to the compute service. + + Args: + input_file_name: Name of the input file uploaded to GCP + + Returns: + CreateComputeJobResult: The job creation result with UUID + + Raises: + Exception: If job submission fails + """ + try: + payload = { + "name": self.settings.name, + "solver": self.settings.solver, + "solver_options": self.settings.solver_options, + "provider": self.settings.compute_provider.value, + "cpu_cores": self.settings.cpu_cores, + "disk_space_gb": self.settings.disk_space_gb, + "input_file_name": input_file_name, + "delete_worker_on_error": self.settings.delete_worker_on_error + } + + response = requests.post( + f"{self.settings.orchestrator_server_url}/create", + json=payload, + headers={ + "Authorization": f"{self.jwt.token_type} {self.jwt.token}", + "Content-Type": "application/json" + }, + timeout=30 + ) + + response.raise_for_status() + job_result = response.json() + + return CreateComputeJobResult(uuid=job_result["uuid"]) + + except RequestException as e: + raise Exception(f"Failed to submit job to compute service: {e}") + except KeyError as e: + raise Exception(f"Invalid job submission response format: missing field {e}") + except Exception as e: + raise Exception(f"Error submitting job to compute service: {e}") + def solve_on_oetc(self, model): """ Solve a linopy model on the OET Cloud compute app. @@ -189,8 +248,6 @@ def solve_on_oetc(self, model): Parameters ---------- model : linopy.model.Model - **kwargs : - Keyword arguments passed to `linopy.model.Model.solve`. Returns ------- @@ -201,7 +258,8 @@ def solve_on_oetc(self, model): model.to_netcdf(fn.name) input_file_name = self._upload_file_to_gcp(fn.name) - # TODO: Submit job to compute service + job_result = self._submit_job_to_compute_service(input_file_name) + # TODO: Wait for job completion # TODO: Download result from GCP # TODO: Return solved model diff --git a/test/test_oetc.py b/test/test_oetc.py index a3156881..56df6ee7 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -9,7 +9,7 @@ from linopy.oetc import ( OetcHandler, OetcSettings, OetcCredentials, AuthenticationResult, - ComputeProvider, GcpCredentials + ComputeProvider, GcpCredentials, CreateComputeJobResult ) @@ -54,7 +54,9 @@ def mock_settings(): ) return OetcSettings( credentials=credentials, + name="Test Job", authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", compute_provider=ComputeProvider.GCP ) @@ -107,6 +109,346 @@ def test_successful_authentication(self, mock_datetime, mock_get, mock_post, moc timeout=30 ) + # Verify GCP credentials request + mock_get.assert_called_once_with( + "https://auth.example.com/users/user-uuid-123/gcp-credentials", + headers={ + "Authorization": f"Bearer {sample_jwt_token}", + "Content-Type": "application/json" + }, + timeout=30 + ) + + # Verify AuthenticationResult + assert isinstance(handler.jwt, AuthenticationResult) + assert handler.jwt.token == sample_jwt_token + assert handler.jwt.token_type == "Bearer" + assert handler.jwt.expires_in == 3600 + assert handler.jwt.authenticated_at == fixed_time + + # Verify GcpCredentials + assert isinstance(handler.cloud_provider_credentials, GcpCredentials) + assert handler.cloud_provider_credentials.gcp_project_id == "test-project-123" + assert handler.cloud_provider_credentials.gcp_service_key == "test-service-key-content" + assert handler.cloud_provider_credentials.input_bucket == "test-input-bucket" + assert handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" + + @patch('linopy.oetc.requests.post') + def test_authentication_http_error(self, mock_post, mock_settings): + """Test authentication failure with HTTP error""" + # Setup mock to raise HTTP error + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("401 Unauthorized") + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + +class TestJwtDecoding: + + @pytest.fixture + def handler_with_mocked_auth(self): + """Create handler with mocked authentication for testing JWT decoding""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP + ) + + # Mock the authentication and credentials fetching + mock_auth_result = AuthenticationResult( + token="fake.token.here", + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now() + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = mock_auth_result + handler.cloud_provider_credentials = None + + return handler + + def test_decode_jwt_payload_success(self, handler_with_mocked_auth, sample_jwt_token): + """Test successful JWT payload decoding""" + result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) + + assert result["iss"] == "OETC" + assert result["sub"] == "user-uuid-123" + assert result["email"] == "test@example.com" + assert result["firstname"] == "Test" + assert result["lastname"] == "User" + + def test_decode_jwt_payload_invalid_token(self, handler_with_mocked_auth): + """Test JWT payload decoding with invalid token""" + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._decode_jwt_payload("invalid.token") + + assert "Failed to decode JWT payload" in str(exc_info.value) + + def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth): + """Test JWT payload decoding with malformed token""" + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._decode_jwt_payload("not_a_jwt_token") + + assert "Failed to decode JWT payload" in str(exc_info.value) + + +class TestCloudProviderCredentials: + + @pytest.fixture + def handler_with_mocked_auth(self, sample_jwt_token): + """Create handler with mocked authentication for testing credentials fetching""" + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP + ) + + # Mock the authentication result + mock_auth_result = AuthenticationResult( + token=sample_jwt_token, + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now() + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = mock_auth_result + + return handler + + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_success(self, mock_get, handler_with_mocked_auth, mock_gcp_credentials_response): + """Test successful GCP credentials fetching""" + # Setup mock + mock_response = Mock() + mock_response.json.return_value = mock_gcp_credentials_response + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + # Execute + result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + # Verify request + mock_get.assert_called_once_with( + "https://auth.example.com/users/user-uuid-123/gcp-credentials", + headers={ + "Authorization": f"Bearer {handler_with_mocked_auth.jwt.token}", + "Content-Type": "application/json" + }, + timeout=30 + ) + + # Verify result + assert isinstance(result, GcpCredentials) + assert result.gcp_project_id == "test-project-123" + assert result.gcp_service_key == "test-service-key-content" + assert result.input_bucket == "test-input-bucket" + assert result.solution_bucket == "test-solution-bucket" + + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth): + """Test GCP credentials fetching with HTTP error""" + # Setup mock to raise HTTP error + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("403 Forbidden") + mock_get.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + assert "Failed to fetch GCP credentials" in str(exc_info.value) + + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_missing_field(self, mock_get, handler_with_mocked_auth): + """Test GCP credentials fetching with missing response field""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "gcp_project_id": "test-project-123", + "gcp_service_key": "test-service-key-content", + "input_bucket": "test-input-bucket" + # Missing "solution_bucket" field + } + mock_response.raise_for_status.return_value = None + mock_get.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + assert "Invalid credentials response format: missing field 'solution_bucket'" in str(exc_info.value) + + def test_get_cloud_provider_credentials_unsupported_provider(self, handler_with_mocked_auth): + """Test cloud provider credentials with unsupported provider""" + # Change to unsupported provider + handler_with_mocked_auth.settings.compute_provider = "AWS" # Not in enum, but for testing + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() + + assert "Unsupported compute provider: AWS" in str(exc_info.value) + + def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_auth): + """Test GCP credentials fetching when JWT token has no user UUID""" + # Create token without 'sub' field + payload = {"iss": "OETC", "email": "test@example.com"} + payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + token_without_sub = f"header.{payload_encoded}.signature" + + handler_with_mocked_auth.jwt.token = token_without_sub + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + + assert "User UUID not found in JWT token" in str(exc_info.value) + + +class TestGcpCredentials: + + def test_gcp_credentials_creation(self): + """Test GcpCredentials dataclass creation""" + credentials = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key="test-service-key-content", + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket" + ) + + assert credentials.gcp_project_id == "test-project-123" + assert credentials.gcp_service_key == "test-service-key-content" + assert credentials.input_bucket == "test-input-bucket" + assert credentials.solution_bucket == "test-solution-bucket" + + +class TestComputeProvider: + + def test_compute_provider_enum(self): + """Test ComputeProvider enum values""" + assert ComputeProvider.GCP == "GCP" + assert ComputeProvider.GCP.value == "GCP" + + @patch('linopy.oetc.requests.post') + def test_authentication_network_error(self, mock_post, mock_settings): + """Test authentication failure with network error""" + # Setup mock to raise network error + mock_post.side_effect = RequestException("Connection timeout") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + @patch('linopy.oetc.requests.post') + def test_authentication_invalid_response_missing_token(self, mock_post, mock_settings): + """Test authentication failure with missing token in response""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "token_type": "Bearer", + "expires_in": 3600 + # Missing "token" field + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Invalid response format: missing field 'token'" in str(exc_info.value) + + @patch('linopy.oetc.requests.post') + def test_authentication_invalid_response_missing_expires_in(self, mock_post, mock_settings): + """Test authentication failure with missing expires_in in response""" + # Setup mock with invalid response + mock_response = Mock() + mock_response.json.return_value = { + "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "Bearer" + # Missing "expires_in" field + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Invalid response format: missing field 'expires_in'" in str(exc_info.value) + + @patch('linopy.oetc.requests.post') + def test_authentication_timeout_error(self, mock_post, mock_settings): + """Test authentication failure with timeout""" + # Setup mock to raise timeout error + mock_post.side_effect = requests.Timeout("Request timeout") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + OetcHandler(mock_settings) + + assert "Authentication request failed" in str(exc_info.value) + + +class TestAuthenticationResult: + + @pytest.fixture + def auth_result(self): + """Create an AuthenticationResult for testing""" + return AuthenticationResult( + token="test_token", + token_type="Bearer", + expires_in=3600, # 1 hour + authenticated_at=datetime(2024, 1, 15, 12, 0, 0) + ) + + def test_expires_at_calculation(self, auth_result): + """Test that expires_at correctly calculates expiration time""" + expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later + assert auth_result.expires_at == expected_expiry + + @patch('linopy.oetc.datetime') + def test_is_expired_false_when_not_expired(self, mock_datetime, auth_result): + """Test is_expired returns False when token is still valid""" + # Set current time to before expiration + mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 30, 0) + + assert auth_result.is_expired is False + + @patch('linopy.oetc.datetime') + def test_is_expired_true_when_expired(self, mock_datetime, auth_result): + """Test is_expired returns True when token has expired""" + # Set current time to after expiration + mock_datetime.now.return_value = datetime(2024, 1, 15, 14, 0, 0) + + assert auth_result.is_expired is True + + @patch('linopy.oetc.datetime') + def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): + """Test is_expired returns True when token expires exactly now""" + # Set current time to exact expiration time + mock_datetime.now.return_value = datetime(2024, 1, 15, 13, 0, 0) + + assert auth_result.is_expired is True + class TestFileCompression: @@ -117,7 +459,9 @@ def handler_with_mocked_auth(self): credentials = OetcCredentials(email="test@example.com", password="test_password") settings = OetcSettings( credentials=credentials, + name="Test Job", authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", compute_provider=ComputeProvider.GCP ) @@ -179,7 +523,9 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response): credentials = OetcCredentials(email="test@example.com", password="test_password") settings = OetcSettings( credentials=credentials, + name="Test Job", authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", compute_provider=ComputeProvider.GCP ) @@ -295,154 +641,27 @@ def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info, mock_storag mock_bucket.blob.return_value = mock_blob # Execute and verify exception - with pytest.raises(Exception) as exc_info: - handler_with_gcp_credentials._upload_file_to_gcp(file_path) - - assert "Failed to upload file to GCP" in str(exc_info.value) - - -class TestSolveOnOetc: - - @pytest.fixture - def handler_with_complete_setup(self, mock_gcp_credentials_response): - """Create handler with complete setup for testing solve functionality""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") - settings = OetcSettings( - credentials=credentials, - authentication_server_url="https://auth.example.com", - compute_provider=ComputeProvider.GCP - ) - - gcp_creds = GcpCredentials( - gcp_project_id="test-project-123", - gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', - input_bucket="test-input-bucket", - solution_bucket="test-solution-bucket" - ) - - handler = OetcHandler.__new__(OetcHandler) - handler.settings = settings - handler.jwt = Mock() - handler.cloud_provider_credentials = gcp_creds - - return handler - - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_file_upload(self, mock_tempfile, handler_with_complete_setup): - """Test solve_on_oetc method file upload flow""" - # Setup - mock_model = Mock() - mock_temp_file = Mock() - mock_temp_file.name = "/tmp/linopy-abc123.nc" - mock_tempfile.return_value.__enter__.return_value = mock_temp_file - - # Mock file upload - with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', - return_value="uploaded_file.nc.gz") as mock_upload: - # Execute - result = handler_with_complete_setup.solve_on_oetc(mock_model) - - # Verify - assert result == mock_model # Currently returns the input model - mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") - - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete_setup): - """Test solve_on_oetc method with upload failure""" - # Setup - mock_model = Mock() - mock_temp_file = Mock() - mock_temp_file.name = "/tmp/linopy-abc123.nc" - mock_tempfile.return_value.__enter__.return_value = mock_temp_file - - # Mock upload failure - with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', side_effect=Exception("Upload failed")): - # Execute and verify exception - with pytest.raises(Exception) as exc_info: - handler_with_complete_setup.solve_on_oetc(mock_model) - - assert "Upload failed" in str(exc_info.value) - - @patch('linopy.oetc.requests.post') - def test_authentication_http_error(self, mock_post, mock_settings): - """Test authentication failure with HTTP error""" - # Setup mock to raise HTTP error - mock_response = Mock() - mock_response.raise_for_status.side_effect = requests.HTTPError("401 Unauthorized") - mock_post.return_value = mock_response - - # Execute and verify exception - with pytest.raises(Exception) as exc_info: - OetcHandler(mock_settings) - - assert "Authentication request failed" in str(exc_info.value) - - -class TestJwtDecoding: - - @pytest.fixture - def handler_with_mocked_auth(self): - """Create handler with mocked authentication for testing JWT decoding""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") - settings = OetcSettings( - credentials=credentials, - authentication_server_url="https://auth.example.com", - compute_provider=ComputeProvider.GCP - ) - - # Mock the authentication and credentials fetching - mock_auth_result = AuthenticationResult( - token="fake.token.here", - token_type="Bearer", - expires_in=3600, - authenticated_at=datetime.now() - ) - - handler = OetcHandler.__new__(OetcHandler) - handler.settings = settings - handler.jwt = mock_auth_result - handler.cloud_provider_credentials = None - - return handler - - def test_decode_jwt_payload_success(self, handler_with_mocked_auth, sample_jwt_token): - """Test successful JWT payload decoding""" - result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) - - assert result["iss"] == "OETC" - assert result["sub"] == "user-uuid-123" - assert result["email"] == "test@example.com" - assert result["firstname"] == "Test" - assert result["lastname"] == "User" - - def test_decode_jwt_payload_invalid_token(self, handler_with_mocked_auth): - """Test JWT payload decoding with invalid token""" - with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._decode_jwt_payload("invalid.token") - - assert "Failed to decode JWT payload" in str(exc_info.value) - - def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth): - """Test JWT payload decoding with malformed token""" - with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._decode_jwt_payload("not_a_jwt_token") + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._upload_file_to_gcp(file_path) - assert "Failed to decode JWT payload" in str(exc_info.value) + assert "Failed to upload file to GCP" in str(exc_info.value) -class TestCloudProviderCredentials: +class TestJobSubmission: @pytest.fixture - def handler_with_mocked_auth(self, sample_jwt_token): - """Create handler with mocked authentication for testing credentials fetching""" + def handler_with_auth_setup(self, sample_jwt_token): + """Create handler with authentication setup for testing job submission""" credentials = OetcCredentials(email="test@example.com", password="test_password") settings = OetcSettings( credentials=credentials, + name="Test Optimization Job", authentication_server_url="https://auth.example.com", - compute_provider=ComputeProvider.GCP + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + solver="gurobi", + cpu_cores=4, + disk_space_gb=20 ) # Mock the authentication result @@ -456,226 +675,244 @@ def handler_with_mocked_auth(self, sample_jwt_token): handler = OetcHandler.__new__(OetcHandler) handler.settings = settings handler.jwt = mock_auth_result + handler.cloud_provider_credentials = Mock() return handler - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_success(self, mock_get, handler_with_mocked_auth, mock_gcp_credentials_response): - """Test successful GCP credentials fetching""" - # Setup mock + @patch('linopy.oetc.requests.post') + def test_submit_job_success(self, mock_post, handler_with_auth_setup): + """Test successful job submission to compute service""" + # Setup + input_file_name = "test_model.nc.gz" + expected_job_uuid = "job-uuid-123" + + # Mock successful response mock_response = Mock() - mock_response.json.return_value = mock_gcp_credentials_response + mock_response.json.return_value = {"uuid": expected_job_uuid} mock_response.raise_for_status.return_value = None - mock_get.return_value = mock_response + mock_post.return_value = mock_response # Execute - result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + result = handler_with_auth_setup._submit_job_to_compute_service(input_file_name) # Verify request - mock_get.assert_called_once_with( - "https://auth.example.com/users/user-uuid-123/gcp-credentials", + expected_payload = { + "name": "Test Optimization Job", + "solver": "gurobi", + "solver_options": {}, + "provider": "GCP", + "cpu_cores": 4, + "disk_space_gb": 20, + "input_file_name": input_file_name, + "delete_worker_on_error": False + } + + mock_post.assert_called_once_with( + "https://orchestrator.example.com/create", + json=expected_payload, headers={ - "Authorization": f"Bearer {handler_with_mocked_auth.jwt.token}", + "Authorization": f"Bearer {handler_with_auth_setup.jwt.token}", "Content-Type": "application/json" }, timeout=30 ) # Verify result - assert isinstance(result, GcpCredentials) - assert result.gcp_project_id == "test-project-123" - assert result.gcp_service_key == "test-service-key-content" - assert result.input_bucket == "test-input-bucket" - assert result.solution_bucket == "test-solution-bucket" + assert isinstance(result, CreateComputeJobResult) + assert result.uuid == expected_job_uuid - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth): - """Test GCP credentials fetching with HTTP error""" - # Setup mock to raise HTTP error + @patch('linopy.oetc.requests.post') + def test_submit_job_http_error(self, mock_post, handler_with_auth_setup): + """Test job submission with HTTP error""" + # Setup + input_file_name = "test_model.nc.gz" mock_response = Mock() - mock_response.raise_for_status.side_effect = requests.HTTPError("403 Forbidden") - mock_get.return_value = mock_response + mock_response.raise_for_status.side_effect = requests.HTTPError("400 Bad Request") + mock_post.return_value = mock_response # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + handler_with_auth_setup._submit_job_to_compute_service(input_file_name) - assert "Failed to fetch GCP credentials" in str(exc_info.value) + assert "Failed to submit job to compute service" in str(exc_info.value) - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_missing_field(self, mock_get, handler_with_mocked_auth): - """Test GCP credentials fetching with missing response field""" - # Setup mock with invalid response + @patch('linopy.oetc.requests.post') + def test_submit_job_missing_uuid_in_response(self, mock_post, handler_with_auth_setup): + """Test job submission with missing UUID in response""" + # Setup + input_file_name = "test_model.nc.gz" mock_response = Mock() - mock_response.json.return_value = { - "gcp_project_id": "test-project-123", - "gcp_service_key": "test-service-key-content", - "input_bucket": "test-input-bucket" - # Missing "solution_bucket" field - } + mock_response.json.return_value = {} # Missing "uuid" field mock_response.raise_for_status.return_value = None - mock_get.return_value = mock_response - - # Execute and verify exception - with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() - - assert "Invalid credentials response format: missing field 'solution_bucket'" in str(exc_info.value) - - def test_get_cloud_provider_credentials_unsupported_provider(self, handler_with_mocked_auth): - """Test cloud provider credentials with unsupported provider""" - # Change to unsupported provider - handler_with_mocked_auth.settings.compute_provider = "AWS" # Not in enum, but for testing + mock_post.return_value = mock_response # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() - - assert "Unsupported compute provider: AWS" in str(exc_info.value) + handler_with_auth_setup._submit_job_to_compute_service(input_file_name) - def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_auth): - """Test GCP credentials fetching when JWT token has no user UUID""" - # Create token without 'sub' field - payload = {"iss": "OETC", "email": "test@example.com"} - payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') - token_without_sub = f"header.{payload_encoded}.signature" + assert "Invalid job submission response format: missing field 'uuid'" in str(exc_info.value) - handler_with_mocked_auth.jwt.token = token_without_sub + @patch('linopy.oetc.requests.post') + def test_submit_job_network_error(self, mock_post, handler_with_auth_setup): + """Test job submission with network error""" + # Setup + input_file_name = "test_model.nc.gz" + mock_post.side_effect = RequestException("Connection timeout") # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + handler_with_auth_setup._submit_job_to_compute_service(input_file_name) - assert "User UUID not found in JWT token" in str(exc_info.value) - - -class TestGcpCredentials: - - def test_gcp_credentials_creation(self): - """Test GcpCredentials dataclass creation""" - credentials = GcpCredentials( - gcp_project_id="test-project-123", - gcp_service_key="test-service-key-content", - input_bucket="test-input-bucket", - solution_bucket="test-solution-bucket" - ) + assert "Failed to submit job to compute service" in str(exc_info.value) - assert credentials.gcp_project_id == "test-project-123" - assert credentials.gcp_service_key == "test-service-key-content" - assert credentials.input_bucket == "test-input-bucket" - assert credentials.solution_bucket == "test-solution-bucket" +class TestSolveOnOetc: -class TestComputeProvider: + @pytest.fixture + def handler_with_complete_setup(self, mock_gcp_credentials_response): + """Create handler with complete setup for testing solve functionality""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP + ) - def test_compute_provider_enum(self): - """Test ComputeProvider enum values""" - assert ComputeProvider.GCP == "GCP" - assert ComputeProvider.GCP.value == "GCP" + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket" + ) - @patch('linopy.oetc.requests.post') - def test_authentication_network_error(self, mock_post, mock_settings): - """Test authentication failure with network error""" - # Setup mock to raise network error - mock_post.side_effect = RequestException("Connection timeout") + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds - # Execute and verify exception - with pytest.raises(Exception) as exc_info: - OetcHandler(mock_settings) + return handler - assert "Authentication request failed" in str(exc_info.value) + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_file_upload(self, mock_tempfile, handler_with_complete_setup): + """Test solve_on_oetc method file upload flow""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file - @patch('linopy.oetc.requests.post') - def test_authentication_invalid_response_missing_token(self, mock_post, mock_settings): - """Test authentication failure with missing token in response""" - # Setup mock with invalid response - mock_response = Mock() - mock_response.json.return_value = { - "token_type": "Bearer", - "expires_in": 3600 - # Missing "token" field - } - mock_response.raise_for_status.return_value = None - mock_post.return_value = mock_response + # Mock file upload and job submission + with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', + return_value="uploaded_file.nc.gz") as mock_upload: + with patch.object(handler_with_complete_setup, '_submit_job_to_compute_service', + return_value=CreateComputeJobResult(uuid="test-job-uuid")) as mock_submit: + # Execute + result = handler_with_complete_setup.solve_on_oetc(mock_model) - # Execute and verify exception - with pytest.raises(Exception) as exc_info: - OetcHandler(mock_settings) + # Verify + assert result == mock_model # Currently returns the input model + mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with("uploaded_file.nc.gz") - assert "Invalid response format: missing field 'token'" in str(exc_info.value) + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete_setup): + """Test solve_on_oetc method with upload failure""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file - @patch('linopy.oetc.requests.post') - def test_authentication_invalid_response_missing_expires_in(self, mock_post, mock_settings): - """Test authentication failure with missing expires_in in response""" - # Setup mock with invalid response - mock_response = Mock() - mock_response.json.return_value = { - "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", - "token_type": "Bearer" - # Missing "expires_in" field - } - mock_response.raise_for_status.return_value = None - mock_post.return_value = mock_response + # Mock upload failure + with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', side_effect=Exception("Upload failed")): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_complete_setup.solve_on_oetc(mock_model) - # Execute and verify exception - with pytest.raises(Exception) as exc_info: - OetcHandler(mock_settings) + assert "Upload failed" in str(exc_info.value) - assert "Invalid response format: missing field 'expires_in'" in str(exc_info.value) - @patch('linopy.oetc.requests.post') - def test_authentication_timeout_error(self, mock_post, mock_settings): - """Test authentication failure with timeout""" - # Setup mock to raise timeout error - mock_post.side_effect = requests.Timeout("Request timeout") +class TestSolveOnOetcWithJobSubmission: - # Execute and verify exception - with pytest.raises(Exception) as exc_info: - OetcHandler(mock_settings) + @pytest.fixture + def handler_with_full_setup(self): + """Create handler with full setup for testing complete solve flow""" + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + name="Linopy Solve Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + solver="highs", + cpu_cores=2, + disk_space_gb=15 + ) - assert "Authentication request failed" in str(exc_info.value) + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket" + ) + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds -class TestAuthenticationResult: + return handler - @pytest.fixture - def auth_result(self): - """Create an AuthenticationResult for testing""" - return AuthenticationResult( - token="test_token", - token_type="Bearer", - expires_in=3600, # 1 hour - authenticated_at=datetime(2024, 1, 15, 12, 0, 0) - ) + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_with_job_submission(self, mock_tempfile, handler_with_full_setup): + """Test solve_on_oetc method including job submission""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file - def test_expires_at_calculation(self, auth_result): - """Test that expires_at correctly calculates expiration time""" - expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later - assert auth_result.expires_at == expected_expiry + uploaded_file_name = "model_file.nc.gz" + job_uuid = "job-uuid-456" - @patch('linopy.oetc.datetime') - def test_is_expired_false_when_not_expired(self, mock_datetime, auth_result): - """Test is_expired returns False when token is still valid""" - # Set current time to before expiration - mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 30, 0) + # Mock file upload and job submission + with patch.object(handler_with_full_setup, '_upload_file_to_gcp', + return_value=uploaded_file_name) as mock_upload: + with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', + return_value=CreateComputeJobResult(uuid=job_uuid)) as mock_submit: + # Execute + result = handler_with_full_setup.solve_on_oetc(mock_model) - assert auth_result.is_expired is False + # Verify + assert result == mock_model # Currently returns the input model + mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with(uploaded_file_name) - @patch('linopy.oetc.datetime') - def test_is_expired_true_when_expired(self, mock_datetime, auth_result): - """Test is_expired returns True when token has expired""" - # Set current time to after expiration - mock_datetime.now.return_value = datetime(2024, 1, 15, 14, 0, 0) + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_job_submission_failure(self, mock_tempfile, handler_with_full_setup): + """Test solve_on_oetc method with job submission failure""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file - assert auth_result.is_expired is True + uploaded_file_name = "model_file.nc.gz" - @patch('linopy.oetc.datetime') - def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): - """Test is_expired returns True when token expires exactly now""" - # Set current time to exact expiration time - mock_datetime.now.return_value = datetime(2024, 1, 15, 13, 0, 0) + # Mock successful upload but failed job submission + with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name): + with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', + side_effect=Exception("Job submission failed")): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_full_setup.solve_on_oetc(mock_model) - assert auth_result.is_expired is True + assert "Job submission failed" in str(exc_info.value) # Additional integration-style test @@ -696,7 +933,9 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): ) settings = OetcSettings( credentials=credentials, + name="Integration Test Job", authentication_server_url="https://api.company.com/auth", + orchestrator_server_url="https://api.company.com/orchestrator", compute_provider=ComputeProvider.GCP ) @@ -769,4 +1008,3 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): }, timeout=30 ) - From a6585facc370c7f419ce436ad04429cb96e57910 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Tue, 10 Jun 2025 15:22:40 +0200 Subject: [PATCH 05/28] Implement waiting for OETC job completion --- linopy/oetc.py | 132 +++++++++++++++++++++++++++++++++++++++++++--- test/test_oetc.py | 81 +++++++++++++++++++--------- 2 files changed, 181 insertions(+), 32 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index fb63ba9e..7dffe095 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -1,3 +1,4 @@ +import time from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum @@ -65,8 +66,17 @@ def is_expired(self) -> bool: @dataclass -class CreateComputeJobResult: +class JobResult: uuid: str + status: str + name: str = None + owner: str = None + solver: str = None + duration_in_seconds: int = None + solving_duration_in_seconds: int = None + input_files: list = None + output_files: list = None + created_at: str = None class OetcHandler: @@ -194,7 +204,7 @@ def __get_gcp_credentials(self) -> GcpCredentials: except Exception as e: raise Exception(f"Error fetching GCP credentials: {e}") - def _submit_job_to_compute_service(self, input_file_name: str) -> CreateComputeJobResult: + def _submit_job_to_compute_service(self, input_file_name: str) -> str: """ Submit a job to the compute service. @@ -232,7 +242,7 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> CreateComputeJ response.raise_for_status() job_result = response.json() - return CreateComputeJobResult(uuid=job_result["uuid"]) + return job_result["uuid"] except RequestException as e: raise Exception(f"Failed to submit job to compute service: {e}") @@ -241,6 +251,114 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> CreateComputeJ except Exception as e: raise Exception(f"Error submitting job to compute service: {e}") + def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, + max_poll_interval: int = 300) -> JobResult: + """ + Wait for job completion and get job data by polling the orchestrator service. + + This method will poll indefinitely until the job finishes (FINISHED) or + encounters an error (SETUP_ERROR, RUNTIME_ERROR). + + Args: + job_uuid: UUID of the job to wait for + initial_poll_interval: Initial polling interval in seconds (default: 30) + max_poll_interval: Maximum polling interval in seconds (default: 300) + + Returns: + JobResult: The job result when complete + + Raises: + Exception: If job encounters errors or network requests consistently fail + """ + poll_interval = initial_poll_interval + consecutive_failures = 0 + max_network_retries = 10 + + print(f"Waiting for job {job_uuid} to complete...") + + while True: + try: + response = requests.get( + f"{self.settings.orchestrator_server_url}/jobs/{job_uuid}", + headers={ + "Authorization": f"{self.jwt.token_type} {self.jwt.token}", + "Content-Type": "application/json" + }, + timeout=30 + ) + + response.raise_for_status() + job_data_dict = response.json() + + job_result = JobResult( + uuid=job_data_dict["uuid"], + status=job_data_dict["status"], + name=job_data_dict.get("name"), + owner=job_data_dict.get("owner"), + solver=job_data_dict.get("solver"), + duration_in_seconds=job_data_dict.get("duration_in_seconds"), + solving_duration_in_seconds=job_data_dict.get("solving_duration_in_seconds"), + input_files=job_data_dict.get("input_files", []), + output_files=job_data_dict.get("output_files", []), + created_at=job_data_dict.get("created_at") + ) + + consecutive_failures = 0 + + if job_result.status == "FINISHED": + print(f"Job {job_uuid} completed successfully!") + if not job_result.output_files: + print("Warning: Job completed but no output files found") + return job_result + + elif job_result.status == "SETUP_ERROR": + error_msg = f"Job failed during setup phase (status: {job_result.status}). Please check the OETC logs for details." + print(f"Error: {error_msg}") + raise Exception(error_msg) + + elif job_result.status == "RUNTIME_ERROR": + error_msg = f"Job failed during execution (status: {job_result.status}). Please check the OETC logs for details." + print(f"Error: {error_msg}") + raise Exception(error_msg) + + elif job_result.status in ["PENDING", "STARTING", "RUNNING"]: + status_msg = f"Job {job_uuid} status: {job_result.status}" + if job_result.duration_in_seconds: + status_msg += f" (running for {job_result.duration_in_seconds}s)" + status_msg += f", checking again in {poll_interval} seconds..." + print(status_msg) + + time.sleep(poll_interval) + + # Exponential backoff for polling interval, capped at max_poll_interval + poll_interval = min(int(poll_interval * 1.5), max_poll_interval) + + else: + # Unknown status + error_msg = f"Unknown job status: {job_result.status}. Please check the OETC logs for details." + print(f"Error: {error_msg}") + raise Exception(error_msg) + + except RequestException as e: + consecutive_failures += 1 + + if consecutive_failures >= max_network_retries: + raise Exception(f"Failed to get job status after {max_network_retries} network retries: {e}") + + # Wait before retrying network request + retry_wait = min(consecutive_failures * 10, 60) + print(f"Network error getting job status (attempt {consecutive_failures}/{max_network_retries}), " + f"retrying in {retry_wait} seconds: {e}") + time.sleep(retry_wait) + + except KeyError as e: + raise Exception(f"Invalid job status response format: missing field {e}") + except Exception as e: + if "status:" in str(e) or "OETC logs" in str(e): + raise + else: + raise Exception(f"Error getting job status: {e}") + def solve_on_oetc(self, model): """ Solve a linopy model on the OET Cloud compute app. @@ -258,11 +376,11 @@ def solve_on_oetc(self, model): model.to_netcdf(fn.name) input_file_name = self._upload_file_to_gcp(fn.name) - job_result = self._submit_job_to_compute_service(input_file_name) + job_uuid = self._submit_job_to_compute_service(input_file_name) + job_result = self.wait_and_get_job_data(job_uuid) - # TODO: Wait for job completion - # TODO: Download result from GCP - # TODO: Return solved model + # TODO: Download result from GCP using job_data.output_files + # TODO: Load result into model and return solved model return model diff --git a/test/test_oetc.py b/test/test_oetc.py index 56df6ee7..28671bbb 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -9,7 +9,7 @@ from linopy.oetc import ( OetcHandler, OetcSettings, OetcCredentials, AuthenticationResult, - ComputeProvider, GcpCredentials, CreateComputeJobResult + ComputeProvider, GcpCredentials, JobResult ) @@ -718,8 +718,7 @@ def test_submit_job_success(self, mock_post, handler_with_auth_setup): ) # Verify result - assert isinstance(result, CreateComputeJobResult) - assert result.uuid == expected_job_uuid + assert result == expected_job_uuid @patch('linopy.oetc.requests.post') def test_submit_job_http_error(self, mock_post, handler_with_auth_setup): @@ -804,19 +803,23 @@ def test_solve_on_oetc_file_upload(self, mock_tempfile, handler_with_complete_se mock_temp_file.name = "/tmp/linopy-abc123.nc" mock_tempfile.return_value.__enter__.return_value = mock_temp_file - # Mock file upload and job submission + # Mock file upload, job submission, and job waiting with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', return_value="uploaded_file.nc.gz") as mock_upload: with patch.object(handler_with_complete_setup, '_submit_job_to_compute_service', - return_value=CreateComputeJobResult(uuid="test-job-uuid")) as mock_submit: - # Execute - result = handler_with_complete_setup.solve_on_oetc(mock_model) - - # Verify - assert result == mock_model # Currently returns the input model - mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_submit.assert_called_once_with("uploaded_file.nc.gz") + return_value="test-job-uuid") as mock_submit: + with patch.object(handler_with_complete_setup, 'wait_and_get_job_data', + return_value=JobResult(uuid="test-job-uuid", status="FINISHED", + output_files=[{"name": "result.nc.gz"}])) as mock_wait: + # Execute + result = handler_with_complete_setup.solve_on_oetc(mock_model) + + # Verify + assert result == mock_model # Currently returns the input model + mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with("uploaded_file.nc.gz") + mock_wait.assert_called_once_with("test-job-uuid") @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete_setup): @@ -869,7 +872,7 @@ def handler_with_full_setup(self): @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_with_job_submission(self, mock_tempfile, handler_with_full_setup): - """Test solve_on_oetc method including job submission""" + """Test solve_on_oetc method including job submission and waiting""" # Setup mock_model = Mock() mock_temp_file = Mock() @@ -879,19 +882,23 @@ def test_solve_on_oetc_with_job_submission(self, mock_tempfile, handler_with_ful uploaded_file_name = "model_file.nc.gz" job_uuid = "job-uuid-456" - # Mock file upload and job submission + # Mock file upload, job submission, and job waiting with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name) as mock_upload: with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', - return_value=CreateComputeJobResult(uuid=job_uuid)) as mock_submit: - # Execute - result = handler_with_full_setup.solve_on_oetc(mock_model) - - # Verify - assert result == mock_model # Currently returns the input model - mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_submit.assert_called_once_with(uploaded_file_name) + return_value=job_uuid) as mock_submit: + with patch.object(handler_with_full_setup, 'wait_and_get_job_data', + return_value=JobResult(uuid=job_uuid, status="FINISHED", + output_files=[{"name": "result.nc.gz"}])) as mock_wait: + # Execute + result = handler_with_full_setup.solve_on_oetc(mock_model) + + # Verify + assert result == mock_model # Currently returns the input model + mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with(uploaded_file_name) + mock_wait.assert_called_once_with(job_uuid) @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_job_submission_failure(self, mock_tempfile, handler_with_full_setup): @@ -904,7 +911,7 @@ def test_solve_on_oetc_job_submission_failure(self, mock_tempfile, handler_with_ uploaded_file_name = "model_file.nc.gz" - # Mock successful upload but failed job submission + # Mock successful upload but failed job submission - no need to mock waiting since submission fails with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name): with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', side_effect=Exception("Job submission failed")): @@ -914,6 +921,30 @@ def test_solve_on_oetc_job_submission_failure(self, mock_tempfile, handler_with_ assert "Job submission failed" in str(exc_info.value) + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile, handler_with_full_setup): + """Test solve_on_oetc method with job waiting failure""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + uploaded_file_name = "model_file.nc.gz" + job_uuid = "job-uuid-failed" + + # Mock successful upload and job submission but failed job waiting + with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name): + with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', + return_value=job_uuid): + with patch.object(handler_with_full_setup, 'wait_and_get_job_data', + side_effect=Exception("Job failed: solver error")): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_full_setup.solve_on_oetc(mock_model) + + assert "Job failed: solver error" in str(exc_info.value) + # Additional integration-style test class TestOetcHandlerIntegration: From f966fd6b3cffd1f75d3c878d41b78c8c53548abc Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Wed, 11 Jun 2025 14:03:15 +0200 Subject: [PATCH 06/28] Implement solution download and decompression --- linopy/oetc.py | 114 ++++++++++++- test/test_oetc.py | 399 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 483 insertions(+), 30 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 7dffe095..1a7d65df 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -13,6 +13,7 @@ from requests import RequestException from google.cloud import storage from google.oauth2 import service_account +import linopy class ComputeProvider(str, Enum): @@ -359,6 +360,81 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, else: raise Exception(f"Error getting job status: {e}") + def _gzip_decompress(self, input_path: str) -> str: + """ + Decompress a gzip-compressed file. + + Args: + input_path: Path to the compressed file + + Returns: + str: Path to the decompressed file + + Raises: + Exception: If decompression fails + """ + try: + output_path = input_path[:-3] + chunk_size = 1024 * 1024 + + with gzip.open(input_path, "rb") as f_in: + with open(output_path, "wb") as f_out: + while True: + chunk = f_in.read(chunk_size) + if not chunk: + break + f_out.write(chunk) + + return output_path + except Exception as e: + raise Exception(f"Failed to decompress file: {e}") + + def _download_file_from_gcp(self, file_name: str) -> str: + """ + Download a file from GCP storage bucket. + + Args: + file_name: Name of the file to download from the solution bucket + + Returns: + str: Path to the downloaded and decompressed file + + Raises: + Exception: If download or decompression fails + """ + try: + # Create GCP credentials from service key + service_key_dict = json.loads(self.cloud_provider_credentials.gcp_service_key) + credentials = service_account.Credentials.from_service_account_info( + service_key_dict, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + # Download from GCP solution bucket + storage_client = storage.Client( + credentials=credentials, + project=self.cloud_provider_credentials.gcp_project_id + ) + bucket = storage_client.bucket(self.cloud_provider_credentials.solution_bucket) + blob = bucket.blob(file_name) + + # Create temporary file for download + with tempfile.NamedTemporaryFile(delete=False, suffix=".gz") as temp_file: + compressed_file_path = temp_file.name + + blob.download_to_filename(compressed_file_path) + + # Decompress the downloaded file + decompressed_file_path = self._gzip_decompress(compressed_file_path) + + # Clean up compressed file + os.remove(compressed_file_path) + + return decompressed_file_path + + except Exception as e: + raise Exception(f"Failed to download file from GCP: {e}") + def solve_on_oetc(self, model): """ Solve a linopy model on the OET Cloud compute app. @@ -371,18 +447,44 @@ def solve_on_oetc(self, model): ------- linopy.model.Model Solved model. + + Raises: + Exception: If solving fails at any stage """ - with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: - model.to_netcdf(fn.name) - input_file_name = self._upload_file_to_gcp(fn.name) + try: + # Save model to temporary file and upload + with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: + model.to_netcdf(fn.name) + input_file_name = self._upload_file_to_gcp(fn.name) + # Submit job and wait for completion job_uuid = self._submit_job_to_compute_service(input_file_name) job_result = self.wait_and_get_job_data(job_uuid) - # TODO: Download result from GCP using job_data.output_files - # TODO: Load result into model and return solved model + # Download and load the solution + if not job_result.output_files: + raise Exception("No output files found in completed job") - return model + output_file_name = job_result.output_files[0] + if isinstance(output_file_name, dict) and 'name' in output_file_name: + output_file_name = output_file_name['name'] + + solution_file_path = self._download_file_from_gcp(output_file_name) + + # Load the solved model + solved_model = linopy.read_netcdf(solution_file_path) + + # Clean up downloaded file + os.remove(solution_file_path) + + print(f"Model solved successfully. Status: {solved_model.status}") + if hasattr(solved_model, 'objective') and hasattr(solved_model.objective, 'value'): + print(f"Objective value: {solved_model.objective.value:.2e}") + + return solved_model + + except Exception as e: + raise Exception(f"Error solving model on OETC: {e}") def _gzip_compress(self, source_path: str) -> str: """ diff --git a/test/test_oetc.py b/test/test_oetc.py index 28671bbb..d1cd36ed 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -646,6 +646,305 @@ def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info, mock_storag assert "Failed to upload file to GCP" in str(exc_info.value) +class TestFileDecompression: + + @pytest.fixture + def handler_with_mocked_auth(self): + """Create handler with mocked authentication for testing file operations""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = Mock() + + return handler + + @patch('linopy.oetc.gzip.open') + @patch('builtins.open') + def test_gzip_decompress_success(self, mock_open_file, mock_gzip_open, handler_with_mocked_auth): + """Test successful file decompression""" + # Setup + input_path = "/tmp/test_file.nc.gz" + expected_output = "/tmp/test_file.nc" + + # Mock file operations + mock_file_in = Mock() + mock_file_out = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.return_value.__enter__.return_value = mock_file_out + + # Mock file reading - simulate reading compressed data in chunks + mock_file_in.read.side_effect = [b"decompressed_data_chunk", + b""] # First read returns data, second returns empty + + # Execute + result = handler_with_mocked_auth._gzip_decompress(input_path) + + # Verify + assert result == expected_output + mock_gzip_open.assert_called_once_with(input_path, "rb") + mock_open_file.assert_called_once_with(expected_output, "wb") + mock_file_out.write.assert_called_once_with(b"decompressed_data_chunk") + + @patch('linopy.oetc.gzip.open') + def test_gzip_decompress_gzip_open_error(self, mock_gzip_open, handler_with_mocked_auth): + """Test file decompression with gzip open error""" + # Setup + input_path = "/tmp/test_file.nc.gz" + mock_gzip_open.side_effect = IOError("Failed to open gzip file") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._gzip_decompress(input_path) + + assert "Failed to decompress file" in str(exc_info.value) + + @patch('linopy.oetc.gzip.open') + @patch('builtins.open') + def test_gzip_decompress_write_error(self, mock_open_file, mock_gzip_open, handler_with_mocked_auth): + """Test file decompression with write error""" + # Setup + input_path = "/tmp/test_file.nc.gz" + + # Mock file operations + mock_file_in = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.side_effect = IOError("Permission denied") + + # Mock file reading + mock_file_in.read.return_value = b"test_data" + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_mocked_auth._gzip_decompress(input_path) + + assert "Failed to decompress file" in str(exc_info.value) + + def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth): + """Test correct output path generation for decompression""" + # Test first path + with patch('linopy.oetc.gzip.open') as mock_gzip_open: + with patch('builtins.open') as mock_open_file: + mock_file_in = Mock() + mock_file_out = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.return_value.__enter__.return_value = mock_file_out + mock_file_in.read.side_effect = [b"test", b""] + + result = handler_with_mocked_auth._gzip_decompress("/tmp/file.nc.gz") + assert result == "/tmp/file.nc" + + # Test second path with fresh mocks + with patch('linopy.oetc.gzip.open') as mock_gzip_open: + with patch('builtins.open') as mock_open_file: + mock_file_in = Mock() + mock_file_out = Mock() + mock_gzip_open.return_value.__enter__.return_value = mock_file_in + mock_open_file.return_value.__enter__.return_value = mock_file_out + mock_file_in.read.side_effect = [b"test", b""] + + result = handler_with_mocked_auth._gzip_decompress("/path/to/model.data.gz") + assert result == "/path/to/model.data" + + +class TestGcpDownload: + + @pytest.fixture + def handler_with_gcp_credentials(self, mock_gcp_credentials_response): + """Create handler with GCP credentials for testing download""" + with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): + credentials = OetcCredentials(email="test@example.com", password="test_password") + settings = OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP + ) + + # Create proper GCP credentials + gcp_creds = GcpCredentials( + gcp_project_id="test-project-123", + gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', + input_bucket="test-input-bucket", + solution_bucket="test-solution-bucket" + ) + + handler = OetcHandler.__new__(OetcHandler) + handler.settings = settings + handler.jwt = Mock() + handler.cloud_provider_credentials = gcp_creds + + return handler + + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_success(self, mock_creds_from_info, mock_storage_client, + mock_tempfile, mock_remove, handler_with_gcp_credentials): + """Test successful file download from GCP""" + # Setup + file_name = "solution_file.nc.gz" + compressed_path = "/tmp/tmpfile.gz" + decompressed_path = "/tmp/tmpfile" + + # Mock temporary file creation + mock_temp_file = Mock() + mock_temp_file.name = compressed_path + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock decompression + with patch.object(handler_with_gcp_credentials, '_gzip_decompress', return_value=decompressed_path): + # Mock GCP components + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_bucket.blob.return_value = mock_blob + + # Execute + result = handler_with_gcp_credentials._download_file_from_gcp(file_name) + + # Verify + assert result == decompressed_path + + # Verify GCP credentials creation + mock_creds_from_info.assert_called_once_with( + {"type": "service_account", "project_id": "test-project-123"}, + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + # Verify GCP client creation + mock_storage_client.assert_called_once_with( + credentials=mock_credentials, + project="test-project-123" + ) + + # Verify bucket access (solution bucket, not input bucket) + mock_client.bucket.assert_called_once_with("test-solution-bucket") + + # Verify blob operations + mock_bucket.blob.assert_called_once_with(file_name) + mock_blob.download_to_filename.assert_called_once_with(compressed_path) + + # Verify cleanup + mock_remove.assert_called_once_with(compressed_path) + + @patch('linopy.oetc.json.loads') + def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads, handler_with_gcp_credentials): + """Test download failure with invalid service key""" + # Setup + file_name = "solution_file.nc.gz" + mock_json_loads.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_download_error(self, mock_creds_from_info, mock_storage_client, + mock_tempfile, handler_with_gcp_credentials): + """Test download failure during blob download""" + # Setup + file_name = "solution_file.nc.gz" + compressed_path = "/tmp/tmpfile.gz" + + # Mock temporary file creation + mock_temp_file = Mock() + mock_temp_file.name = compressed_path + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock GCP setup + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_blob.download_to_filename.side_effect = Exception("Download failed") + mock_bucket.blob.return_value = mock_blob + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info, mock_storage_client, + mock_tempfile, mock_remove, handler_with_gcp_credentials): + """Test download failure during decompression""" + # Setup + file_name = "solution_file.nc.gz" + compressed_path = "/tmp/tmpfile.gz" + + # Mock temporary file creation + mock_temp_file = Mock() + mock_temp_file.name = compressed_path + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + # Mock successful GCP download but failed decompression + mock_credentials = Mock() + mock_creds_from_info.return_value = mock_credentials + + mock_client = Mock() + mock_storage_client.return_value = mock_client + + mock_bucket = Mock() + mock_client.bucket.return_value = mock_bucket + + mock_blob = Mock() + mock_bucket.blob.return_value = mock_blob + + # Mock decompression failure + with patch.object(handler_with_gcp_credentials, '_gzip_decompress', + side_effect=Exception("Decompression failed")): + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_credentials_error(self, mock_creds_from_info, handler_with_gcp_credentials): + """Test download failure during credentials creation""" + # Setup + file_name = "solution_file.nc.gz" + mock_creds_from_info.side_effect = Exception("Invalid credentials") + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_gcp_credentials._download_file_from_gcp(file_name) + + assert "Failed to download file from GCP" in str(exc_info.value) + class TestJobSubmission: @@ -794,16 +1093,24 @@ def handler_with_complete_setup(self, mock_gcp_credentials_response): return handler + @patch('linopy.oetc.linopy.read_netcdf') + @patch('linopy.oetc.os.remove') @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_file_upload(self, mock_tempfile, handler_with_complete_setup): - """Test solve_on_oetc method file upload flow""" + def test_solve_on_oetc_file_upload(self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_complete_setup): + """Test solve_on_oetc method complete workflow""" # Setup mock_model = Mock() + mock_solved_model = Mock() + mock_solved_model.status = "optimal" + mock_solved_model.objective.value = 42.0 + mock_temp_file = Mock() mock_temp_file.name = "/tmp/linopy-abc123.nc" mock_tempfile.return_value.__enter__.return_value = mock_temp_file - # Mock file upload, job submission, and job waiting + mock_read_netcdf.return_value = mock_solved_model + + # Mock file upload, job submission, job waiting, and download with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', return_value="uploaded_file.nc.gz") as mock_upload: with patch.object(handler_with_complete_setup, '_submit_job_to_compute_service', @@ -811,15 +1118,20 @@ def test_solve_on_oetc_file_upload(self, mock_tempfile, handler_with_complete_se with patch.object(handler_with_complete_setup, 'wait_and_get_job_data', return_value=JobResult(uuid="test-job-uuid", status="FINISHED", output_files=[{"name": "result.nc.gz"}])) as mock_wait: - # Execute - result = handler_with_complete_setup.solve_on_oetc(mock_model) - - # Verify - assert result == mock_model # Currently returns the input model - mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_submit.assert_called_once_with("uploaded_file.nc.gz") - mock_wait.assert_called_once_with("test-job-uuid") + with patch.object(handler_with_complete_setup, '_download_file_from_gcp', + return_value="/tmp/downloaded_result.nc") as mock_download: + # Execute + result = handler_with_complete_setup.solve_on_oetc(mock_model) + + # Verify + assert result == mock_solved_model # Now returns the solved model + mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with("uploaded_file.nc.gz") + mock_wait.assert_called_once_with("test-job-uuid") + mock_download.assert_called_once_with("result.nc.gz") + mock_read_netcdf.assert_called_once_with("/tmp/downloaded_result.nc") + mock_remove.assert_called_once_with("/tmp/downloaded_result.nc") @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete_setup): @@ -870,19 +1182,28 @@ def handler_with_full_setup(self): return handler + @patch('linopy.oetc.linopy.read_netcdf') + @patch('linopy.oetc.os.remove') @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_with_job_submission(self, mock_tempfile, handler_with_full_setup): - """Test solve_on_oetc method including job submission and waiting""" + def test_solve_on_oetc_with_job_submission(self, mock_tempfile, mock_remove, mock_read_netcdf, + handler_with_full_setup): + """Test solve_on_oetc method including job submission, waiting, and download""" # Setup mock_model = Mock() + mock_solved_model = Mock() + mock_solved_model.status = "optimal" + mock_solved_model.objective.value = 100.5 + mock_temp_file = Mock() mock_temp_file.name = "/tmp/linopy-abc123.nc" mock_tempfile.return_value.__enter__.return_value = mock_temp_file + mock_read_netcdf.return_value = mock_solved_model + uploaded_file_name = "model_file.nc.gz" job_uuid = "job-uuid-456" - # Mock file upload, job submission, and job waiting + # Mock complete workflow with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name) as mock_upload: with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', @@ -890,15 +1211,20 @@ def test_solve_on_oetc_with_job_submission(self, mock_tempfile, handler_with_ful with patch.object(handler_with_full_setup, 'wait_and_get_job_data', return_value=JobResult(uuid=job_uuid, status="FINISHED", output_files=[{"name": "result.nc.gz"}])) as mock_wait: - # Execute - result = handler_with_full_setup.solve_on_oetc(mock_model) - - # Verify - assert result == mock_model # Currently returns the input model - mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") - mock_submit.assert_called_once_with(uploaded_file_name) - mock_wait.assert_called_once_with(job_uuid) + with patch.object(handler_with_full_setup, '_download_file_from_gcp', + return_value="/tmp/solution_file.nc") as mock_download: + # Execute + result = handler_with_full_setup.solve_on_oetc(mock_model) + + # Verify + assert result == mock_solved_model # Now returns the solved model + mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") + mock_submit.assert_called_once_with(uploaded_file_name) + mock_wait.assert_called_once_with(job_uuid) + mock_download.assert_called_once_with("result.nc.gz") + mock_read_netcdf.assert_called_once_with("/tmp/solution_file.nc") + mock_remove.assert_called_once_with("/tmp/solution_file.nc") @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_job_submission_failure(self, mock_tempfile, handler_with_full_setup): @@ -945,6 +1271,31 @@ def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile, handler_with_ful assert "Job failed: solver error" in str(exc_info.value) + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_no_output_files_error(self, mock_tempfile, handler_with_full_setup): + """Test solve_on_oetc method when job completes but has no output files""" + # Setup + mock_model = Mock() + mock_temp_file = Mock() + mock_temp_file.name = "/tmp/linopy-abc123.nc" + mock_tempfile.return_value.__enter__.return_value = mock_temp_file + + uploaded_file_name = "model_file.nc.gz" + job_uuid = "job-uuid-456" + + # Mock successful workflow until job completion with no output files + with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name): + with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', return_value=job_uuid): + with patch.object(handler_with_full_setup, 'wait_and_get_job_data', + return_value=JobResult(uuid=job_uuid, status="FINISHED", + output_files=[])): # No output files + + # Execute and verify exception + with pytest.raises(Exception) as exc_info: + handler_with_full_setup.solve_on_oetc(mock_model) + + assert "No output files found in completed job" in str(exc_info.value) + # Additional integration-style test class TestOetcHandlerIntegration: From e65c80531e83a388ca6f7bf03669386f4f362f83 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Wed, 11 Jun 2025 14:18:22 +0200 Subject: [PATCH 07/28] Add proper logging --- linopy/oetc.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 1a7d65df..0586f704 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -1,3 +1,4 @@ +import logging import time from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -16,6 +17,8 @@ import linopy +logger = logging.getLogger(__name__) + class ComputeProvider(str, Enum): GCP = "GCP" @@ -97,6 +100,7 @@ def __sign_in(self) -> AuthenticationResult: Exception: If authentication fails or response is invalid """ try: + logger.info("OETC - Signing in...") payload = { "email": self.settings.credentials.email, "password": self.settings.credentials.password @@ -112,6 +116,8 @@ def __sign_in(self) -> AuthenticationResult: response.raise_for_status() jwt_result = response.json() + logger.info("OETC - Signed in") + return AuthenticationResult( token=jwt_result["token"], token_type=jwt_result["token_type"], @@ -173,6 +179,7 @@ def __get_gcp_credentials(self) -> GcpCredentials: Exception: If credentials fetching fails or response is invalid """ try: + logger.info("OETC - Fetching user GCP credentials...") payload = self._decode_jwt_payload(self.jwt.token) user_uuid = payload.get('sub') @@ -191,6 +198,8 @@ def __get_gcp_credentials(self) -> GcpCredentials: response.raise_for_status() credentials_data = response.json() + logger.info("OETC - Fetched user GCP credentials") + return GcpCredentials( gcp_project_id=credentials_data["gcp_project_id"], gcp_service_key=credentials_data["gcp_service_key"], @@ -219,6 +228,7 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str: Exception: If job submission fails """ try: + logger.info("OETC - Submitting compute job...") payload = { "name": self.settings.name, "solver": self.settings.solver, @@ -243,6 +253,8 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str: response.raise_for_status() job_result = response.json() + logger.info(f"OETC - Compute job {job_result['uuid']} started") + return job_result["uuid"] except RequestException as e: @@ -275,7 +287,7 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, consecutive_failures = 0 max_network_retries = 10 - print(f"Waiting for job {job_uuid} to complete...") + logger.info(f"OETC - Waiting for job {job_uuid} to complete...") while True: try: @@ -307,19 +319,19 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, consecutive_failures = 0 if job_result.status == "FINISHED": - print(f"Job {job_uuid} completed successfully!") + logger.info(f"OETC - Job {job_uuid} completed successfully!") if not job_result.output_files: - print("Warning: Job completed but no output files found") + logger.warning("OETC - Warning: Job completed but no output files found") return job_result elif job_result.status == "SETUP_ERROR": error_msg = f"Job failed during setup phase (status: {job_result.status}). Please check the OETC logs for details." - print(f"Error: {error_msg}") + logger.error(f"OETC Error: {error_msg}") raise Exception(error_msg) elif job_result.status == "RUNTIME_ERROR": error_msg = f"Job failed during execution (status: {job_result.status}). Please check the OETC logs for details." - print(f"Error: {error_msg}") + logger.error(f"OETC Error: {error_msg}") raise Exception(error_msg) elif job_result.status in ["PENDING", "STARTING", "RUNNING"]: @@ -327,7 +339,7 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, if job_result.duration_in_seconds: status_msg += f" (running for {job_result.duration_in_seconds}s)" status_msg += f", checking again in {poll_interval} seconds..." - print(status_msg) + logger.info(f"OETC - {status_msg}") time.sleep(poll_interval) @@ -337,7 +349,7 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, else: # Unknown status error_msg = f"Unknown job status: {job_result.status}. Please check the OETC logs for details." - print(f"Error: {error_msg}") + logger.error(f"OETC Error: {error_msg}") raise Exception(error_msg) except RequestException as e: @@ -348,7 +360,7 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, # Wait before retrying network request retry_wait = min(consecutive_failures * 10, 60) - print(f"Network error getting job status (attempt {consecutive_failures}/{max_network_retries}), " + logger.error(f"OETC - Network error getting job status (attempt {consecutive_failures}/{max_network_retries}), " f"retrying in {retry_wait} seconds: {e}") time.sleep(retry_wait) @@ -477,9 +489,9 @@ def solve_on_oetc(self, model): # Clean up downloaded file os.remove(solution_file_path) - print(f"Model solved successfully. Status: {solved_model.status}") + logger.info(f"OETC - Model solved successfully. Status: {solved_model.status}") if hasattr(solved_model, 'objective') and hasattr(solved_model.objective, 'value'): - print(f"Objective value: {solved_model.objective.value:.2e}") + logger.info(f"OETC - Objective value: {solved_model.objective.value:.2e}") return solved_model From bb341b6448b842de1fdd8bcb76c4a9eabac54802 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Thu, 12 Jun 2025 13:38:29 +0200 Subject: [PATCH 08/28] Add oetc call to linopy model --- linopy/model.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index 5767d653..f0a7b57e 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -59,6 +59,7 @@ from linopy.matrices import MatrixAccessor from linopy.objective import Objective from linopy.solvers import IO_APIS, available_solvers, quadratic_solvers +from linopy.oetc import OetcHandler, OetcSettings from linopy.types import ( ConstantLike, ConstraintLike, @@ -1049,7 +1050,7 @@ def solve( slice_size: int = 2_000_000, remote: Any = None, progress: bool | None = None, - oetc_settings: dict | None = {}, + oetc_settings: OetcSettings = None, **solver_options: Any, ) -> tuple[str, str]: """ @@ -1117,6 +1118,9 @@ def solve( Whether to show a progress bar of writing the lp file. The default is None, which means that the progress bar is shown if the model has more than 10000 variables and constraints. + oetc_settings : dict, optional + Settings for the solving on the OETC platform. If a value is provided + solving will be attempted on OETC, otherwise it will be done locally. **solver_options : kwargs Options passed to the solver. @@ -1135,20 +1139,25 @@ def solve( f"Keyword argument `io_api` has to be one of {IO_APIS} or None" ) - if remote: - solved = remote.solve_on_remote( - self, - solver_name=solver_name, - io_api=io_api, - problem_fn=problem_fn, - solution_fn=solution_fn, - log_fn=log_fn, - basis_fn=basis_fn, - warmstart_fn=warmstart_fn, - keep_files=keep_files, - sanitize_zeros=sanitize_zeros, - **solver_options, - ) + if remote or oetc_settings: + if remote and oetc_settings: + raise ValueError("Remote and OETC can't be active at the same time") + if remote: + solved = remote.solve_on_remote( + self, + solver_name=solver_name, + io_api=io_api, + problem_fn=problem_fn, + solution_fn=solution_fn, + log_fn=log_fn, + basis_fn=basis_fn, + warmstart_fn=warmstart_fn, + keep_files=keep_files, + sanitize_zeros=sanitize_zeros, + **solver_options, + ) + else: + solved = OetcHandler(oetc_settings).solve_on_oetc(self) self.objective.set_value(solved.objective.value) self.status = solved.status From df6835dbb549c2a00787ba737893393ee1ee4e43 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Fri, 1 Aug 2025 13:52:31 +0200 Subject: [PATCH 09/28] Add requests dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index eb23115f..59a823a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "tqdm", "deprecation", "google-cloud-storage", + "requests" ] [project.urls] From 0195b8824c509a6c735c2886aaa131b4bacf246c Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 4 Aug 2025 12:24:09 +0200 Subject: [PATCH 10/28] Fix wrong orchestrator endpoint values --- linopy/oetc.py | 4 ++-- test/test_oetc.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 0586f704..f7dcece8 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -241,7 +241,7 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str: } response = requests.post( - f"{self.settings.orchestrator_server_url}/create", + f"{self.settings.orchestrator_server_url}/compute-job/create", json=payload, headers={ "Authorization": f"{self.jwt.token_type} {self.jwt.token}", @@ -292,7 +292,7 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, while True: try: response = requests.get( - f"{self.settings.orchestrator_server_url}/jobs/{job_uuid}", + f"{self.settings.orchestrator_server_url}/compute-job/{job_uuid}", headers={ "Authorization": f"{self.jwt.token_type} {self.jwt.token}", "Content-Type": "application/json" diff --git a/test/test_oetc.py b/test/test_oetc.py index d1cd36ed..71822596 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -1007,7 +1007,7 @@ def test_submit_job_success(self, mock_post, handler_with_auth_setup): } mock_post.assert_called_once_with( - "https://orchestrator.example.com/create", + "https://orchestrator.example.com/compute-job/create", json=expected_payload, headers={ "Authorization": f"Bearer {handler_with_auth_setup.jwt.token}", From 866f421f74c4bf32fc5f987e32c8ba966ac22ee5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Aug 2025 13:08:47 +0000 Subject: [PATCH 11/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/model.py | 2 +- linopy/oetc.py | 129 ++++++--- test/test_oetc.py | 722 ++++++++++++++++++++++++++++++---------------- 3 files changed, 548 insertions(+), 305 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index f0a7b57e..81b9bb09 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -58,8 +58,8 @@ ) from linopy.matrices import MatrixAccessor from linopy.objective import Objective -from linopy.solvers import IO_APIS, available_solvers, quadratic_solvers from linopy.oetc import OetcHandler, OetcSettings +from linopy.solvers import IO_APIS, available_solvers, quadratic_solvers from linopy.types import ( ConstantLike, ConstraintLike, diff --git a/linopy/oetc.py b/linopy/oetc.py index f7dcece8..7a55986a 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -1,24 +1,24 @@ +import base64 +import gzip +import json import logging +import os +import tempfile import time from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum -from typing import Union -import json -import base64 -import gzip -import os -import tempfile import requests -from requests import RequestException from google.cloud import storage from google.oauth2 import service_account -import linopy +from requests import RequestException +import linopy logger = logging.getLogger(__name__) + class ComputeProvider(str, Enum): GCP = "GCP" @@ -103,14 +103,14 @@ def __sign_in(self) -> AuthenticationResult: logger.info("OETC - Signing in...") payload = { "email": self.settings.credentials.email, - "password": self.settings.credentials.password + "password": self.settings.credentials.password, } response = requests.post( f"{self.settings.authentication_server_url}/sign-in", json=payload, headers={"Content-Type": "application/json"}, - timeout=30 + timeout=30, ) response.raise_for_status() @@ -122,7 +122,7 @@ def __sign_in(self) -> AuthenticationResult: token=jwt_result["token"], token_type=jwt_result["token_type"], expires_in=jwt_result["expires_in"], - authenticated_at=datetime.now() + authenticated_at=datetime.now(), ) except RequestException as e: @@ -146,14 +146,14 @@ def _decode_jwt_payload(self, token: str) -> dict: Exception: If token cannot be decoded """ try: - payload_part = token.split('.')[1] - payload_part += '=' * (4 - len(payload_part) % 4) + payload_part = token.split(".")[1] + payload_part += "=" * (4 - len(payload_part) % 4) payload_bytes = base64.urlsafe_b64decode(payload_part) - return json.loads(payload_bytes.decode('utf-8')) + return json.loads(payload_bytes.decode("utf-8")) except (IndexError, json.JSONDecodeError, Exception) as e: raise Exception(f"Failed to decode JWT payload: {e}") - def __get_cloud_provider_credentials(self) -> Union[GcpCredentials, None]: + def __get_cloud_provider_credentials(self) -> GcpCredentials | None: """ Fetch cloud provider credentials based on the configured provider. @@ -166,7 +166,9 @@ def __get_cloud_provider_credentials(self) -> Union[GcpCredentials, None]: if self.settings.compute_provider == ComputeProvider.GCP: return self.__get_gcp_credentials() else: - raise Exception(f"Unsupported compute provider: {self.settings.compute_provider}") + raise Exception( + f"Unsupported compute provider: {self.settings.compute_provider}" + ) def __get_gcp_credentials(self) -> GcpCredentials: """ @@ -181,7 +183,7 @@ def __get_gcp_credentials(self) -> GcpCredentials: try: logger.info("OETC - Fetching user GCP credentials...") payload = self._decode_jwt_payload(self.jwt.token) - user_uuid = payload.get('sub') + user_uuid = payload.get("sub") if not user_uuid: raise Exception("User UUID not found in JWT token") @@ -190,9 +192,9 @@ def __get_gcp_credentials(self) -> GcpCredentials: f"{self.settings.authentication_server_url}/users/{user_uuid}/gcp-credentials", headers={ "Authorization": f"{self.jwt.token_type} {self.jwt.token}", - "Content-Type": "application/json" + "Content-Type": "application/json", }, - timeout=30 + timeout=30, ) response.raise_for_status() @@ -204,7 +206,7 @@ def __get_gcp_credentials(self) -> GcpCredentials: gcp_project_id=credentials_data["gcp_project_id"], gcp_service_key=credentials_data["gcp_service_key"], input_bucket=credentials_data["input_bucket"], - solution_bucket=credentials_data["solution_bucket"] + solution_bucket=credentials_data["solution_bucket"], ) except RequestException as e: @@ -237,7 +239,7 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str: "cpu_cores": self.settings.cpu_cores, "disk_space_gb": self.settings.disk_space_gb, "input_file_name": input_file_name, - "delete_worker_on_error": self.settings.delete_worker_on_error + "delete_worker_on_error": self.settings.delete_worker_on_error, } response = requests.post( @@ -245,9 +247,9 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str: json=payload, headers={ "Authorization": f"{self.jwt.token_type} {self.jwt.token}", - "Content-Type": "application/json" + "Content-Type": "application/json", }, - timeout=30 + timeout=30, ) response.raise_for_status() @@ -260,12 +262,18 @@ def _submit_job_to_compute_service(self, input_file_name: str) -> str: except RequestException as e: raise Exception(f"Failed to submit job to compute service: {e}") except KeyError as e: - raise Exception(f"Invalid job submission response format: missing field {e}") + raise Exception( + f"Invalid job submission response format: missing field {e}" + ) except Exception as e: raise Exception(f"Error submitting job to compute service: {e}") - def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, - max_poll_interval: int = 300) -> JobResult: + def wait_and_get_job_data( + self, + job_uuid: str, + initial_poll_interval: int = 30, + max_poll_interval: int = 300, + ) -> JobResult: """ Wait for job completion and get job data by polling the orchestrator service. @@ -295,9 +303,9 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, f"{self.settings.orchestrator_server_url}/compute-job/{job_uuid}", headers={ "Authorization": f"{self.jwt.token_type} {self.jwt.token}", - "Content-Type": "application/json" + "Content-Type": "application/json", }, - timeout=30 + timeout=30, ) response.raise_for_status() @@ -310,10 +318,12 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, owner=job_data_dict.get("owner"), solver=job_data_dict.get("solver"), duration_in_seconds=job_data_dict.get("duration_in_seconds"), - solving_duration_in_seconds=job_data_dict.get("solving_duration_in_seconds"), + solving_duration_in_seconds=job_data_dict.get( + "solving_duration_in_seconds" + ), input_files=job_data_dict.get("input_files", []), output_files=job_data_dict.get("output_files", []), - created_at=job_data_dict.get("created_at") + created_at=job_data_dict.get("created_at"), ) consecutive_failures = 0 @@ -321,7 +331,9 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, if job_result.status == "FINISHED": logger.info(f"OETC - Job {job_uuid} completed successfully!") if not job_result.output_files: - logger.warning("OETC - Warning: Job completed but no output files found") + logger.warning( + "OETC - Warning: Job completed but no output files found" + ) return job_result elif job_result.status == "SETUP_ERROR": @@ -337,7 +349,9 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, elif job_result.status in ["PENDING", "STARTING", "RUNNING"]: status_msg = f"Job {job_uuid} status: {job_result.status}" if job_result.duration_in_seconds: - status_msg += f" (running for {job_result.duration_in_seconds}s)" + status_msg += ( + f" (running for {job_result.duration_in_seconds}s)" + ) status_msg += f", checking again in {poll_interval} seconds..." logger.info(f"OETC - {status_msg}") @@ -356,16 +370,22 @@ def wait_and_get_job_data(self, job_uuid: str, initial_poll_interval: int = 30, consecutive_failures += 1 if consecutive_failures >= max_network_retries: - raise Exception(f"Failed to get job status after {max_network_retries} network retries: {e}") + raise Exception( + f"Failed to get job status after {max_network_retries} network retries: {e}" + ) # Wait before retrying network request retry_wait = min(consecutive_failures * 10, 60) - logger.error(f"OETC - Network error getting job status (attempt {consecutive_failures}/{max_network_retries}), " - f"retrying in {retry_wait} seconds: {e}") + logger.error( + f"OETC - Network error getting job status (attempt {consecutive_failures}/{max_network_retries}), " + f"retrying in {retry_wait} seconds: {e}" + ) time.sleep(retry_wait) except KeyError as e: - raise Exception(f"Invalid job status response format: missing field {e}") + raise Exception( + f"Invalid job status response format: missing field {e}" + ) except Exception as e: if "status:" in str(e) or "OETC logs" in str(e): raise @@ -416,7 +436,9 @@ def _download_file_from_gcp(self, file_name: str) -> str: """ try: # Create GCP credentials from service key - service_key_dict = json.loads(self.cloud_provider_credentials.gcp_service_key) + service_key_dict = json.loads( + self.cloud_provider_credentials.gcp_service_key + ) credentials = service_account.Credentials.from_service_account_info( service_key_dict, scopes=["https://www.googleapis.com/auth/cloud-platform"], @@ -425,9 +447,11 @@ def _download_file_from_gcp(self, file_name: str) -> str: # Download from GCP solution bucket storage_client = storage.Client( credentials=credentials, - project=self.cloud_provider_credentials.gcp_project_id + project=self.cloud_provider_credentials.gcp_project_id, + ) + bucket = storage_client.bucket( + self.cloud_provider_credentials.solution_bucket ) - bucket = storage_client.bucket(self.cloud_provider_credentials.solution_bucket) blob = bucket.blob(file_name) # Create temporary file for download @@ -460,7 +484,8 @@ def solve_on_oetc(self, model): linopy.model.Model Solved model. - Raises: + Raises + ------ Exception: If solving fails at any stage """ try: @@ -478,8 +503,8 @@ def solve_on_oetc(self, model): raise Exception("No output files found in completed job") output_file_name = job_result.output_files[0] - if isinstance(output_file_name, dict) and 'name' in output_file_name: - output_file_name = output_file_name['name'] + if isinstance(output_file_name, dict) and "name" in output_file_name: + output_file_name = output_file_name["name"] solution_file_path = self._download_file_from_gcp(output_file_name) @@ -489,9 +514,15 @@ def solve_on_oetc(self, model): # Clean up downloaded file os.remove(solution_file_path) - logger.info(f"OETC - Model solved successfully. Status: {solved_model.status}") - if hasattr(solved_model, 'objective') and hasattr(solved_model.objective, 'value'): - logger.info(f"OETC - Objective value: {solved_model.objective.value:.2e}") + logger.info( + f"OETC - Model solved successfully. Status: {solved_model.status}" + ) + if hasattr(solved_model, "objective") and hasattr( + solved_model.objective, "value" + ): + logger.info( + f"OETC - Objective value: {solved_model.objective.value:.2e}" + ) return solved_model @@ -545,7 +576,9 @@ def _upload_file_to_gcp(self, file_path: str) -> str: compressed_file_name = os.path.basename(compressed_file_path) # Create GCP credentials from service key - service_key_dict = json.loads(self.cloud_provider_credentials.gcp_service_key) + service_key_dict = json.loads( + self.cloud_provider_credentials.gcp_service_key + ) credentials = service_account.Credentials.from_service_account_info( service_key_dict, scopes=["https://www.googleapis.com/auth/cloud-platform"], @@ -554,7 +587,7 @@ def _upload_file_to_gcp(self, file_path: str) -> str: # Upload to GCP bucket storage_client = storage.Client( credentials=credentials, - project=self.cloud_provider_credentials.gcp_project_id + project=self.cloud_provider_credentials.gcp_project_id, ) bucket = storage_client.bucket(self.cloud_provider_credentials.input_bucket) blob = bucket.blob(compressed_file_name) @@ -567,4 +600,4 @@ def _upload_file_to_gcp(self, file_path: str) -> str: return compressed_file_name except Exception as e: - raise Exception(f"Failed to upload file to GCP: {e}") \ No newline at end of file + raise Exception(f"Failed to upload file to GCP: {e}") diff --git a/test/test_oetc.py b/test/test_oetc.py index 71822596..9fabf9fa 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -1,15 +1,20 @@ -import pytest -from datetime import datetime -from unittest.mock import patch, Mock import base64 import json +from datetime import datetime +from unittest.mock import Mock, patch +import pytest import requests from requests import RequestException from linopy.oetc import ( - OetcHandler, OetcSettings, OetcCredentials, AuthenticationResult, - ComputeProvider, GcpCredentials, JobResult + AuthenticationResult, + ComputeProvider, + GcpCredentials, + JobResult, + OetcCredentials, + OetcHandler, + OetcSettings, ) @@ -23,12 +28,18 @@ def sample_jwt_token(): "jti": "jwt-id-456", "email": "test@example.com", "firstname": "Test", - "lastname": "User" + "lastname": "User", } # Create a simple JWT-like token (header.payload.signature) - header = base64.urlsafe_b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()).decode().rstrip('=') - payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + header = ( + base64.urlsafe_b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()) + .decode() + .rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) signature = "fake_signature" return f"{header}.{payload_encoded}.{signature}" @@ -41,42 +52,46 @@ def mock_gcp_credentials_response(): "gcp_project_id": "test-project-123", "gcp_service_key": "test-service-key-content", "input_bucket": "test-input-bucket", - "solution_bucket": "test-solution-bucket" + "solution_bucket": "test-solution-bucket", } @pytest.fixture def mock_settings(): """Create mock settings for testing""" - credentials = OetcCredentials( - email="test@example.com", - password="test_password" - ) + credentials = OetcCredentials(email="test@example.com", password="test_password") return OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) class TestOetcHandler: - @pytest.fixture def mock_jwt_response(self): """Create a mock JWT response""" return { "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", "token_type": "Bearer", - "expires_in": 3600 + "expires_in": 3600, } - @patch('linopy.oetc.requests.post') - @patch('linopy.oetc.requests.get') - @patch('linopy.oetc.datetime') - def test_successful_authentication(self, mock_datetime, mock_get, mock_post, mock_settings, mock_jwt_response, - mock_gcp_credentials_response, sample_jwt_token): + @patch("linopy.oetc.requests.post") + @patch("linopy.oetc.requests.get") + @patch("linopy.oetc.datetime") + def test_successful_authentication( + self, + mock_datetime, + mock_get, + mock_post, + mock_settings, + mock_jwt_response, + mock_gcp_credentials_response, + sample_jwt_token, + ): """Test successful authentication flow""" # Setup mocks fixed_time = datetime(2024, 1, 15, 12, 0, 0) @@ -101,12 +116,9 @@ def test_successful_authentication(self, mock_datetime, mock_get, mock_post, moc # Verify authentication request mock_post.assert_called_once_with( "https://auth.example.com/sign-in", - json={ - "email": "test@example.com", - "password": "test_password" - }, + json={"email": "test@example.com", "password": "test_password"}, headers={"Content-Type": "application/json"}, - timeout=30 + timeout=30, ) # Verify GCP credentials request @@ -114,9 +126,9 @@ def test_successful_authentication(self, mock_datetime, mock_get, mock_post, moc "https://auth.example.com/users/user-uuid-123/gcp-credentials", headers={ "Authorization": f"Bearer {sample_jwt_token}", - "Content-Type": "application/json" + "Content-Type": "application/json", }, - timeout=30 + timeout=30, ) # Verify AuthenticationResult @@ -129,16 +141,23 @@ def test_successful_authentication(self, mock_datetime, mock_get, mock_post, moc # Verify GcpCredentials assert isinstance(handler.cloud_provider_credentials, GcpCredentials) assert handler.cloud_provider_credentials.gcp_project_id == "test-project-123" - assert handler.cloud_provider_credentials.gcp_service_key == "test-service-key-content" + assert ( + handler.cloud_provider_credentials.gcp_service_key + == "test-service-key-content" + ) assert handler.cloud_provider_credentials.input_bucket == "test-input-bucket" - assert handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" + assert ( + handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" + ) - @patch('linopy.oetc.requests.post') + @patch("linopy.oetc.requests.post") def test_authentication_http_error(self, mock_post, mock_settings): """Test authentication failure with HTTP error""" # Setup mock to raise HTTP error mock_response = Mock() - mock_response.raise_for_status.side_effect = requests.HTTPError("401 Unauthorized") + mock_response.raise_for_status.side_effect = requests.HTTPError( + "401 Unauthorized" + ) mock_post.return_value = mock_response # Execute and verify exception @@ -149,18 +168,19 @@ def test_authentication_http_error(self, mock_post, mock_settings): class TestJwtDecoding: - @pytest.fixture def handler_with_mocked_auth(self): """Create handler with mocked authentication for testing JWT decoding""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") + with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) # Mock the authentication and credentials fetching @@ -168,7 +188,7 @@ def handler_with_mocked_auth(self): token="fake.token.here", token_type="Bearer", expires_in=3600, - authenticated_at=datetime.now() + authenticated_at=datetime.now(), ) handler = OetcHandler.__new__(OetcHandler) @@ -178,7 +198,9 @@ def handler_with_mocked_auth(self): return handler - def test_decode_jwt_payload_success(self, handler_with_mocked_auth, sample_jwt_token): + def test_decode_jwt_payload_success( + self, handler_with_mocked_auth, sample_jwt_token + ): """Test successful JWT payload decoding""" result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) @@ -204,17 +226,18 @@ def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth): class TestCloudProviderCredentials: - @pytest.fixture def handler_with_mocked_auth(self, sample_jwt_token): """Create handler with mocked authentication for testing credentials fetching""" - credentials = OetcCredentials(email="test@example.com", password="test_password") + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) # Mock the authentication result @@ -222,7 +245,7 @@ def handler_with_mocked_auth(self, sample_jwt_token): token=sample_jwt_token, token_type="Bearer", expires_in=3600, - authenticated_at=datetime.now() + authenticated_at=datetime.now(), ) handler = OetcHandler.__new__(OetcHandler) @@ -231,8 +254,10 @@ def handler_with_mocked_auth(self, sample_jwt_token): return handler - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_success(self, mock_get, handler_with_mocked_auth, mock_gcp_credentials_response): + @patch("linopy.oetc.requests.get") + def test_get_gcp_credentials_success( + self, mock_get, handler_with_mocked_auth, mock_gcp_credentials_response + ): """Test successful GCP credentials fetching""" # Setup mock mock_response = Mock() @@ -248,9 +273,9 @@ def test_get_gcp_credentials_success(self, mock_get, handler_with_mocked_auth, m "https://auth.example.com/users/user-uuid-123/gcp-credentials", headers={ "Authorization": f"Bearer {handler_with_mocked_auth.jwt.token}", - "Content-Type": "application/json" + "Content-Type": "application/json", }, - timeout=30 + timeout=30, ) # Verify result @@ -260,7 +285,7 @@ def test_get_gcp_credentials_success(self, mock_get, handler_with_mocked_auth, m assert result.input_bucket == "test-input-bucket" assert result.solution_bucket == "test-solution-bucket" - @patch('linopy.oetc.requests.get') + @patch("linopy.oetc.requests.get") def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth): """Test GCP credentials fetching with HTTP error""" # Setup mock to raise HTTP error @@ -274,15 +299,17 @@ def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth assert "Failed to fetch GCP credentials" in str(exc_info.value) - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_missing_field(self, mock_get, handler_with_mocked_auth): + @patch("linopy.oetc.requests.get") + def test_get_gcp_credentials_missing_field( + self, mock_get, handler_with_mocked_auth + ): """Test GCP credentials fetching with missing response field""" # Setup mock with invalid response mock_response = Mock() mock_response.json.return_value = { "gcp_project_id": "test-project-123", "gcp_service_key": "test-service-key-content", - "input_bucket": "test-input-bucket" + "input_bucket": "test-input-bucket", # Missing "solution_bucket" field } mock_response.raise_for_status.return_value = None @@ -292,12 +319,19 @@ def test_get_gcp_credentials_missing_field(self, mock_get, handler_with_mocked_a with pytest.raises(Exception) as exc_info: handler_with_mocked_auth._OetcHandler__get_gcp_credentials() - assert "Invalid credentials response format: missing field 'solution_bucket'" in str(exc_info.value) + assert ( + "Invalid credentials response format: missing field 'solution_bucket'" + in str(exc_info.value) + ) - def test_get_cloud_provider_credentials_unsupported_provider(self, handler_with_mocked_auth): + def test_get_cloud_provider_credentials_unsupported_provider( + self, handler_with_mocked_auth + ): """Test cloud provider credentials with unsupported provider""" # Change to unsupported provider - handler_with_mocked_auth.settings.compute_provider = "AWS" # Not in enum, but for testing + handler_with_mocked_auth.settings.compute_provider = ( + "AWS" # Not in enum, but for testing + ) # Execute and verify exception with pytest.raises(Exception) as exc_info: @@ -309,7 +343,9 @@ def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_aut """Test GCP credentials fetching when JWT token has no user UUID""" # Create token without 'sub' field payload = {"iss": "OETC", "email": "test@example.com"} - payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) token_without_sub = f"header.{payload_encoded}.signature" handler_with_mocked_auth.jwt.token = token_without_sub @@ -322,14 +358,13 @@ def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_aut class TestGcpCredentials: - def test_gcp_credentials_creation(self): """Test GcpCredentials dataclass creation""" credentials = GcpCredentials( gcp_project_id="test-project-123", gcp_service_key="test-service-key-content", input_bucket="test-input-bucket", - solution_bucket="test-solution-bucket" + solution_bucket="test-solution-bucket", ) assert credentials.gcp_project_id == "test-project-123" @@ -339,13 +374,12 @@ def test_gcp_credentials_creation(self): class TestComputeProvider: - def test_compute_provider_enum(self): """Test ComputeProvider enum values""" assert ComputeProvider.GCP == "GCP" assert ComputeProvider.GCP.value == "GCP" - @patch('linopy.oetc.requests.post') + @patch("linopy.oetc.requests.post") def test_authentication_network_error(self, mock_post, mock_settings): """Test authentication failure with network error""" # Setup mock to raise network error @@ -357,14 +391,16 @@ def test_authentication_network_error(self, mock_post, mock_settings): assert "Authentication request failed" in str(exc_info.value) - @patch('linopy.oetc.requests.post') - def test_authentication_invalid_response_missing_token(self, mock_post, mock_settings): + @patch("linopy.oetc.requests.post") + def test_authentication_invalid_response_missing_token( + self, mock_post, mock_settings + ): """Test authentication failure with missing token in response""" # Setup mock with invalid response mock_response = Mock() mock_response.json.return_value = { "token_type": "Bearer", - "expires_in": 3600 + "expires_in": 3600, # Missing "token" field } mock_response.raise_for_status.return_value = None @@ -376,14 +412,16 @@ def test_authentication_invalid_response_missing_token(self, mock_post, mock_set assert "Invalid response format: missing field 'token'" in str(exc_info.value) - @patch('linopy.oetc.requests.post') - def test_authentication_invalid_response_missing_expires_in(self, mock_post, mock_settings): + @patch("linopy.oetc.requests.post") + def test_authentication_invalid_response_missing_expires_in( + self, mock_post, mock_settings + ): """Test authentication failure with missing expires_in in response""" # Setup mock with invalid response mock_response = Mock() mock_response.json.return_value = { "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", - "token_type": "Bearer" + "token_type": "Bearer", # Missing "expires_in" field } mock_response.raise_for_status.return_value = None @@ -393,9 +431,11 @@ def test_authentication_invalid_response_missing_expires_in(self, mock_post, moc with pytest.raises(Exception) as exc_info: OetcHandler(mock_settings) - assert "Invalid response format: missing field 'expires_in'" in str(exc_info.value) + assert "Invalid response format: missing field 'expires_in'" in str( + exc_info.value + ) - @patch('linopy.oetc.requests.post') + @patch("linopy.oetc.requests.post") def test_authentication_timeout_error(self, mock_post, mock_settings): """Test authentication failure with timeout""" # Setup mock to raise timeout error @@ -409,7 +449,6 @@ def test_authentication_timeout_error(self, mock_post, mock_settings): class TestAuthenticationResult: - @pytest.fixture def auth_result(self): """Create an AuthenticationResult for testing""" @@ -417,7 +456,7 @@ def auth_result(self): token="test_token", token_type="Bearer", expires_in=3600, # 1 hour - authenticated_at=datetime(2024, 1, 15, 12, 0, 0) + authenticated_at=datetime(2024, 1, 15, 12, 0, 0), ) def test_expires_at_calculation(self, auth_result): @@ -425,7 +464,7 @@ def test_expires_at_calculation(self, auth_result): expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later assert auth_result.expires_at == expected_expiry - @patch('linopy.oetc.datetime') + @patch("linopy.oetc.datetime") def test_is_expired_false_when_not_expired(self, mock_datetime, auth_result): """Test is_expired returns False when token is still valid""" # Set current time to before expiration @@ -433,7 +472,7 @@ def test_is_expired_false_when_not_expired(self, mock_datetime, auth_result): assert auth_result.is_expired is False - @patch('linopy.oetc.datetime') + @patch("linopy.oetc.datetime") def test_is_expired_true_when_expired(self, mock_datetime, auth_result): """Test is_expired returns True when token has expired""" # Set current time to after expiration @@ -441,7 +480,7 @@ def test_is_expired_true_when_expired(self, mock_datetime, auth_result): assert auth_result.is_expired is True - @patch('linopy.oetc.datetime') + @patch("linopy.oetc.datetime") def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): """Test is_expired returns True when token expires exactly now""" # Set current time to exact expiration time @@ -451,18 +490,19 @@ def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): class TestFileCompression: - @pytest.fixture def handler_with_mocked_auth(self): """Create handler with mocked authentication for testing file operations""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") + with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) handler = OetcHandler.__new__(OetcHandler) @@ -472,10 +512,12 @@ def handler_with_mocked_auth(self): return handler - @patch('linopy.oetc.gzip.open') - @patch('linopy.oetc.os.path.exists') - @patch('builtins.open') - def test_gzip_compress_success(self, mock_open, mock_exists, mock_gzip_open, handler_with_mocked_auth): + @patch("linopy.oetc.gzip.open") + @patch("linopy.oetc.os.path.exists") + @patch("builtins.open") + def test_gzip_compress_success( + self, mock_open, mock_exists, mock_gzip_open, handler_with_mocked_auth + ): """Test successful file compression""" # Setup source_path = "/tmp/test_file.nc" @@ -489,7 +531,10 @@ def test_gzip_compress_success(self, mock_open, mock_exists, mock_gzip_open, han mock_gzip_open.return_value.__enter__.return_value = mock_file_out # Mock file reading - mock_file_in.read.side_effect = [b"test_data_chunk", b""] # First read returns data, second returns empty + mock_file_in.read.side_effect = [ + b"test_data_chunk", + b"", + ] # First read returns data, second returns empty # Execute result = handler_with_mocked_auth._gzip_compress(source_path) @@ -500,12 +545,12 @@ def test_gzip_compress_success(self, mock_open, mock_exists, mock_gzip_open, han mock_gzip_open.assert_called_once_with(expected_output, "wb", compresslevel=9) mock_file_out.write.assert_called_once_with(b"test_data_chunk") - @patch('builtins.open') + @patch("builtins.open") def test_gzip_compress_file_read_error(self, mock_open, handler_with_mocked_auth): """Test file compression with read error""" # Setup source_path = "/tmp/test_file.nc" - mock_open.side_effect = IOError("File not found") + mock_open.side_effect = OSError("File not found") # Execute and verify exception with pytest.raises(Exception) as exc_info: @@ -515,18 +560,19 @@ def test_gzip_compress_file_read_error(self, mock_open, handler_with_mocked_auth class TestGcpUpload: - @pytest.fixture def handler_with_gcp_credentials(self, mock_gcp_credentials_response): """Create handler with GCP credentials for testing upload""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") + with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) # Create proper GCP credentials @@ -534,7 +580,7 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response): gcp_project_id="test-project-123", gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', input_bucket="test-input-bucket", - solution_bucket="test-solution-bucket" + solution_bucket="test-solution-bucket", ) handler = OetcHandler.__new__(OetcHandler) @@ -544,12 +590,18 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response): return handler - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.os.path.basename') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_upload_file_to_gcp_success(self, mock_creds_from_info, mock_storage_client, mock_basename, mock_remove, - handler_with_gcp_credentials): + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.os.path.basename") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_upload_file_to_gcp_success( + self, + mock_creds_from_info, + mock_storage_client, + mock_basename, + mock_remove, + handler_with_gcp_credentials, + ): """Test successful file upload to GCP""" # Setup file_path = "/tmp/test_file.nc" @@ -557,7 +609,9 @@ def test_upload_file_to_gcp_success(self, mock_creds_from_info, mock_storage_cli compressed_name = "test_file.nc.gz" # Mock compression - with patch.object(handler_with_gcp_credentials, '_gzip_compress', return_value=compressed_path): + with patch.object( + handler_with_gcp_credentials, "_gzip_compress", return_value=compressed_path + ): # Mock path operations mock_basename.return_value = compressed_name @@ -583,13 +637,12 @@ def test_upload_file_to_gcp_success(self, mock_creds_from_info, mock_storage_cli # Verify GCP credentials creation mock_creds_from_info.assert_called_once_with( {"type": "service_account", "project_id": "test-project-123"}, - scopes=["https://www.googleapis.com/auth/cloud-platform"] + scopes=["https://www.googleapis.com/auth/cloud-platform"], ) # Verify GCP client creation mock_storage_client.assert_called_once_with( - credentials=mock_credentials, - project="test-project-123" + credentials=mock_credentials, project="test-project-123" ) # Verify bucket access @@ -602,8 +655,10 @@ def test_upload_file_to_gcp_success(self, mock_creds_from_info, mock_storage_cli # Verify cleanup mock_remove.assert_called_once_with(compressed_path) - @patch('linopy.oetc.json.loads') - def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads, handler_with_gcp_credentials): + @patch("linopy.oetc.json.loads") + def test_upload_file_to_gcp_invalid_service_key( + self, mock_json_loads, handler_with_gcp_credentials + ): """Test upload failure with invalid service key""" # Setup file_path = "/tmp/test_file.nc" @@ -615,17 +670,20 @@ def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads, handler_w assert "Failed to upload file to GCP" in str(exc_info.value) - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info, mock_storage_client, - handler_with_gcp_credentials): + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_upload_file_to_gcp_upload_error( + self, mock_creds_from_info, mock_storage_client, handler_with_gcp_credentials + ): """Test upload failure during blob upload""" # Setup file_path = "/tmp/test_file.nc" compressed_path = "/tmp/test_file.nc.gz" # Mock compression - with patch.object(handler_with_gcp_credentials, '_gzip_compress', return_value=compressed_path): + with patch.object( + handler_with_gcp_credentials, "_gzip_compress", return_value=compressed_path + ): # Mock GCP setup mock_credentials = Mock() mock_creds_from_info.return_value = mock_credentials @@ -646,19 +704,21 @@ def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info, mock_storag assert "Failed to upload file to GCP" in str(exc_info.value) -class TestFileDecompression: +class TestFileDecompression: @pytest.fixture def handler_with_mocked_auth(self): """Create handler with mocked authentication for testing file operations""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") + with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) handler = OetcHandler.__new__(OetcHandler) @@ -668,9 +728,11 @@ def handler_with_mocked_auth(self): return handler - @patch('linopy.oetc.gzip.open') - @patch('builtins.open') - def test_gzip_decompress_success(self, mock_open_file, mock_gzip_open, handler_with_mocked_auth): + @patch("linopy.oetc.gzip.open") + @patch("builtins.open") + def test_gzip_decompress_success( + self, mock_open_file, mock_gzip_open, handler_with_mocked_auth + ): """Test successful file decompression""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -683,8 +745,10 @@ def test_gzip_decompress_success(self, mock_open_file, mock_gzip_open, handler_w mock_open_file.return_value.__enter__.return_value = mock_file_out # Mock file reading - simulate reading compressed data in chunks - mock_file_in.read.side_effect = [b"decompressed_data_chunk", - b""] # First read returns data, second returns empty + mock_file_in.read.side_effect = [ + b"decompressed_data_chunk", + b"", + ] # First read returns data, second returns empty # Execute result = handler_with_mocked_auth._gzip_decompress(input_path) @@ -695,12 +759,14 @@ def test_gzip_decompress_success(self, mock_open_file, mock_gzip_open, handler_w mock_open_file.assert_called_once_with(expected_output, "wb") mock_file_out.write.assert_called_once_with(b"decompressed_data_chunk") - @patch('linopy.oetc.gzip.open') - def test_gzip_decompress_gzip_open_error(self, mock_gzip_open, handler_with_mocked_auth): + @patch("linopy.oetc.gzip.open") + def test_gzip_decompress_gzip_open_error( + self, mock_gzip_open, handler_with_mocked_auth + ): """Test file decompression with gzip open error""" # Setup input_path = "/tmp/test_file.nc.gz" - mock_gzip_open.side_effect = IOError("Failed to open gzip file") + mock_gzip_open.side_effect = OSError("Failed to open gzip file") # Execute and verify exception with pytest.raises(Exception) as exc_info: @@ -708,9 +774,11 @@ def test_gzip_decompress_gzip_open_error(self, mock_gzip_open, handler_with_mock assert "Failed to decompress file" in str(exc_info.value) - @patch('linopy.oetc.gzip.open') - @patch('builtins.open') - def test_gzip_decompress_write_error(self, mock_open_file, mock_gzip_open, handler_with_mocked_auth): + @patch("linopy.oetc.gzip.open") + @patch("builtins.open") + def test_gzip_decompress_write_error( + self, mock_open_file, mock_gzip_open, handler_with_mocked_auth + ): """Test file decompression with write error""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -718,7 +786,7 @@ def test_gzip_decompress_write_error(self, mock_open_file, mock_gzip_open, handl # Mock file operations mock_file_in = Mock() mock_gzip_open.return_value.__enter__.return_value = mock_file_in - mock_open_file.side_effect = IOError("Permission denied") + mock_open_file.side_effect = OSError("Permission denied") # Mock file reading mock_file_in.read.return_value = b"test_data" @@ -732,8 +800,8 @@ def test_gzip_decompress_write_error(self, mock_open_file, mock_gzip_open, handl def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth): """Test correct output path generation for decompression""" # Test first path - with patch('linopy.oetc.gzip.open') as mock_gzip_open: - with patch('builtins.open') as mock_open_file: + with patch("linopy.oetc.gzip.open") as mock_gzip_open: + with patch("builtins.open") as mock_open_file: mock_file_in = Mock() mock_file_out = Mock() mock_gzip_open.return_value.__enter__.return_value = mock_file_in @@ -744,31 +812,34 @@ def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth): assert result == "/tmp/file.nc" # Test second path with fresh mocks - with patch('linopy.oetc.gzip.open') as mock_gzip_open: - with patch('builtins.open') as mock_open_file: + with patch("linopy.oetc.gzip.open") as mock_gzip_open: + with patch("builtins.open") as mock_open_file: mock_file_in = Mock() mock_file_out = Mock() mock_gzip_open.return_value.__enter__.return_value = mock_file_in mock_open_file.return_value.__enter__.return_value = mock_file_out mock_file_in.read.side_effect = [b"test", b""] - result = handler_with_mocked_auth._gzip_decompress("/path/to/model.data.gz") + result = handler_with_mocked_auth._gzip_decompress( + "/path/to/model.data.gz" + ) assert result == "/path/to/model.data" class TestGcpDownload: - @pytest.fixture def handler_with_gcp_credentials(self, mock_gcp_credentials_response): """Create handler with GCP credentials for testing download""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") + with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) # Create proper GCP credentials @@ -776,7 +847,7 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response): gcp_project_id="test-project-123", gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', input_bucket="test-input-bucket", - solution_bucket="test-solution-bucket" + solution_bucket="test-solution-bucket", ) handler = OetcHandler.__new__(OetcHandler) @@ -786,12 +857,18 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response): return handler - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_success(self, mock_creds_from_info, mock_storage_client, - mock_tempfile, mock_remove, handler_with_gcp_credentials): + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_success( + self, + mock_creds_from_info, + mock_storage_client, + mock_tempfile, + mock_remove, + handler_with_gcp_credentials, + ): """Test successful file download from GCP""" # Setup file_name = "solution_file.nc.gz" @@ -804,7 +881,11 @@ def test_download_file_from_gcp_success(self, mock_creds_from_info, mock_storage mock_tempfile.return_value.__enter__.return_value = mock_temp_file # Mock decompression - with patch.object(handler_with_gcp_credentials, '_gzip_decompress', return_value=decompressed_path): + with patch.object( + handler_with_gcp_credentials, + "_gzip_decompress", + return_value=decompressed_path, + ): # Mock GCP components mock_credentials = Mock() mock_creds_from_info.return_value = mock_credentials @@ -827,13 +908,12 @@ def test_download_file_from_gcp_success(self, mock_creds_from_info, mock_storage # Verify GCP credentials creation mock_creds_from_info.assert_called_once_with( {"type": "service_account", "project_id": "test-project-123"}, - scopes=["https://www.googleapis.com/auth/cloud-platform"] + scopes=["https://www.googleapis.com/auth/cloud-platform"], ) # Verify GCP client creation mock_storage_client.assert_called_once_with( - credentials=mock_credentials, - project="test-project-123" + credentials=mock_credentials, project="test-project-123" ) # Verify bucket access (solution bucket, not input bucket) @@ -846,8 +926,10 @@ def test_download_file_from_gcp_success(self, mock_creds_from_info, mock_storage # Verify cleanup mock_remove.assert_called_once_with(compressed_path) - @patch('linopy.oetc.json.loads') - def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads, handler_with_gcp_credentials): + @patch("linopy.oetc.json.loads") + def test_download_file_from_gcp_invalid_service_key( + self, mock_json_loads, handler_with_gcp_credentials + ): """Test download failure with invalid service key""" # Setup file_name = "solution_file.nc.gz" @@ -859,11 +941,16 @@ def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads, handl assert "Failed to download file from GCP" in str(exc_info.value) - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_download_error(self, mock_creds_from_info, mock_storage_client, - mock_tempfile, handler_with_gcp_credentials): + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_download_error( + self, + mock_creds_from_info, + mock_storage_client, + mock_tempfile, + handler_with_gcp_credentials, + ): """Test download failure during blob download""" # Setup file_name = "solution_file.nc.gz" @@ -894,12 +981,18 @@ def test_download_file_from_gcp_download_error(self, mock_creds_from_info, mock_ assert "Failed to download file from GCP" in str(exc_info.value) - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info, mock_storage_client, - mock_tempfile, mock_remove, handler_with_gcp_credentials): + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_decompression_error( + self, + mock_creds_from_info, + mock_storage_client, + mock_tempfile, + mock_remove, + handler_with_gcp_credentials, + ): """Test download failure during decompression""" # Setup file_name = "solution_file.nc.gz" @@ -924,16 +1017,21 @@ def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info, mock_bucket.blob.return_value = mock_blob # Mock decompression failure - with patch.object(handler_with_gcp_credentials, '_gzip_decompress', - side_effect=Exception("Decompression failed")): + with patch.object( + handler_with_gcp_credentials, + "_gzip_decompress", + side_effect=Exception("Decompression failed"), + ): # Execute and verify exception with pytest.raises(Exception) as exc_info: handler_with_gcp_credentials._download_file_from_gcp(file_name) assert "Failed to download file from GCP" in str(exc_info.value) - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_credentials_error(self, mock_creds_from_info, handler_with_gcp_credentials): + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_credentials_error( + self, mock_creds_from_info, handler_with_gcp_credentials + ): """Test download failure during credentials creation""" # Setup file_name = "solution_file.nc.gz" @@ -947,11 +1045,12 @@ def test_download_file_from_gcp_credentials_error(self, mock_creds_from_info, ha class TestJobSubmission: - @pytest.fixture def handler_with_auth_setup(self, sample_jwt_token): """Create handler with authentication setup for testing job submission""" - credentials = OetcCredentials(email="test@example.com", password="test_password") + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Optimization Job", @@ -960,7 +1059,7 @@ def handler_with_auth_setup(self, sample_jwt_token): compute_provider=ComputeProvider.GCP, solver="gurobi", cpu_cores=4, - disk_space_gb=20 + disk_space_gb=20, ) # Mock the authentication result @@ -968,7 +1067,7 @@ def handler_with_auth_setup(self, sample_jwt_token): token=sample_jwt_token, token_type="Bearer", expires_in=3600, - authenticated_at=datetime.now() + authenticated_at=datetime.now(), ) handler = OetcHandler.__new__(OetcHandler) @@ -978,7 +1077,7 @@ def handler_with_auth_setup(self, sample_jwt_token): return handler - @patch('linopy.oetc.requests.post') + @patch("linopy.oetc.requests.post") def test_submit_job_success(self, mock_post, handler_with_auth_setup): """Test successful job submission to compute service""" # Setup @@ -1003,7 +1102,7 @@ def test_submit_job_success(self, mock_post, handler_with_auth_setup): "cpu_cores": 4, "disk_space_gb": 20, "input_file_name": input_file_name, - "delete_worker_on_error": False + "delete_worker_on_error": False, } mock_post.assert_called_once_with( @@ -1011,21 +1110,23 @@ def test_submit_job_success(self, mock_post, handler_with_auth_setup): json=expected_payload, headers={ "Authorization": f"Bearer {handler_with_auth_setup.jwt.token}", - "Content-Type": "application/json" + "Content-Type": "application/json", }, - timeout=30 + timeout=30, ) # Verify result assert result == expected_job_uuid - @patch('linopy.oetc.requests.post') + @patch("linopy.oetc.requests.post") def test_submit_job_http_error(self, mock_post, handler_with_auth_setup): """Test job submission with HTTP error""" # Setup input_file_name = "test_model.nc.gz" mock_response = Mock() - mock_response.raise_for_status.side_effect = requests.HTTPError("400 Bad Request") + mock_response.raise_for_status.side_effect = requests.HTTPError( + "400 Bad Request" + ) mock_post.return_value = mock_response # Execute and verify exception @@ -1034,8 +1135,10 @@ def test_submit_job_http_error(self, mock_post, handler_with_auth_setup): assert "Failed to submit job to compute service" in str(exc_info.value) - @patch('linopy.oetc.requests.post') - def test_submit_job_missing_uuid_in_response(self, mock_post, handler_with_auth_setup): + @patch("linopy.oetc.requests.post") + def test_submit_job_missing_uuid_in_response( + self, mock_post, handler_with_auth_setup + ): """Test job submission with missing UUID in response""" # Setup input_file_name = "test_model.nc.gz" @@ -1048,9 +1151,11 @@ def test_submit_job_missing_uuid_in_response(self, mock_post, handler_with_auth_ with pytest.raises(Exception) as exc_info: handler_with_auth_setup._submit_job_to_compute_service(input_file_name) - assert "Invalid job submission response format: missing field 'uuid'" in str(exc_info.value) + assert "Invalid job submission response format: missing field 'uuid'" in str( + exc_info.value + ) - @patch('linopy.oetc.requests.post') + @patch("linopy.oetc.requests.post") def test_submit_job_network_error(self, mock_post, handler_with_auth_setup): """Test job submission with network error""" # Setup @@ -1065,25 +1170,26 @@ def test_submit_job_network_error(self, mock_post, handler_with_auth_setup): class TestSolveOnOetc: - @pytest.fixture def handler_with_complete_setup(self, mock_gcp_credentials_response): """Create handler with complete setup for testing solve functionality""" - with patch('linopy.oetc.requests.post'), patch('linopy.oetc.requests.get'): - credentials = OetcCredentials(email="test@example.com", password="test_password") + with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Test Job", authentication_server_url="https://auth.example.com", orchestrator_server_url="https://orchestrator.example.com", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) gcp_creds = GcpCredentials( gcp_project_id="test-project-123", gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', input_bucket="test-input-bucket", - solution_bucket="test-solution-bucket" + solution_bucket="test-solution-bucket", ) handler = OetcHandler.__new__(OetcHandler) @@ -1093,10 +1199,12 @@ def handler_with_complete_setup(self, mock_gcp_credentials_response): return handler - @patch('linopy.oetc.linopy.read_netcdf') - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_file_upload(self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_complete_setup): + @patch("linopy.oetc.linopy.read_netcdf") + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_file_upload( + self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_complete_setup + ): """Test solve_on_oetc method complete workflow""" # Setup mock_model = Mock() @@ -1111,30 +1219,53 @@ def test_solve_on_oetc_file_upload(self, mock_tempfile, mock_remove, mock_read_n mock_read_netcdf.return_value = mock_solved_model # Mock file upload, job submission, job waiting, and download - with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', - return_value="uploaded_file.nc.gz") as mock_upload: - with patch.object(handler_with_complete_setup, '_submit_job_to_compute_service', - return_value="test-job-uuid") as mock_submit: - with patch.object(handler_with_complete_setup, 'wait_and_get_job_data', - return_value=JobResult(uuid="test-job-uuid", status="FINISHED", - output_files=[{"name": "result.nc.gz"}])) as mock_wait: - with patch.object(handler_with_complete_setup, '_download_file_from_gcp', - return_value="/tmp/downloaded_result.nc") as mock_download: + with patch.object( + handler_with_complete_setup, + "_upload_file_to_gcp", + return_value="uploaded_file.nc.gz", + ) as mock_upload: + with patch.object( + handler_with_complete_setup, + "_submit_job_to_compute_service", + return_value="test-job-uuid", + ) as mock_submit: + with patch.object( + handler_with_complete_setup, + "wait_and_get_job_data", + return_value=JobResult( + uuid="test-job-uuid", + status="FINISHED", + output_files=[{"name": "result.nc.gz"}], + ), + ) as mock_wait: + with patch.object( + handler_with_complete_setup, + "_download_file_from_gcp", + return_value="/tmp/downloaded_result.nc", + ) as mock_download: # Execute result = handler_with_complete_setup.solve_on_oetc(mock_model) # Verify - assert result == mock_solved_model # Now returns the solved model - mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + assert ( + result == mock_solved_model + ) # Now returns the solved model + mock_model.to_netcdf.assert_called_once_with( + "/tmp/linopy-abc123.nc" + ) mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") mock_submit.assert_called_once_with("uploaded_file.nc.gz") mock_wait.assert_called_once_with("test-job-uuid") mock_download.assert_called_once_with("result.nc.gz") - mock_read_netcdf.assert_called_once_with("/tmp/downloaded_result.nc") + mock_read_netcdf.assert_called_once_with( + "/tmp/downloaded_result.nc" + ) mock_remove.assert_called_once_with("/tmp/downloaded_result.nc") - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete_setup): + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_upload_failure( + self, mock_tempfile, handler_with_complete_setup + ): """Test solve_on_oetc method with upload failure""" # Setup mock_model = Mock() @@ -1143,7 +1274,11 @@ def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete mock_tempfile.return_value.__enter__.return_value = mock_temp_file # Mock upload failure - with patch.object(handler_with_complete_setup, '_upload_file_to_gcp', side_effect=Exception("Upload failed")): + with patch.object( + handler_with_complete_setup, + "_upload_file_to_gcp", + side_effect=Exception("Upload failed"), + ): # Execute and verify exception with pytest.raises(Exception) as exc_info: handler_with_complete_setup.solve_on_oetc(mock_model) @@ -1152,11 +1287,12 @@ def test_solve_on_oetc_upload_failure(self, mock_tempfile, handler_with_complete class TestSolveOnOetcWithJobSubmission: - @pytest.fixture def handler_with_full_setup(self): """Create handler with full setup for testing complete solve flow""" - credentials = OetcCredentials(email="test@example.com", password="test_password") + credentials = OetcCredentials( + email="test@example.com", password="test_password" + ) settings = OetcSettings( credentials=credentials, name="Linopy Solve Job", @@ -1165,14 +1301,14 @@ def handler_with_full_setup(self): compute_provider=ComputeProvider.GCP, solver="highs", cpu_cores=2, - disk_space_gb=15 + disk_space_gb=15, ) gcp_creds = GcpCredentials( gcp_project_id="test-project-123", gcp_service_key='{"type": "service_account", "project_id": "test-project-123"}', input_bucket="test-input-bucket", - solution_bucket="test-solution-bucket" + solution_bucket="test-solution-bucket", ) handler = OetcHandler.__new__(OetcHandler) @@ -1182,11 +1318,12 @@ def handler_with_full_setup(self): return handler - @patch('linopy.oetc.linopy.read_netcdf') - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_with_job_submission(self, mock_tempfile, mock_remove, mock_read_netcdf, - handler_with_full_setup): + @patch("linopy.oetc.linopy.read_netcdf") + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_with_job_submission( + self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_full_setup + ): """Test solve_on_oetc method including job submission, waiting, and download""" # Setup mock_model = Mock() @@ -1204,30 +1341,53 @@ def test_solve_on_oetc_with_job_submission(self, mock_tempfile, mock_remove, moc job_uuid = "job-uuid-456" # Mock complete workflow - with patch.object(handler_with_full_setup, '_upload_file_to_gcp', - return_value=uploaded_file_name) as mock_upload: - with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', - return_value=job_uuid) as mock_submit: - with patch.object(handler_with_full_setup, 'wait_and_get_job_data', - return_value=JobResult(uuid=job_uuid, status="FINISHED", - output_files=[{"name": "result.nc.gz"}])) as mock_wait: - with patch.object(handler_with_full_setup, '_download_file_from_gcp', - return_value="/tmp/solution_file.nc") as mock_download: + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ) as mock_upload: + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + return_value=job_uuid, + ) as mock_submit: + with patch.object( + handler_with_full_setup, + "wait_and_get_job_data", + return_value=JobResult( + uuid=job_uuid, + status="FINISHED", + output_files=[{"name": "result.nc.gz"}], + ), + ) as mock_wait: + with patch.object( + handler_with_full_setup, + "_download_file_from_gcp", + return_value="/tmp/solution_file.nc", + ) as mock_download: # Execute result = handler_with_full_setup.solve_on_oetc(mock_model) # Verify - assert result == mock_solved_model # Now returns the solved model - mock_model.to_netcdf.assert_called_once_with("/tmp/linopy-abc123.nc") + assert ( + result == mock_solved_model + ) # Now returns the solved model + mock_model.to_netcdf.assert_called_once_with( + "/tmp/linopy-abc123.nc" + ) mock_upload.assert_called_once_with("/tmp/linopy-abc123.nc") mock_submit.assert_called_once_with(uploaded_file_name) mock_wait.assert_called_once_with(job_uuid) mock_download.assert_called_once_with("result.nc.gz") - mock_read_netcdf.assert_called_once_with("/tmp/solution_file.nc") + mock_read_netcdf.assert_called_once_with( + "/tmp/solution_file.nc" + ) mock_remove.assert_called_once_with("/tmp/solution_file.nc") - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_job_submission_failure(self, mock_tempfile, handler_with_full_setup): + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_job_submission_failure( + self, mock_tempfile, handler_with_full_setup + ): """Test solve_on_oetc method with job submission failure""" # Setup mock_model = Mock() @@ -1238,17 +1398,26 @@ def test_solve_on_oetc_job_submission_failure(self, mock_tempfile, handler_with_ uploaded_file_name = "model_file.nc.gz" # Mock successful upload but failed job submission - no need to mock waiting since submission fails - with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name): - with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', - side_effect=Exception("Job submission failed")): + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ): + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + side_effect=Exception("Job submission failed"), + ): # Execute and verify exception with pytest.raises(Exception) as exc_info: handler_with_full_setup.solve_on_oetc(mock_model) assert "Job submission failed" in str(exc_info.value) - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile, handler_with_full_setup): + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_job_waiting_failure( + self, mock_tempfile, handler_with_full_setup + ): """Test solve_on_oetc method with job waiting failure""" # Setup mock_model = Mock() @@ -1260,19 +1429,31 @@ def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile, handler_with_ful job_uuid = "job-uuid-failed" # Mock successful upload and job submission but failed job waiting - with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name): - with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', - return_value=job_uuid): - with patch.object(handler_with_full_setup, 'wait_and_get_job_data', - side_effect=Exception("Job failed: solver error")): + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ): + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + return_value=job_uuid, + ): + with patch.object( + handler_with_full_setup, + "wait_and_get_job_data", + side_effect=Exception("Job failed: solver error"), + ): # Execute and verify exception with pytest.raises(Exception) as exc_info: handler_with_full_setup.solve_on_oetc(mock_model) assert "Job failed: solver error" in str(exc_info.value) - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_no_output_files_error(self, mock_tempfile, handler_with_full_setup): + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_no_output_files_error( + self, mock_tempfile, handler_with_full_setup + ): """Test solve_on_oetc method when job completes but has no output files""" # Setup mock_model = Mock() @@ -1284,25 +1465,37 @@ def test_solve_on_oetc_no_output_files_error(self, mock_tempfile, handler_with_f job_uuid = "job-uuid-456" # Mock successful workflow until job completion with no output files - with patch.object(handler_with_full_setup, '_upload_file_to_gcp', return_value=uploaded_file_name): - with patch.object(handler_with_full_setup, '_submit_job_to_compute_service', return_value=job_uuid): - with patch.object(handler_with_full_setup, 'wait_and_get_job_data', - return_value=JobResult(uuid=job_uuid, status="FINISHED", - output_files=[])): # No output files - + with patch.object( + handler_with_full_setup, + "_upload_file_to_gcp", + return_value=uploaded_file_name, + ): + with patch.object( + handler_with_full_setup, + "_submit_job_to_compute_service", + return_value=job_uuid, + ): + with patch.object( + handler_with_full_setup, + "wait_and_get_job_data", + return_value=JobResult( + uuid=job_uuid, status="FINISHED", output_files=[] + ), + ): # No output files # Execute and verify exception with pytest.raises(Exception) as exc_info: handler_with_full_setup.solve_on_oetc(mock_model) - assert "No output files found in completed job" in str(exc_info.value) + assert "No output files found in completed job" in str( + exc_info.value + ) # Additional integration-style test class TestOetcHandlerIntegration: - - @patch('linopy.oetc.requests.post') - @patch('linopy.oetc.requests.get') - @patch('linopy.oetc.datetime') + @patch("linopy.oetc.requests.post") + @patch("linopy.oetc.requests.get") + @patch("linopy.oetc.datetime") def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): """Test complete authentication and credentials flow with realistic data""" # Setup @@ -1310,15 +1503,14 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): mock_datetime.now.return_value = fixed_time credentials = OetcCredentials( - email="user@company.com", - password="secure_password_123" + email="user@company.com", password="secure_password_123" ) settings = OetcSettings( credentials=credentials, name="Integration Test Job", authentication_server_url="https://api.company.com/auth", orchestrator_server_url="https://api.company.com/orchestrator", - compute_provider=ComputeProvider.GCP + compute_provider=ComputeProvider.GCP, ) # Create realistic JWT token @@ -1329,10 +1521,18 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): "jti": "jwt-id-789", "email": "user@company.com", "firstname": "John", - "lastname": "Doe" + "lastname": "Doe", } - header = base64.urlsafe_b64encode(json.dumps({"alg": "HS256", "typ": "JWT"}).encode()).decode().rstrip('=') - payload_encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') + header = ( + base64.urlsafe_b64encode( + json.dumps({"alg": "HS256", "typ": "JWT"}).encode() + ) + .decode() + .rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) realistic_token = f"{header}.{payload_encoded}.realistic_signature" # Mock authentication response @@ -1340,7 +1540,7 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): mock_auth_response.json.return_value = { "token": realistic_token, "token_type": "Bearer", - "expires_in": 7200 # 2 hours + "expires_in": 7200, # 2 hours } mock_auth_response.raise_for_status.return_value = None mock_post.return_value = mock_auth_response @@ -1351,7 +1551,7 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): "gcp_project_id": "production-project-456", "gcp_service_key": "production-service-key-content", "input_bucket": "prod-input-bucket", - "solution_bucket": "prod-solution-bucket" + "solution_bucket": "prod-solution-bucket", } mock_gcp_response.raise_for_status.return_value = None mock_get.return_value = mock_gcp_response @@ -1364,29 +1564,39 @@ def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): assert handler.jwt.token_type == "Bearer" assert handler.jwt.expires_in == 7200 assert handler.jwt.authenticated_at == fixed_time - assert handler.jwt.expires_at == datetime(2024, 1, 15, 14, 0, 0) # 2 hours later + assert handler.jwt.expires_at == datetime( + 2024, 1, 15, 14, 0, 0 + ) # 2 hours later assert handler.jwt.is_expired is False # Verify GCP credentials assert isinstance(handler.cloud_provider_credentials, GcpCredentials) - assert handler.cloud_provider_credentials.gcp_project_id == "production-project-456" - assert handler.cloud_provider_credentials.gcp_service_key == "production-service-key-content" + assert ( + handler.cloud_provider_credentials.gcp_project_id + == "production-project-456" + ) + assert ( + handler.cloud_provider_credentials.gcp_service_key + == "production-service-key-content" + ) assert handler.cloud_provider_credentials.input_bucket == "prod-input-bucket" - assert handler.cloud_provider_credentials.solution_bucket == "prod-solution-bucket" + assert ( + handler.cloud_provider_credentials.solution_bucket == "prod-solution-bucket" + ) # Verify correct API calls were made mock_post.assert_called_once_with( "https://api.company.com/auth/sign-in", json={"email": "user@company.com", "password": "secure_password_123"}, headers={"Content-Type": "application/json"}, - timeout=30 + timeout=30, ) mock_get.assert_called_once_with( "https://api.company.com/auth/users/user-uuid-456/gcp-credentials", headers={ "Authorization": f"Bearer {realistic_token}", - "Content-Type": "application/json" + "Content-Type": "application/json", }, - timeout=30 + timeout=30, ) From 26f87869162d692a36b8ae523b0c26b49db79091 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 4 Aug 2025 16:26:18 +0200 Subject: [PATCH 12/28] Fix type hint error caught by mypy --- linopy/model.py | 2 +- linopy/oetc.py | 28 +++-- pyproject.toml | 1 + test/test_oetc.py | 287 ++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 281 insertions(+), 37 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index 81b9bb09..1a8a35fa 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1050,7 +1050,7 @@ def solve( slice_size: int = 2_000_000, remote: Any = None, progress: bool | None = None, - oetc_settings: OetcSettings = None, + oetc_settings: OetcSettings | None = None, **solver_options: Any, ) -> tuple[str, str]: """ diff --git a/linopy/oetc.py b/linopy/oetc.py index 7a55986a..3cda18da 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -14,8 +14,11 @@ from google.oauth2 import service_account from requests import RequestException +<<<<<<< Updated upstream import linopy +======= +>>>>>>> Stashed changes logger = logging.getLogger(__name__) @@ -73,15 +76,14 @@ def is_expired(self) -> bool: class JobResult: uuid: str status: str - name: str = None - owner: str = None - solver: str = None - duration_in_seconds: int = None - solving_duration_in_seconds: int = None - input_files: list = None - output_files: list = None - created_at: str = None - + name: str | None = None + owner: str | None = None + solver: str | None = None + duration_in_seconds: int | None = None + solving_duration_in_seconds: int | None = None + input_files: list | None = None + output_files: list | None = None + created_at: str | None = None class OetcHandler: def __init__(self, settings: OetcSettings) -> None: @@ -153,12 +155,16 @@ def _decode_jwt_payload(self, token: str) -> dict: except (IndexError, json.JSONDecodeError, Exception) as e: raise Exception(f"Failed to decode JWT payload: {e}") +<<<<<<< Updated upstream def __get_cloud_provider_credentials(self) -> GcpCredentials | None: +======= + def __get_cloud_provider_credentials(self) -> GcpCredentials: +>>>>>>> Stashed changes """ Fetch cloud provider credentials based on the configured provider. Returns: - Union[GcpCredentials, None]: The cloud provider credentials + GcpCredentials: The cloud provider credentials Raises: Exception: If the compute provider is not supported @@ -471,7 +477,7 @@ def _download_file_from_gcp(self, file_name: str) -> str: except Exception as e: raise Exception(f"Failed to download file from GCP: {e}") - def solve_on_oetc(self, model): + def solve_on_oetc(self, model): # type: ignore """ Solve a linopy model on the OET Cloud compute app. diff --git a/pyproject.toml b/pyproject.toml index 59a823a6..08f14d42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ dev = [ "netcdf4", "paramiko", "types-paramiko", + "types-requests", "gurobipy", "highspy", ] diff --git a/test/test_oetc.py b/test/test_oetc.py index 9fabf9fa..d4a1e909 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -1,7 +1,11 @@ import base64 import json +<<<<<<< Updated upstream from datetime import datetime from unittest.mock import Mock, patch +======= +from typing import Any +>>>>>>> Stashed changes import pytest import requests @@ -19,7 +23,7 @@ @pytest.fixture -def sample_jwt_token(): +def sample_jwt_token() -> str: """Create a sample JWT token with test payload""" payload = { "iss": "OETC", @@ -46,7 +50,7 @@ def sample_jwt_token(): @pytest.fixture -def mock_gcp_credentials_response(): +def mock_gcp_credentials_response() -> dict: """Create a mock GCP credentials response""" return { "gcp_project_id": "test-project-123", @@ -57,7 +61,7 @@ def mock_gcp_credentials_response(): @pytest.fixture -def mock_settings(): +def mock_settings() -> OetcSettings: """Create mock settings for testing""" credentials = OetcCredentials(email="test@example.com", password="test_password") return OetcSettings( @@ -71,7 +75,7 @@ def mock_settings(): class TestOetcHandler: @pytest.fixture - def mock_jwt_response(self): + def mock_jwt_response(self) -> dict: """Create a mock JWT response""" return { "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", @@ -79,6 +83,7 @@ def mock_jwt_response(self): "expires_in": 3600, } +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") @patch("linopy.oetc.requests.get") @patch("linopy.oetc.datetime") @@ -92,6 +97,13 @@ def test_successful_authentication( mock_gcp_credentials_response, sample_jwt_token, ): +======= + @patch('linopy.oetc.requests.post') + @patch('linopy.oetc.requests.get') + @patch('linopy.oetc.datetime') + def test_successful_authentication(self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock, mock_settings: OetcSettings, mock_jwt_response: dict, + mock_gcp_credentials_response: dict, sample_jwt_token: str) -> None: +>>>>>>> Stashed changes """Test successful authentication flow""" # Setup mocks fixed_time = datetime(2024, 1, 15, 12, 0, 0) @@ -150,8 +162,13 @@ def test_successful_authentication( handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" ) +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_authentication_http_error(self, mock_post, mock_settings): +======= + @patch('linopy.oetc.requests.post') + def test_authentication_http_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: +>>>>>>> Stashed changes """Test authentication failure with HTTP error""" # Setup mock to raise HTTP error mock_response = Mock() @@ -169,7 +186,7 @@ def test_authentication_http_error(self, mock_post, mock_settings): class TestJwtDecoding: @pytest.fixture - def handler_with_mocked_auth(self): + def handler_with_mocked_auth(self) -> OetcHandler: """Create handler with mocked authentication for testing JWT decoding""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -194,13 +211,17 @@ def handler_with_mocked_auth(self): handler = OetcHandler.__new__(OetcHandler) handler.settings = settings handler.jwt = mock_auth_result - handler.cloud_provider_credentials = None + handler.cloud_provider_credentials = None # type: ignore return handler +<<<<<<< Updated upstream def test_decode_jwt_payload_success( self, handler_with_mocked_auth, sample_jwt_token ): +======= + def test_decode_jwt_payload_success(self, handler_with_mocked_auth: OetcHandler, sample_jwt_token: str) -> None: +>>>>>>> Stashed changes """Test successful JWT payload decoding""" result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) @@ -210,14 +231,14 @@ def test_decode_jwt_payload_success( assert result["firstname"] == "Test" assert result["lastname"] == "User" - def test_decode_jwt_payload_invalid_token(self, handler_with_mocked_auth): + def test_decode_jwt_payload_invalid_token(self, handler_with_mocked_auth: OetcHandler) -> None: """Test JWT payload decoding with invalid token""" with pytest.raises(Exception) as exc_info: handler_with_mocked_auth._decode_jwt_payload("invalid.token") assert "Failed to decode JWT payload" in str(exc_info.value) - def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth): + def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth: OetcHandler) -> None: """Test JWT payload decoding with malformed token""" with pytest.raises(Exception) as exc_info: handler_with_mocked_auth._decode_jwt_payload("not_a_jwt_token") @@ -227,7 +248,7 @@ def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth): class TestCloudProviderCredentials: @pytest.fixture - def handler_with_mocked_auth(self, sample_jwt_token): + def handler_with_mocked_auth(self, sample_jwt_token: str) -> OetcHandler: """Create handler with mocked authentication for testing credentials fetching""" credentials = OetcCredentials( email="test@example.com", password="test_password" @@ -254,10 +275,15 @@ def handler_with_mocked_auth(self, sample_jwt_token): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.requests.get") def test_get_gcp_credentials_success( self, mock_get, handler_with_mocked_auth, mock_gcp_credentials_response ): +======= + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_success(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler, mock_gcp_credentials_response: dict) -> None: +>>>>>>> Stashed changes """Test successful GCP credentials fetching""" # Setup mock mock_response = Mock() @@ -266,7 +292,7 @@ def test_get_gcp_credentials_success( mock_get.return_value = mock_response # Execute - result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] # Verify request mock_get.assert_called_once_with( @@ -285,8 +311,13 @@ def test_get_gcp_credentials_success( assert result.input_bucket == "test-input-bucket" assert result.solution_bucket == "test-solution-bucket" +<<<<<<< Updated upstream @patch("linopy.oetc.requests.get") def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth): +======= + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_http_error(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler) -> None: +>>>>>>> Stashed changes """Test GCP credentials fetching with HTTP error""" # Setup mock to raise HTTP error mock_response = Mock() @@ -295,14 +326,19 @@ def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] assert "Failed to fetch GCP credentials" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.requests.get") def test_get_gcp_credentials_missing_field( self, mock_get, handler_with_mocked_auth ): +======= + @patch('linopy.oetc.requests.get') + def test_get_gcp_credentials_missing_field(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler) -> None: +>>>>>>> Stashed changes """Test GCP credentials fetching with missing response field""" # Setup mock with invalid response mock_response = Mock() @@ -317,13 +353,14 @@ def test_get_gcp_credentials_missing_field( # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] assert ( "Invalid credentials response format: missing field 'solution_bucket'" in str(exc_info.value) ) +<<<<<<< Updated upstream def test_get_cloud_provider_credentials_unsupported_provider( self, handler_with_mocked_auth ): @@ -332,14 +369,20 @@ def test_get_cloud_provider_credentials_unsupported_provider( handler_with_mocked_auth.settings.compute_provider = ( "AWS" # Not in enum, but for testing ) +======= + def test_get_cloud_provider_credentials_unsupported_provider(self, handler_with_mocked_auth: OetcHandler) -> None: + """Test cloud provider credentials with unsupported provider""" + # Change to unsupported provider + handler_with_mocked_auth.settings.compute_provider = "AWS" # type: ignore # Not in enum, but for testing +>>>>>>> Stashed changes # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() + handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() # type: ignore[attr-defined] assert "Unsupported compute provider: AWS" in str(exc_info.value) - def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_auth): + def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_auth: OetcHandler) -> None: """Test GCP credentials fetching when JWT token has no user UUID""" # Create token without 'sub' field payload = {"iss": "OETC", "email": "test@example.com"} @@ -352,13 +395,18 @@ def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_aut # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] assert "User UUID not found in JWT token" in str(exc_info.value) class TestGcpCredentials: +<<<<<<< Updated upstream def test_gcp_credentials_creation(self): +======= + + def test_gcp_credentials_creation(self) -> None: +>>>>>>> Stashed changes """Test GcpCredentials dataclass creation""" credentials = GcpCredentials( gcp_project_id="test-project-123", @@ -374,13 +422,23 @@ def test_gcp_credentials_creation(self): class TestComputeProvider: +<<<<<<< Updated upstream def test_compute_provider_enum(self): +======= + + def test_compute_provider_enum(self) -> None: +>>>>>>> Stashed changes """Test ComputeProvider enum values""" assert ComputeProvider.GCP == "GCP" assert ComputeProvider.GCP.value == "GCP" +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_authentication_network_error(self, mock_post, mock_settings): +======= + @patch('linopy.oetc.requests.post') + def test_authentication_network_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: +>>>>>>> Stashed changes """Test authentication failure with network error""" # Setup mock to raise network error mock_post.side_effect = RequestException("Connection timeout") @@ -391,10 +449,15 @@ def test_authentication_network_error(self, mock_post, mock_settings): assert "Authentication request failed" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_authentication_invalid_response_missing_token( self, mock_post, mock_settings ): +======= + @patch('linopy.oetc.requests.post') + def test_authentication_invalid_response_missing_token(self, mock_post: Mock, mock_settings: OetcSettings) -> None: +>>>>>>> Stashed changes """Test authentication failure with missing token in response""" # Setup mock with invalid response mock_response = Mock() @@ -412,10 +475,15 @@ def test_authentication_invalid_response_missing_token( assert "Invalid response format: missing field 'token'" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_authentication_invalid_response_missing_expires_in( self, mock_post, mock_settings ): +======= + @patch('linopy.oetc.requests.post') + def test_authentication_invalid_response_missing_expires_in(self, mock_post: Mock, mock_settings: OetcSettings) -> None: +>>>>>>> Stashed changes """Test authentication failure with missing expires_in in response""" # Setup mock with invalid response mock_response = Mock() @@ -435,8 +503,13 @@ def test_authentication_invalid_response_missing_expires_in( exc_info.value ) +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_authentication_timeout_error(self, mock_post, mock_settings): +======= + @patch('linopy.oetc.requests.post') + def test_authentication_timeout_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: +>>>>>>> Stashed changes """Test authentication failure with timeout""" # Setup mock to raise timeout error mock_post.side_effect = requests.Timeout("Request timeout") @@ -450,7 +523,7 @@ def test_authentication_timeout_error(self, mock_post, mock_settings): class TestAuthenticationResult: @pytest.fixture - def auth_result(self): + def auth_result(self) -> AuthenticationResult: """Create an AuthenticationResult for testing""" return AuthenticationResult( token="test_token", @@ -459,29 +532,44 @@ def auth_result(self): authenticated_at=datetime(2024, 1, 15, 12, 0, 0), ) - def test_expires_at_calculation(self, auth_result): + def test_expires_at_calculation(self, auth_result: AuthenticationResult) -> None: """Test that expires_at correctly calculates expiration time""" expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later assert auth_result.expires_at == expected_expiry +<<<<<<< Updated upstream @patch("linopy.oetc.datetime") def test_is_expired_false_when_not_expired(self, mock_datetime, auth_result): +======= + @patch('linopy.oetc.datetime') + def test_is_expired_false_when_not_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: +>>>>>>> Stashed changes """Test is_expired returns False when token is still valid""" # Set current time to before expiration mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 30, 0) assert auth_result.is_expired is False +<<<<<<< Updated upstream @patch("linopy.oetc.datetime") def test_is_expired_true_when_expired(self, mock_datetime, auth_result): +======= + @patch('linopy.oetc.datetime') + def test_is_expired_true_when_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: +>>>>>>> Stashed changes """Test is_expired returns True when token has expired""" # Set current time to after expiration mock_datetime.now.return_value = datetime(2024, 1, 15, 14, 0, 0) assert auth_result.is_expired is True +<<<<<<< Updated upstream @patch("linopy.oetc.datetime") def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): +======= + @patch('linopy.oetc.datetime') + def test_is_expired_true_when_exactly_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: +>>>>>>> Stashed changes """Test is_expired returns True when token expires exactly now""" # Set current time to exact expiration time mock_datetime.now.return_value = datetime(2024, 1, 15, 13, 0, 0) @@ -491,7 +579,7 @@ def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): class TestFileCompression: @pytest.fixture - def handler_with_mocked_auth(self): + def handler_with_mocked_auth(self) -> OetcHandler: """Create handler with mocked authentication for testing file operations""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -512,12 +600,19 @@ def handler_with_mocked_auth(self): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.gzip.open") @patch("linopy.oetc.os.path.exists") @patch("builtins.open") def test_gzip_compress_success( self, mock_open, mock_exists, mock_gzip_open, handler_with_mocked_auth ): +======= + @patch('linopy.oetc.gzip.open') + @patch('linopy.oetc.os.path.exists') + @patch('builtins.open') + def test_gzip_compress_success(self, mock_open: Mock, mock_exists: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: +>>>>>>> Stashed changes """Test successful file compression""" # Setup source_path = "/tmp/test_file.nc" @@ -545,8 +640,13 @@ def test_gzip_compress_success( mock_gzip_open.assert_called_once_with(expected_output, "wb", compresslevel=9) mock_file_out.write.assert_called_once_with(b"test_data_chunk") +<<<<<<< Updated upstream @patch("builtins.open") def test_gzip_compress_file_read_error(self, mock_open, handler_with_mocked_auth): +======= + @patch('builtins.open') + def test_gzip_compress_file_read_error(self, mock_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: +>>>>>>> Stashed changes """Test file compression with read error""" # Setup source_path = "/tmp/test_file.nc" @@ -561,7 +661,7 @@ def test_gzip_compress_file_read_error(self, mock_open, handler_with_mocked_auth class TestGcpUpload: @pytest.fixture - def handler_with_gcp_credentials(self, mock_gcp_credentials_response): + def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> OetcHandler: """Create handler with GCP credentials for testing upload""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -590,6 +690,7 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.os.remove") @patch("linopy.oetc.os.path.basename") @patch("linopy.oetc.storage.Client") @@ -602,6 +703,14 @@ def test_upload_file_to_gcp_success( mock_remove, handler_with_gcp_credentials, ): +======= + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.os.path.basename') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_upload_file_to_gcp_success(self, mock_creds_from_info: Mock, mock_storage_client: Mock, mock_basename: Mock, mock_remove: Mock, + handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test successful file upload to GCP""" # Setup file_path = "/tmp/test_file.nc" @@ -655,10 +764,15 @@ def test_upload_file_to_gcp_success( # Verify cleanup mock_remove.assert_called_once_with(compressed_path) +<<<<<<< Updated upstream @patch("linopy.oetc.json.loads") def test_upload_file_to_gcp_invalid_service_key( self, mock_json_loads, handler_with_gcp_credentials ): +======= + @patch('linopy.oetc.json.loads') + def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test upload failure with invalid service key""" # Setup file_path = "/tmp/test_file.nc" @@ -670,11 +784,18 @@ def test_upload_file_to_gcp_invalid_service_key( assert "Failed to upload file to GCP" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.storage.Client") @patch("linopy.oetc.service_account.Credentials.from_service_account_info") def test_upload_file_to_gcp_upload_error( self, mock_creds_from_info, mock_storage_client, handler_with_gcp_credentials ): +======= + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, + handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test upload failure during blob upload""" # Setup file_path = "/tmp/test_file.nc" @@ -707,7 +828,7 @@ def test_upload_file_to_gcp_upload_error( class TestFileDecompression: @pytest.fixture - def handler_with_mocked_auth(self): + def handler_with_mocked_auth(self) -> OetcHandler: """Create handler with mocked authentication for testing file operations""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -728,11 +849,17 @@ def handler_with_mocked_auth(self): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.gzip.open") @patch("builtins.open") def test_gzip_decompress_success( self, mock_open_file, mock_gzip_open, handler_with_mocked_auth ): +======= + @patch('linopy.oetc.gzip.open') + @patch('builtins.open') + def test_gzip_decompress_success(self, mock_open_file: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: +>>>>>>> Stashed changes """Test successful file decompression""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -759,10 +886,15 @@ def test_gzip_decompress_success( mock_open_file.assert_called_once_with(expected_output, "wb") mock_file_out.write.assert_called_once_with(b"decompressed_data_chunk") +<<<<<<< Updated upstream @patch("linopy.oetc.gzip.open") def test_gzip_decompress_gzip_open_error( self, mock_gzip_open, handler_with_mocked_auth ): +======= + @patch('linopy.oetc.gzip.open') + def test_gzip_decompress_gzip_open_error(self, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: +>>>>>>> Stashed changes """Test file decompression with gzip open error""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -774,11 +906,17 @@ def test_gzip_decompress_gzip_open_error( assert "Failed to decompress file" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.gzip.open") @patch("builtins.open") def test_gzip_decompress_write_error( self, mock_open_file, mock_gzip_open, handler_with_mocked_auth ): +======= + @patch('linopy.oetc.gzip.open') + @patch('builtins.open') + def test_gzip_decompress_write_error(self, mock_open_file: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: +>>>>>>> Stashed changes """Test file decompression with write error""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -797,7 +935,7 @@ def test_gzip_decompress_write_error( assert "Failed to decompress file" in str(exc_info.value) - def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth): + def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth: OetcHandler) -> None: """Test correct output path generation for decompression""" # Test first path with patch("linopy.oetc.gzip.open") as mock_gzip_open: @@ -828,7 +966,7 @@ def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth): class TestGcpDownload: @pytest.fixture - def handler_with_gcp_credentials(self, mock_gcp_credentials_response): + def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> OetcHandler: """Create handler with GCP credentials for testing download""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -857,6 +995,7 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.os.remove") @patch("linopy.oetc.tempfile.NamedTemporaryFile") @patch("linopy.oetc.storage.Client") @@ -869,6 +1008,14 @@ def test_download_file_from_gcp_success( mock_remove, handler_with_gcp_credentials, ): +======= + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_success(self, mock_creds_from_info: Mock, mock_storage_client: Mock, + mock_tempfile: Mock, mock_remove: Mock, handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test successful file download from GCP""" # Setup file_name = "solution_file.nc.gz" @@ -926,10 +1073,15 @@ def test_download_file_from_gcp_success( # Verify cleanup mock_remove.assert_called_once_with(compressed_path) +<<<<<<< Updated upstream @patch("linopy.oetc.json.loads") def test_download_file_from_gcp_invalid_service_key( self, mock_json_loads, handler_with_gcp_credentials ): +======= + @patch('linopy.oetc.json.loads') + def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test download failure with invalid service key""" # Setup file_name = "solution_file.nc.gz" @@ -941,6 +1093,7 @@ def test_download_file_from_gcp_invalid_service_key( assert "Failed to download file from GCP" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.tempfile.NamedTemporaryFile") @patch("linopy.oetc.storage.Client") @patch("linopy.oetc.service_account.Credentials.from_service_account_info") @@ -951,6 +1104,13 @@ def test_download_file_from_gcp_download_error( mock_tempfile, handler_with_gcp_credentials, ): +======= + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_download_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, + mock_tempfile: Mock, handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test download failure during blob download""" # Setup file_name = "solution_file.nc.gz" @@ -981,6 +1141,7 @@ def test_download_file_from_gcp_download_error( assert "Failed to download file from GCP" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.os.remove") @patch("linopy.oetc.tempfile.NamedTemporaryFile") @patch("linopy.oetc.storage.Client") @@ -993,6 +1154,14 @@ def test_download_file_from_gcp_decompression_error( mock_remove, handler_with_gcp_credentials, ): +======= + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + @patch('linopy.oetc.storage.Client') + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, + mock_tempfile: Mock, mock_remove: Mock, handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test download failure during decompression""" # Setup file_name = "solution_file.nc.gz" @@ -1028,10 +1197,15 @@ def test_download_file_from_gcp_decompression_error( assert "Failed to download file from GCP" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.service_account.Credentials.from_service_account_info") def test_download_file_from_gcp_credentials_error( self, mock_creds_from_info, handler_with_gcp_credentials ): +======= + @patch('linopy.oetc.service_account.Credentials.from_service_account_info') + def test_download_file_from_gcp_credentials_error(self, mock_creds_from_info: Mock, handler_with_gcp_credentials: OetcHandler) -> None: +>>>>>>> Stashed changes """Test download failure during credentials creation""" # Setup file_name = "solution_file.nc.gz" @@ -1046,7 +1220,7 @@ def test_download_file_from_gcp_credentials_error( class TestJobSubmission: @pytest.fixture - def handler_with_auth_setup(self, sample_jwt_token): + def handler_with_auth_setup(self, sample_jwt_token: str) -> OetcHandler: """Create handler with authentication setup for testing job submission""" credentials = OetcCredentials( email="test@example.com", password="test_password" @@ -1077,8 +1251,13 @@ def handler_with_auth_setup(self, sample_jwt_token): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_submit_job_success(self, mock_post, handler_with_auth_setup): +======= + @patch('linopy.oetc.requests.post') + def test_submit_job_success(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test successful job submission to compute service""" # Setup input_file_name = "test_model.nc.gz" @@ -1118,8 +1297,13 @@ def test_submit_job_success(self, mock_post, handler_with_auth_setup): # Verify result assert result == expected_job_uuid +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_submit_job_http_error(self, mock_post, handler_with_auth_setup): +======= + @patch('linopy.oetc.requests.post') + def test_submit_job_http_error(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test job submission with HTTP error""" # Setup input_file_name = "test_model.nc.gz" @@ -1135,10 +1319,15 @@ def test_submit_job_http_error(self, mock_post, handler_with_auth_setup): assert "Failed to submit job to compute service" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_submit_job_missing_uuid_in_response( self, mock_post, handler_with_auth_setup ): +======= + @patch('linopy.oetc.requests.post') + def test_submit_job_missing_uuid_in_response(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test job submission with missing UUID in response""" # Setup input_file_name = "test_model.nc.gz" @@ -1155,8 +1344,13 @@ def test_submit_job_missing_uuid_in_response( exc_info.value ) +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") def test_submit_job_network_error(self, mock_post, handler_with_auth_setup): +======= + @patch('linopy.oetc.requests.post') + def test_submit_job_network_error(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test job submission with network error""" # Setup input_file_name = "test_model.nc.gz" @@ -1171,7 +1365,7 @@ def test_submit_job_network_error(self, mock_post, handler_with_auth_setup): class TestSolveOnOetc: @pytest.fixture - def handler_with_complete_setup(self, mock_gcp_credentials_response): + def handler_with_complete_setup(self, mock_gcp_credentials_response: dict) -> OetcHandler: """Create handler with complete setup for testing solve functionality""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -1199,12 +1393,19 @@ def handler_with_complete_setup(self, mock_gcp_credentials_response): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.linopy.read_netcdf") @patch("linopy.oetc.os.remove") @patch("linopy.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_file_upload( self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_complete_setup ): +======= + @patch('linopy.oetc.linopy.read_netcdf') + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_file_upload(self, mock_tempfile: Mock, mock_remove: Mock, mock_read_netcdf: Mock, handler_with_complete_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test solve_on_oetc method complete workflow""" # Setup mock_model = Mock() @@ -1262,10 +1463,15 @@ def test_solve_on_oetc_file_upload( ) mock_remove.assert_called_once_with("/tmp/downloaded_result.nc") +<<<<<<< Updated upstream @patch("linopy.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_upload_failure( self, mock_tempfile, handler_with_complete_setup ): +======= + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_upload_failure(self, mock_tempfile: Mock, handler_with_complete_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test solve_on_oetc method with upload failure""" # Setup mock_model = Mock() @@ -1288,7 +1494,7 @@ def test_solve_on_oetc_upload_failure( class TestSolveOnOetcWithJobSubmission: @pytest.fixture - def handler_with_full_setup(self): + def handler_with_full_setup(self) -> OetcHandler: """Create handler with full setup for testing complete solve flow""" credentials = OetcCredentials( email="test@example.com", password="test_password" @@ -1318,12 +1524,20 @@ def handler_with_full_setup(self): return handler +<<<<<<< Updated upstream @patch("linopy.oetc.linopy.read_netcdf") @patch("linopy.oetc.os.remove") @patch("linopy.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_with_job_submission( self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_full_setup ): +======= + @patch('linopy.oetc.linopy.read_netcdf') + @patch('linopy.oetc.os.remove') + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_with_job_submission(self, mock_tempfile: Mock, mock_remove: Mock, mock_read_netcdf: Mock, + handler_with_full_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test solve_on_oetc method including job submission, waiting, and download""" # Setup mock_model = Mock() @@ -1384,10 +1598,15 @@ def test_solve_on_oetc_with_job_submission( ) mock_remove.assert_called_once_with("/tmp/solution_file.nc") +<<<<<<< Updated upstream @patch("linopy.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_job_submission_failure( self, mock_tempfile, handler_with_full_setup ): +======= + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_job_submission_failure(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test solve_on_oetc method with job submission failure""" # Setup mock_model = Mock() @@ -1414,10 +1633,15 @@ def test_solve_on_oetc_job_submission_failure( assert "Job submission failed" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_job_waiting_failure( self, mock_tempfile, handler_with_full_setup ): +======= + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test solve_on_oetc method with job waiting failure""" # Setup mock_model = Mock() @@ -1450,10 +1674,15 @@ def test_solve_on_oetc_job_waiting_failure( assert "Job failed: solver error" in str(exc_info.value) +<<<<<<< Updated upstream @patch("linopy.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_no_output_files_error( self, mock_tempfile, handler_with_full_setup ): +======= + @patch('linopy.oetc.tempfile.NamedTemporaryFile') + def test_solve_on_oetc_no_output_files_error(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: +>>>>>>> Stashed changes """Test solve_on_oetc method when job completes but has no output files""" # Setup mock_model = Mock() @@ -1493,10 +1722,18 @@ def test_solve_on_oetc_no_output_files_error( # Additional integration-style test class TestOetcHandlerIntegration: +<<<<<<< Updated upstream @patch("linopy.oetc.requests.post") @patch("linopy.oetc.requests.get") @patch("linopy.oetc.datetime") def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): +======= + + @patch('linopy.oetc.requests.post') + @patch('linopy.oetc.requests.get') + @patch('linopy.oetc.datetime') + def test_complete_authentication_flow(self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock) -> None: +>>>>>>> Stashed changes """Test complete authentication and credentials flow with realistic data""" # Setup fixed_time = datetime(2024, 1, 15, 12, 0, 0) From 467786f5d0b26b23f70fd2edd37244b4751226ca Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 4 Aug 2025 16:34:54 +0200 Subject: [PATCH 13/28] Fix conflicts --- linopy/oetc.py | 7 -- test/test_oetc.py | 310 +--------------------------------------------- 2 files changed, 1 insertion(+), 316 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 3cda18da..79b06fd6 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -14,11 +14,8 @@ from google.oauth2 import service_account from requests import RequestException -<<<<<<< Updated upstream import linopy -======= ->>>>>>> Stashed changes logger = logging.getLogger(__name__) @@ -155,11 +152,7 @@ def _decode_jwt_payload(self, token: str) -> dict: except (IndexError, json.JSONDecodeError, Exception) as e: raise Exception(f"Failed to decode JWT payload: {e}") -<<<<<<< Updated upstream - def __get_cloud_provider_credentials(self) -> GcpCredentials | None: -======= def __get_cloud_provider_credentials(self) -> GcpCredentials: ->>>>>>> Stashed changes """ Fetch cloud provider credentials based on the configured provider. diff --git a/test/test_oetc.py b/test/test_oetc.py index d4a1e909..fe670c8e 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -1,11 +1,8 @@ import base64 import json -<<<<<<< Updated upstream from datetime import datetime from unittest.mock import Mock, patch -======= from typing import Any ->>>>>>> Stashed changes import pytest import requests @@ -83,27 +80,11 @@ def mock_jwt_response(self) -> dict: "expires_in": 3600, } -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - @patch("linopy.oetc.requests.get") - @patch("linopy.oetc.datetime") - def test_successful_authentication( - self, - mock_datetime, - mock_get, - mock_post, - mock_settings, - mock_jwt_response, - mock_gcp_credentials_response, - sample_jwt_token, - ): -======= @patch('linopy.oetc.requests.post') @patch('linopy.oetc.requests.get') @patch('linopy.oetc.datetime') def test_successful_authentication(self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock, mock_settings: OetcSettings, mock_jwt_response: dict, mock_gcp_credentials_response: dict, sample_jwt_token: str) -> None: ->>>>>>> Stashed changes """Test successful authentication flow""" # Setup mocks fixed_time = datetime(2024, 1, 15, 12, 0, 0) @@ -162,13 +143,9 @@ def test_successful_authentication(self, mock_datetime: Mock, mock_get: Mock, mo handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" ) -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_authentication_http_error(self, mock_post, mock_settings): -======= + @patch('linopy.oetc.requests.post') def test_authentication_http_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: ->>>>>>> Stashed changes """Test authentication failure with HTTP error""" # Setup mock to raise HTTP error mock_response = Mock() @@ -215,13 +192,7 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler -<<<<<<< Updated upstream - def test_decode_jwt_payload_success( - self, handler_with_mocked_auth, sample_jwt_token - ): -======= def test_decode_jwt_payload_success(self, handler_with_mocked_auth: OetcHandler, sample_jwt_token: str) -> None: ->>>>>>> Stashed changes """Test successful JWT payload decoding""" result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) @@ -275,15 +246,8 @@ def handler_with_mocked_auth(self, sample_jwt_token: str) -> OetcHandler: return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.get") - def test_get_gcp_credentials_success( - self, mock_get, handler_with_mocked_auth, mock_gcp_credentials_response - ): -======= @patch('linopy.oetc.requests.get') def test_get_gcp_credentials_success(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler, mock_gcp_credentials_response: dict) -> None: ->>>>>>> Stashed changes """Test successful GCP credentials fetching""" # Setup mock mock_response = Mock() @@ -311,13 +275,8 @@ def test_get_gcp_credentials_success(self, mock_get: Mock, handler_with_mocked_a assert result.input_bucket == "test-input-bucket" assert result.solution_bucket == "test-solution-bucket" -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.get") - def test_get_gcp_credentials_http_error(self, mock_get, handler_with_mocked_auth): -======= @patch('linopy.oetc.requests.get') def test_get_gcp_credentials_http_error(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler) -> None: ->>>>>>> Stashed changes """Test GCP credentials fetching with HTTP error""" # Setup mock to raise HTTP error mock_response = Mock() @@ -330,15 +289,8 @@ def test_get_gcp_credentials_http_error(self, mock_get: Mock, handler_with_mocke assert "Failed to fetch GCP credentials" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.get") - def test_get_gcp_credentials_missing_field( - self, mock_get, handler_with_mocked_auth - ): -======= @patch('linopy.oetc.requests.get') def test_get_gcp_credentials_missing_field(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler) -> None: ->>>>>>> Stashed changes """Test GCP credentials fetching with missing response field""" # Setup mock with invalid response mock_response = Mock() @@ -360,21 +312,10 @@ def test_get_gcp_credentials_missing_field(self, mock_get: Mock, handler_with_mo in str(exc_info.value) ) -<<<<<<< Updated upstream - def test_get_cloud_provider_credentials_unsupported_provider( - self, handler_with_mocked_auth - ): - """Test cloud provider credentials with unsupported provider""" - # Change to unsupported provider - handler_with_mocked_auth.settings.compute_provider = ( - "AWS" # Not in enum, but for testing - ) -======= def test_get_cloud_provider_credentials_unsupported_provider(self, handler_with_mocked_auth: OetcHandler) -> None: """Test cloud provider credentials with unsupported provider""" # Change to unsupported provider handler_with_mocked_auth.settings.compute_provider = "AWS" # type: ignore # Not in enum, but for testing ->>>>>>> Stashed changes # Execute and verify exception with pytest.raises(Exception) as exc_info: @@ -401,12 +342,7 @@ def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_aut class TestGcpCredentials: -<<<<<<< Updated upstream - def test_gcp_credentials_creation(self): -======= - def test_gcp_credentials_creation(self) -> None: ->>>>>>> Stashed changes """Test GcpCredentials dataclass creation""" credentials = GcpCredentials( gcp_project_id="test-project-123", @@ -422,23 +358,13 @@ def test_gcp_credentials_creation(self) -> None: class TestComputeProvider: -<<<<<<< Updated upstream - def test_compute_provider_enum(self): -======= - def test_compute_provider_enum(self) -> None: ->>>>>>> Stashed changes """Test ComputeProvider enum values""" assert ComputeProvider.GCP == "GCP" assert ComputeProvider.GCP.value == "GCP" -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_authentication_network_error(self, mock_post, mock_settings): -======= @patch('linopy.oetc.requests.post') def test_authentication_network_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: ->>>>>>> Stashed changes """Test authentication failure with network error""" # Setup mock to raise network error mock_post.side_effect = RequestException("Connection timeout") @@ -449,15 +375,8 @@ def test_authentication_network_error(self, mock_post: Mock, mock_settings: Oetc assert "Authentication request failed" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_authentication_invalid_response_missing_token( - self, mock_post, mock_settings - ): -======= @patch('linopy.oetc.requests.post') def test_authentication_invalid_response_missing_token(self, mock_post: Mock, mock_settings: OetcSettings) -> None: ->>>>>>> Stashed changes """Test authentication failure with missing token in response""" # Setup mock with invalid response mock_response = Mock() @@ -475,15 +394,8 @@ def test_authentication_invalid_response_missing_token(self, mock_post: Mock, mo assert "Invalid response format: missing field 'token'" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_authentication_invalid_response_missing_expires_in( - self, mock_post, mock_settings - ): -======= @patch('linopy.oetc.requests.post') def test_authentication_invalid_response_missing_expires_in(self, mock_post: Mock, mock_settings: OetcSettings) -> None: ->>>>>>> Stashed changes """Test authentication failure with missing expires_in in response""" # Setup mock with invalid response mock_response = Mock() @@ -503,13 +415,8 @@ def test_authentication_invalid_response_missing_expires_in(self, mock_post: Moc exc_info.value ) -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_authentication_timeout_error(self, mock_post, mock_settings): -======= @patch('linopy.oetc.requests.post') def test_authentication_timeout_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: ->>>>>>> Stashed changes """Test authentication failure with timeout""" # Setup mock to raise timeout error mock_post.side_effect = requests.Timeout("Request timeout") @@ -537,39 +444,24 @@ def test_expires_at_calculation(self, auth_result: AuthenticationResult) -> None expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later assert auth_result.expires_at == expected_expiry -<<<<<<< Updated upstream - @patch("linopy.oetc.datetime") - def test_is_expired_false_when_not_expired(self, mock_datetime, auth_result): -======= @patch('linopy.oetc.datetime') def test_is_expired_false_when_not_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: ->>>>>>> Stashed changes """Test is_expired returns False when token is still valid""" # Set current time to before expiration mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 30, 0) assert auth_result.is_expired is False -<<<<<<< Updated upstream - @patch("linopy.oetc.datetime") - def test_is_expired_true_when_expired(self, mock_datetime, auth_result): -======= @patch('linopy.oetc.datetime') def test_is_expired_true_when_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: ->>>>>>> Stashed changes """Test is_expired returns True when token has expired""" # Set current time to after expiration mock_datetime.now.return_value = datetime(2024, 1, 15, 14, 0, 0) assert auth_result.is_expired is True -<<<<<<< Updated upstream - @patch("linopy.oetc.datetime") - def test_is_expired_true_when_exactly_expired(self, mock_datetime, auth_result): -======= @patch('linopy.oetc.datetime') def test_is_expired_true_when_exactly_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: ->>>>>>> Stashed changes """Test is_expired returns True when token expires exactly now""" # Set current time to exact expiration time mock_datetime.now.return_value = datetime(2024, 1, 15, 13, 0, 0) @@ -600,19 +492,10 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.gzip.open") - @patch("linopy.oetc.os.path.exists") - @patch("builtins.open") - def test_gzip_compress_success( - self, mock_open, mock_exists, mock_gzip_open, handler_with_mocked_auth - ): -======= @patch('linopy.oetc.gzip.open') @patch('linopy.oetc.os.path.exists') @patch('builtins.open') def test_gzip_compress_success(self, mock_open: Mock, mock_exists: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: ->>>>>>> Stashed changes """Test successful file compression""" # Setup source_path = "/tmp/test_file.nc" @@ -640,13 +523,8 @@ def test_gzip_compress_success(self, mock_open: Mock, mock_exists: Mock, mock_gz mock_gzip_open.assert_called_once_with(expected_output, "wb", compresslevel=9) mock_file_out.write.assert_called_once_with(b"test_data_chunk") -<<<<<<< Updated upstream - @patch("builtins.open") - def test_gzip_compress_file_read_error(self, mock_open, handler_with_mocked_auth): -======= @patch('builtins.open') def test_gzip_compress_file_read_error(self, mock_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: ->>>>>>> Stashed changes """Test file compression with read error""" # Setup source_path = "/tmp/test_file.nc" @@ -690,27 +568,12 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> O return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.os.path.basename") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") - def test_upload_file_to_gcp_success( - self, - mock_creds_from_info, - mock_storage_client, - mock_basename, - mock_remove, - handler_with_gcp_credentials, - ): -======= @patch('linopy.oetc.os.remove') @patch('linopy.oetc.os.path.basename') @patch('linopy.oetc.storage.Client') @patch('linopy.oetc.service_account.Credentials.from_service_account_info') def test_upload_file_to_gcp_success(self, mock_creds_from_info: Mock, mock_storage_client: Mock, mock_basename: Mock, mock_remove: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test successful file upload to GCP""" # Setup file_path = "/tmp/test_file.nc" @@ -764,15 +627,8 @@ def test_upload_file_to_gcp_success(self, mock_creds_from_info: Mock, mock_stora # Verify cleanup mock_remove.assert_called_once_with(compressed_path) -<<<<<<< Updated upstream - @patch("linopy.oetc.json.loads") - def test_upload_file_to_gcp_invalid_service_key( - self, mock_json_loads, handler_with_gcp_credentials - ): -======= @patch('linopy.oetc.json.loads') def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test upload failure with invalid service key""" # Setup file_path = "/tmp/test_file.nc" @@ -784,18 +640,10 @@ def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads: Mock, han assert "Failed to upload file to GCP" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") - def test_upload_file_to_gcp_upload_error( - self, mock_creds_from_info, mock_storage_client, handler_with_gcp_credentials - ): -======= @patch('linopy.oetc.storage.Client') @patch('linopy.oetc.service_account.Credentials.from_service_account_info') def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test upload failure during blob upload""" # Setup file_path = "/tmp/test_file.nc" @@ -849,17 +697,9 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.gzip.open") - @patch("builtins.open") - def test_gzip_decompress_success( - self, mock_open_file, mock_gzip_open, handler_with_mocked_auth - ): -======= @patch('linopy.oetc.gzip.open') @patch('builtins.open') def test_gzip_decompress_success(self, mock_open_file: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: ->>>>>>> Stashed changes """Test successful file decompression""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -886,15 +726,8 @@ def test_gzip_decompress_success(self, mock_open_file: Mock, mock_gzip_open: Moc mock_open_file.assert_called_once_with(expected_output, "wb") mock_file_out.write.assert_called_once_with(b"decompressed_data_chunk") -<<<<<<< Updated upstream - @patch("linopy.oetc.gzip.open") - def test_gzip_decompress_gzip_open_error( - self, mock_gzip_open, handler_with_mocked_auth - ): -======= @patch('linopy.oetc.gzip.open') def test_gzip_decompress_gzip_open_error(self, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: ->>>>>>> Stashed changes """Test file decompression with gzip open error""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -906,17 +739,9 @@ def test_gzip_decompress_gzip_open_error(self, mock_gzip_open: Mock, handler_wit assert "Failed to decompress file" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.gzip.open") - @patch("builtins.open") - def test_gzip_decompress_write_error( - self, mock_open_file, mock_gzip_open, handler_with_mocked_auth - ): -======= @patch('linopy.oetc.gzip.open') @patch('builtins.open') def test_gzip_decompress_write_error(self, mock_open_file: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: ->>>>>>> Stashed changes """Test file decompression with write error""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -995,27 +820,12 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> O return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") - def test_download_file_from_gcp_success( - self, - mock_creds_from_info, - mock_storage_client, - mock_tempfile, - mock_remove, - handler_with_gcp_credentials, - ): -======= @patch('linopy.oetc.os.remove') @patch('linopy.oetc.tempfile.NamedTemporaryFile') @patch('linopy.oetc.storage.Client') @patch('linopy.oetc.service_account.Credentials.from_service_account_info') def test_download_file_from_gcp_success(self, mock_creds_from_info: Mock, mock_storage_client: Mock, mock_tempfile: Mock, mock_remove: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test successful file download from GCP""" # Setup file_name = "solution_file.nc.gz" @@ -1073,15 +883,8 @@ def test_download_file_from_gcp_success(self, mock_creds_from_info: Mock, mock_s # Verify cleanup mock_remove.assert_called_once_with(compressed_path) -<<<<<<< Updated upstream - @patch("linopy.oetc.json.loads") - def test_download_file_from_gcp_invalid_service_key( - self, mock_json_loads, handler_with_gcp_credentials - ): -======= @patch('linopy.oetc.json.loads') def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test download failure with invalid service key""" # Setup file_name = "solution_file.nc.gz" @@ -1093,24 +896,11 @@ def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads: Mock, assert "Failed to download file from GCP" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") - def test_download_file_from_gcp_download_error( - self, - mock_creds_from_info, - mock_storage_client, - mock_tempfile, - handler_with_gcp_credentials, - ): -======= @patch('linopy.oetc.tempfile.NamedTemporaryFile') @patch('linopy.oetc.storage.Client') @patch('linopy.oetc.service_account.Credentials.from_service_account_info') def test_download_file_from_gcp_download_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, mock_tempfile: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test download failure during blob download""" # Setup file_name = "solution_file.nc.gz" @@ -1141,27 +931,12 @@ def test_download_file_from_gcp_download_error(self, mock_creds_from_info: Mock, assert "Failed to download file from GCP" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") - def test_download_file_from_gcp_decompression_error( - self, - mock_creds_from_info, - mock_storage_client, - mock_tempfile, - mock_remove, - handler_with_gcp_credentials, - ): -======= @patch('linopy.oetc.os.remove') @patch('linopy.oetc.tempfile.NamedTemporaryFile') @patch('linopy.oetc.storage.Client') @patch('linopy.oetc.service_account.Credentials.from_service_account_info') def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, mock_tempfile: Mock, mock_remove: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test download failure during decompression""" # Setup file_name = "solution_file.nc.gz" @@ -1197,15 +972,8 @@ def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info: assert "Failed to download file from GCP" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") - def test_download_file_from_gcp_credentials_error( - self, mock_creds_from_info, handler_with_gcp_credentials - ): -======= @patch('linopy.oetc.service_account.Credentials.from_service_account_info') def test_download_file_from_gcp_credentials_error(self, mock_creds_from_info: Mock, handler_with_gcp_credentials: OetcHandler) -> None: ->>>>>>> Stashed changes """Test download failure during credentials creation""" # Setup file_name = "solution_file.nc.gz" @@ -1251,13 +1019,8 @@ def handler_with_auth_setup(self, sample_jwt_token: str) -> OetcHandler: return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_submit_job_success(self, mock_post, handler_with_auth_setup): -======= @patch('linopy.oetc.requests.post') def test_submit_job_success(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test successful job submission to compute service""" # Setup input_file_name = "test_model.nc.gz" @@ -1297,13 +1060,8 @@ def test_submit_job_success(self, mock_post: Mock, handler_with_auth_setup: Oetc # Verify result assert result == expected_job_uuid -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_submit_job_http_error(self, mock_post, handler_with_auth_setup): -======= @patch('linopy.oetc.requests.post') def test_submit_job_http_error(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test job submission with HTTP error""" # Setup input_file_name = "test_model.nc.gz" @@ -1319,15 +1077,8 @@ def test_submit_job_http_error(self, mock_post: Mock, handler_with_auth_setup: O assert "Failed to submit job to compute service" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_submit_job_missing_uuid_in_response( - self, mock_post, handler_with_auth_setup - ): -======= @patch('linopy.oetc.requests.post') def test_submit_job_missing_uuid_in_response(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test job submission with missing UUID in response""" # Setup input_file_name = "test_model.nc.gz" @@ -1344,13 +1095,8 @@ def test_submit_job_missing_uuid_in_response(self, mock_post: Mock, handler_with exc_info.value ) -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - def test_submit_job_network_error(self, mock_post, handler_with_auth_setup): -======= @patch('linopy.oetc.requests.post') def test_submit_job_network_error(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test job submission with network error""" # Setup input_file_name = "test_model.nc.gz" @@ -1393,19 +1139,10 @@ def handler_with_complete_setup(self, mock_gcp_credentials_response: dict) -> Oe return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.linopy.read_netcdf") - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - def test_solve_on_oetc_file_upload( - self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_complete_setup - ): -======= @patch('linopy.oetc.linopy.read_netcdf') @patch('linopy.oetc.os.remove') @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_file_upload(self, mock_tempfile: Mock, mock_remove: Mock, mock_read_netcdf: Mock, handler_with_complete_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test solve_on_oetc method complete workflow""" # Setup mock_model = Mock() @@ -1463,15 +1200,8 @@ def test_solve_on_oetc_file_upload(self, mock_tempfile: Mock, mock_remove: Mock, ) mock_remove.assert_called_once_with("/tmp/downloaded_result.nc") -<<<<<<< Updated upstream - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - def test_solve_on_oetc_upload_failure( - self, mock_tempfile, handler_with_complete_setup - ): -======= @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_upload_failure(self, mock_tempfile: Mock, handler_with_complete_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test solve_on_oetc method with upload failure""" # Setup mock_model = Mock() @@ -1524,20 +1254,11 @@ def handler_with_full_setup(self) -> OetcHandler: return handler -<<<<<<< Updated upstream - @patch("linopy.oetc.linopy.read_netcdf") - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - def test_solve_on_oetc_with_job_submission( - self, mock_tempfile, mock_remove, mock_read_netcdf, handler_with_full_setup - ): -======= @patch('linopy.oetc.linopy.read_netcdf') @patch('linopy.oetc.os.remove') @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_with_job_submission(self, mock_tempfile: Mock, mock_remove: Mock, mock_read_netcdf: Mock, handler_with_full_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test solve_on_oetc method including job submission, waiting, and download""" # Setup mock_model = Mock() @@ -1598,15 +1319,8 @@ def test_solve_on_oetc_with_job_submission(self, mock_tempfile: Mock, mock_remov ) mock_remove.assert_called_once_with("/tmp/solution_file.nc") -<<<<<<< Updated upstream - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - def test_solve_on_oetc_job_submission_failure( - self, mock_tempfile, handler_with_full_setup - ): -======= @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_job_submission_failure(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test solve_on_oetc method with job submission failure""" # Setup mock_model = Mock() @@ -1633,15 +1347,8 @@ def test_solve_on_oetc_job_submission_failure(self, mock_tempfile: Mock, handler assert "Job submission failed" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - def test_solve_on_oetc_job_waiting_failure( - self, mock_tempfile, handler_with_full_setup - ): -======= @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test solve_on_oetc method with job waiting failure""" # Setup mock_model = Mock() @@ -1674,15 +1381,8 @@ def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile: Mock, handler_wi assert "Job failed: solver error" in str(exc_info.value) -<<<<<<< Updated upstream - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - def test_solve_on_oetc_no_output_files_error( - self, mock_tempfile, handler_with_full_setup - ): -======= @patch('linopy.oetc.tempfile.NamedTemporaryFile') def test_solve_on_oetc_no_output_files_error(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: ->>>>>>> Stashed changes """Test solve_on_oetc method when job completes but has no output files""" # Setup mock_model = Mock() @@ -1722,18 +1422,10 @@ def test_solve_on_oetc_no_output_files_error(self, mock_tempfile: Mock, handler_ # Additional integration-style test class TestOetcHandlerIntegration: -<<<<<<< Updated upstream - @patch("linopy.oetc.requests.post") - @patch("linopy.oetc.requests.get") - @patch("linopy.oetc.datetime") - def test_complete_authentication_flow(self, mock_datetime, mock_get, mock_post): -======= - @patch('linopy.oetc.requests.post') @patch('linopy.oetc.requests.get') @patch('linopy.oetc.datetime') def test_complete_authentication_flow(self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock) -> None: ->>>>>>> Stashed changes """Test complete authentication and credentials flow with realistic data""" # Setup fixed_time = datetime(2024, 1, 15, 12, 0, 0) From 4fb5be7e36a4239fcef46eb0a263254544492828 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Aug 2025 14:35:05 +0000 Subject: [PATCH 14/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/oetc.py | 17 ++- test/test_oetc.py | 369 +++++++++++++++++++++++++++++++--------------- 2 files changed, 259 insertions(+), 127 deletions(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 79b06fd6..11353bbd 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -74,13 +74,14 @@ class JobResult: uuid: str status: str name: str | None = None - owner: str | None = None - solver: str | None = None - duration_in_seconds: int | None = None - solving_duration_in_seconds: int | None = None - input_files: list | None = None - output_files: list | None = None - created_at: str | None = None + owner: str | None = None + solver: str | None = None + duration_in_seconds: int | None = None + solving_duration_in_seconds: int | None = None + input_files: list | None = None + output_files: list | None = None + created_at: str | None = None + class OetcHandler: def __init__(self, settings: OetcSettings) -> None: @@ -470,7 +471,7 @@ def _download_file_from_gcp(self, file_name: str) -> str: except Exception as e: raise Exception(f"Failed to download file from GCP: {e}") - def solve_on_oetc(self, model): # type: ignore + def solve_on_oetc(self, model): # type: ignore """ Solve a linopy model on the OET Cloud compute app. diff --git a/test/test_oetc.py b/test/test_oetc.py index fe670c8e..9d5410ae 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -2,7 +2,6 @@ import json from datetime import datetime from unittest.mock import Mock, patch -from typing import Any import pytest import requests @@ -80,11 +79,19 @@ def mock_jwt_response(self) -> dict: "expires_in": 3600, } - @patch('linopy.oetc.requests.post') - @patch('linopy.oetc.requests.get') - @patch('linopy.oetc.datetime') - def test_successful_authentication(self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock, mock_settings: OetcSettings, mock_jwt_response: dict, - mock_gcp_credentials_response: dict, sample_jwt_token: str) -> None: + @patch("linopy.oetc.requests.post") + @patch("linopy.oetc.requests.get") + @patch("linopy.oetc.datetime") + def test_successful_authentication( + self, + mock_datetime: Mock, + mock_get: Mock, + mock_post: Mock, + mock_settings: OetcSettings, + mock_jwt_response: dict, + mock_gcp_credentials_response: dict, + sample_jwt_token: str, + ) -> None: """Test successful authentication flow""" # Setup mocks fixed_time = datetime(2024, 1, 15, 12, 0, 0) @@ -143,9 +150,10 @@ def test_successful_authentication(self, mock_datetime: Mock, mock_get: Mock, mo handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" ) - - @patch('linopy.oetc.requests.post') - def test_authentication_http_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: + @patch("linopy.oetc.requests.post") + def test_authentication_http_error( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: """Test authentication failure with HTTP error""" # Setup mock to raise HTTP error mock_response = Mock() @@ -192,7 +200,9 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler - def test_decode_jwt_payload_success(self, handler_with_mocked_auth: OetcHandler, sample_jwt_token: str) -> None: + def test_decode_jwt_payload_success( + self, handler_with_mocked_auth: OetcHandler, sample_jwt_token: str + ) -> None: """Test successful JWT payload decoding""" result = handler_with_mocked_auth._decode_jwt_payload(sample_jwt_token) @@ -202,14 +212,18 @@ def test_decode_jwt_payload_success(self, handler_with_mocked_auth: OetcHandler, assert result["firstname"] == "Test" assert result["lastname"] == "User" - def test_decode_jwt_payload_invalid_token(self, handler_with_mocked_auth: OetcHandler) -> None: + def test_decode_jwt_payload_invalid_token( + self, handler_with_mocked_auth: OetcHandler + ) -> None: """Test JWT payload decoding with invalid token""" with pytest.raises(Exception) as exc_info: handler_with_mocked_auth._decode_jwt_payload("invalid.token") assert "Failed to decode JWT payload" in str(exc_info.value) - def test_decode_jwt_payload_malformed_token(self, handler_with_mocked_auth: OetcHandler) -> None: + def test_decode_jwt_payload_malformed_token( + self, handler_with_mocked_auth: OetcHandler + ) -> None: """Test JWT payload decoding with malformed token""" with pytest.raises(Exception) as exc_info: handler_with_mocked_auth._decode_jwt_payload("not_a_jwt_token") @@ -246,8 +260,13 @@ def handler_with_mocked_auth(self, sample_jwt_token: str) -> OetcHandler: return handler - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_success(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler, mock_gcp_credentials_response: dict) -> None: + @patch("linopy.oetc.requests.get") + def test_get_gcp_credentials_success( + self, + mock_get: Mock, + handler_with_mocked_auth: OetcHandler, + mock_gcp_credentials_response: dict, + ) -> None: """Test successful GCP credentials fetching""" # Setup mock mock_response = Mock() @@ -256,7 +275,7 @@ def test_get_gcp_credentials_success(self, mock_get: Mock, handler_with_mocked_a mock_get.return_value = mock_response # Execute - result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + result = handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] # Verify request mock_get.assert_called_once_with( @@ -275,8 +294,10 @@ def test_get_gcp_credentials_success(self, mock_get: Mock, handler_with_mocked_a assert result.input_bucket == "test-input-bucket" assert result.solution_bucket == "test-solution-bucket" - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_http_error(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler) -> None: + @patch("linopy.oetc.requests.get") + def test_get_gcp_credentials_http_error( + self, mock_get: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: """Test GCP credentials fetching with HTTP error""" # Setup mock to raise HTTP error mock_response = Mock() @@ -285,12 +306,14 @@ def test_get_gcp_credentials_http_error(self, mock_get: Mock, handler_with_mocke # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] assert "Failed to fetch GCP credentials" in str(exc_info.value) - @patch('linopy.oetc.requests.get') - def test_get_gcp_credentials_missing_field(self, mock_get: Mock, handler_with_mocked_auth: OetcHandler) -> None: + @patch("linopy.oetc.requests.get") + def test_get_gcp_credentials_missing_field( + self, mock_get: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: """Test GCP credentials fetching with missing response field""" # Setup mock with invalid response mock_response = Mock() @@ -305,25 +328,29 @@ def test_get_gcp_credentials_missing_field(self, mock_get: Mock, handler_with_mo # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] assert ( "Invalid credentials response format: missing field 'solution_bucket'" in str(exc_info.value) ) - def test_get_cloud_provider_credentials_unsupported_provider(self, handler_with_mocked_auth: OetcHandler) -> None: + def test_get_cloud_provider_credentials_unsupported_provider( + self, handler_with_mocked_auth: OetcHandler + ) -> None: """Test cloud provider credentials with unsupported provider""" # Change to unsupported provider handler_with_mocked_auth.settings.compute_provider = "AWS" # type: ignore # Not in enum, but for testing # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() # type: ignore[attr-defined] + handler_with_mocked_auth._OetcHandler__get_cloud_provider_credentials() # type: ignore[attr-defined] assert "Unsupported compute provider: AWS" in str(exc_info.value) - def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_auth: OetcHandler) -> None: + def test_get_gcp_credentials_no_user_uuid_in_token( + self, handler_with_mocked_auth: OetcHandler + ) -> None: """Test GCP credentials fetching when JWT token has no user UUID""" # Create token without 'sub' field payload = {"iss": "OETC", "email": "test@example.com"} @@ -336,7 +363,7 @@ def test_get_gcp_credentials_no_user_uuid_in_token(self, handler_with_mocked_aut # Execute and verify exception with pytest.raises(Exception) as exc_info: - handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] + handler_with_mocked_auth._OetcHandler__get_gcp_credentials() # type: ignore[attr-defined] assert "User UUID not found in JWT token" in str(exc_info.value) @@ -363,8 +390,10 @@ def test_compute_provider_enum(self) -> None: assert ComputeProvider.GCP == "GCP" assert ComputeProvider.GCP.value == "GCP" - @patch('linopy.oetc.requests.post') - def test_authentication_network_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: + @patch("linopy.oetc.requests.post") + def test_authentication_network_error( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: """Test authentication failure with network error""" # Setup mock to raise network error mock_post.side_effect = RequestException("Connection timeout") @@ -375,8 +404,10 @@ def test_authentication_network_error(self, mock_post: Mock, mock_settings: Oetc assert "Authentication request failed" in str(exc_info.value) - @patch('linopy.oetc.requests.post') - def test_authentication_invalid_response_missing_token(self, mock_post: Mock, mock_settings: OetcSettings) -> None: + @patch("linopy.oetc.requests.post") + def test_authentication_invalid_response_missing_token( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: """Test authentication failure with missing token in response""" # Setup mock with invalid response mock_response = Mock() @@ -394,8 +425,10 @@ def test_authentication_invalid_response_missing_token(self, mock_post: Mock, mo assert "Invalid response format: missing field 'token'" in str(exc_info.value) - @patch('linopy.oetc.requests.post') - def test_authentication_invalid_response_missing_expires_in(self, mock_post: Mock, mock_settings: OetcSettings) -> None: + @patch("linopy.oetc.requests.post") + def test_authentication_invalid_response_missing_expires_in( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: """Test authentication failure with missing expires_in in response""" # Setup mock with invalid response mock_response = Mock() @@ -415,8 +448,10 @@ def test_authentication_invalid_response_missing_expires_in(self, mock_post: Moc exc_info.value ) - @patch('linopy.oetc.requests.post') - def test_authentication_timeout_error(self, mock_post: Mock, mock_settings: OetcSettings) -> None: + @patch("linopy.oetc.requests.post") + def test_authentication_timeout_error( + self, mock_post: Mock, mock_settings: OetcSettings + ) -> None: """Test authentication failure with timeout""" # Setup mock to raise timeout error mock_post.side_effect = requests.Timeout("Request timeout") @@ -444,24 +479,30 @@ def test_expires_at_calculation(self, auth_result: AuthenticationResult) -> None expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later assert auth_result.expires_at == expected_expiry - @patch('linopy.oetc.datetime') - def test_is_expired_false_when_not_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: + @patch("linopy.oetc.datetime") + def test_is_expired_false_when_not_expired( + self, mock_datetime: Mock, auth_result: AuthenticationResult + ) -> None: """Test is_expired returns False when token is still valid""" # Set current time to before expiration mock_datetime.now.return_value = datetime(2024, 1, 15, 12, 30, 0) assert auth_result.is_expired is False - @patch('linopy.oetc.datetime') - def test_is_expired_true_when_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: + @patch("linopy.oetc.datetime") + def test_is_expired_true_when_expired( + self, mock_datetime: Mock, auth_result: AuthenticationResult + ) -> None: """Test is_expired returns True when token has expired""" # Set current time to after expiration mock_datetime.now.return_value = datetime(2024, 1, 15, 14, 0, 0) assert auth_result.is_expired is True - @patch('linopy.oetc.datetime') - def test_is_expired_true_when_exactly_expired(self, mock_datetime: Mock, auth_result: AuthenticationResult) -> None: + @patch("linopy.oetc.datetime") + def test_is_expired_true_when_exactly_expired( + self, mock_datetime: Mock, auth_result: AuthenticationResult + ) -> None: """Test is_expired returns True when token expires exactly now""" # Set current time to exact expiration time mock_datetime.now.return_value = datetime(2024, 1, 15, 13, 0, 0) @@ -492,10 +533,16 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler - @patch('linopy.oetc.gzip.open') - @patch('linopy.oetc.os.path.exists') - @patch('builtins.open') - def test_gzip_compress_success(self, mock_open: Mock, mock_exists: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: + @patch("linopy.oetc.gzip.open") + @patch("linopy.oetc.os.path.exists") + @patch("builtins.open") + def test_gzip_compress_success( + self, + mock_open: Mock, + mock_exists: Mock, + mock_gzip_open: Mock, + handler_with_mocked_auth: OetcHandler, + ) -> None: """Test successful file compression""" # Setup source_path = "/tmp/test_file.nc" @@ -523,8 +570,10 @@ def test_gzip_compress_success(self, mock_open: Mock, mock_exists: Mock, mock_gz mock_gzip_open.assert_called_once_with(expected_output, "wb", compresslevel=9) mock_file_out.write.assert_called_once_with(b"test_data_chunk") - @patch('builtins.open') - def test_gzip_compress_file_read_error(self, mock_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: + @patch("builtins.open") + def test_gzip_compress_file_read_error( + self, mock_open: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: """Test file compression with read error""" # Setup source_path = "/tmp/test_file.nc" @@ -539,7 +588,9 @@ def test_gzip_compress_file_read_error(self, mock_open: Mock, handler_with_mocke class TestGcpUpload: @pytest.fixture - def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> OetcHandler: + def handler_with_gcp_credentials( + self, mock_gcp_credentials_response: dict + ) -> OetcHandler: """Create handler with GCP credentials for testing upload""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -568,12 +619,18 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> O return handler - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.os.path.basename') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_upload_file_to_gcp_success(self, mock_creds_from_info: Mock, mock_storage_client: Mock, mock_basename: Mock, mock_remove: Mock, - handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.os.path.basename") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_upload_file_to_gcp_success( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_basename: Mock, + mock_remove: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: """Test successful file upload to GCP""" # Setup file_path = "/tmp/test_file.nc" @@ -627,8 +684,10 @@ def test_upload_file_to_gcp_success(self, mock_creds_from_info: Mock, mock_stora # Verify cleanup mock_remove.assert_called_once_with(compressed_path) - @patch('linopy.oetc.json.loads') - def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.json.loads") + def test_upload_file_to_gcp_invalid_service_key( + self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler + ) -> None: """Test upload failure with invalid service key""" # Setup file_path = "/tmp/test_file.nc" @@ -640,10 +699,14 @@ def test_upload_file_to_gcp_invalid_service_key(self, mock_json_loads: Mock, han assert "Failed to upload file to GCP" in str(exc_info.value) - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_upload_file_to_gcp_upload_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, - handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_upload_file_to_gcp_upload_error( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: """Test upload failure during blob upload""" # Setup file_path = "/tmp/test_file.nc" @@ -697,9 +760,14 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler - @patch('linopy.oetc.gzip.open') - @patch('builtins.open') - def test_gzip_decompress_success(self, mock_open_file: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: + @patch("linopy.oetc.gzip.open") + @patch("builtins.open") + def test_gzip_decompress_success( + self, + mock_open_file: Mock, + mock_gzip_open: Mock, + handler_with_mocked_auth: OetcHandler, + ) -> None: """Test successful file decompression""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -726,8 +794,10 @@ def test_gzip_decompress_success(self, mock_open_file: Mock, mock_gzip_open: Moc mock_open_file.assert_called_once_with(expected_output, "wb") mock_file_out.write.assert_called_once_with(b"decompressed_data_chunk") - @patch('linopy.oetc.gzip.open') - def test_gzip_decompress_gzip_open_error(self, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: + @patch("linopy.oetc.gzip.open") + def test_gzip_decompress_gzip_open_error( + self, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler + ) -> None: """Test file decompression with gzip open error""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -739,9 +809,14 @@ def test_gzip_decompress_gzip_open_error(self, mock_gzip_open: Mock, handler_wit assert "Failed to decompress file" in str(exc_info.value) - @patch('linopy.oetc.gzip.open') - @patch('builtins.open') - def test_gzip_decompress_write_error(self, mock_open_file: Mock, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler) -> None: + @patch("linopy.oetc.gzip.open") + @patch("builtins.open") + def test_gzip_decompress_write_error( + self, + mock_open_file: Mock, + mock_gzip_open: Mock, + handler_with_mocked_auth: OetcHandler, + ) -> None: """Test file decompression with write error""" # Setup input_path = "/tmp/test_file.nc.gz" @@ -760,7 +835,9 @@ def test_gzip_decompress_write_error(self, mock_open_file: Mock, mock_gzip_open: assert "Failed to decompress file" in str(exc_info.value) - def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth: OetcHandler) -> None: + def test_gzip_decompress_output_path_generation( + self, handler_with_mocked_auth: OetcHandler + ) -> None: """Test correct output path generation for decompression""" # Test first path with patch("linopy.oetc.gzip.open") as mock_gzip_open: @@ -791,7 +868,9 @@ def test_gzip_decompress_output_path_generation(self, handler_with_mocked_auth: class TestGcpDownload: @pytest.fixture - def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> OetcHandler: + def handler_with_gcp_credentials( + self, mock_gcp_credentials_response: dict + ) -> OetcHandler: """Create handler with GCP credentials for testing download""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -820,12 +899,18 @@ def handler_with_gcp_credentials(self, mock_gcp_credentials_response: dict) -> O return handler - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_success(self, mock_creds_from_info: Mock, mock_storage_client: Mock, - mock_tempfile: Mock, mock_remove: Mock, handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_success( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_tempfile: Mock, + mock_remove: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: """Test successful file download from GCP""" # Setup file_name = "solution_file.nc.gz" @@ -883,8 +968,10 @@ def test_download_file_from_gcp_success(self, mock_creds_from_info: Mock, mock_s # Verify cleanup mock_remove.assert_called_once_with(compressed_path) - @patch('linopy.oetc.json.loads') - def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.json.loads") + def test_download_file_from_gcp_invalid_service_key( + self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler + ) -> None: """Test download failure with invalid service key""" # Setup file_name = "solution_file.nc.gz" @@ -896,11 +983,16 @@ def test_download_file_from_gcp_invalid_service_key(self, mock_json_loads: Mock, assert "Failed to download file from GCP" in str(exc_info.value) - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_download_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, - mock_tempfile: Mock, handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_download_error( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_tempfile: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: """Test download failure during blob download""" # Setup file_name = "solution_file.nc.gz" @@ -931,12 +1023,18 @@ def test_download_file_from_gcp_download_error(self, mock_creds_from_info: Mock, assert "Failed to download file from GCP" in str(exc_info.value) - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - @patch('linopy.oetc.storage.Client') - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info: Mock, mock_storage_client: Mock, - mock_tempfile: Mock, mock_remove: Mock, handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.oetc.storage.Client") + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_decompression_error( + self, + mock_creds_from_info: Mock, + mock_storage_client: Mock, + mock_tempfile: Mock, + mock_remove: Mock, + handler_with_gcp_credentials: OetcHandler, + ) -> None: """Test download failure during decompression""" # Setup file_name = "solution_file.nc.gz" @@ -972,8 +1070,10 @@ def test_download_file_from_gcp_decompression_error(self, mock_creds_from_info: assert "Failed to download file from GCP" in str(exc_info.value) - @patch('linopy.oetc.service_account.Credentials.from_service_account_info') - def test_download_file_from_gcp_credentials_error(self, mock_creds_from_info: Mock, handler_with_gcp_credentials: OetcHandler) -> None: + @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + def test_download_file_from_gcp_credentials_error( + self, mock_creds_from_info: Mock, handler_with_gcp_credentials: OetcHandler + ) -> None: """Test download failure during credentials creation""" # Setup file_name = "solution_file.nc.gz" @@ -1019,8 +1119,10 @@ def handler_with_auth_setup(self, sample_jwt_token: str) -> OetcHandler: return handler - @patch('linopy.oetc.requests.post') - def test_submit_job_success(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: + @patch("linopy.oetc.requests.post") + def test_submit_job_success( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: """Test successful job submission to compute service""" # Setup input_file_name = "test_model.nc.gz" @@ -1060,8 +1162,10 @@ def test_submit_job_success(self, mock_post: Mock, handler_with_auth_setup: Oetc # Verify result assert result == expected_job_uuid - @patch('linopy.oetc.requests.post') - def test_submit_job_http_error(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: + @patch("linopy.oetc.requests.post") + def test_submit_job_http_error( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: """Test job submission with HTTP error""" # Setup input_file_name = "test_model.nc.gz" @@ -1077,8 +1181,10 @@ def test_submit_job_http_error(self, mock_post: Mock, handler_with_auth_setup: O assert "Failed to submit job to compute service" in str(exc_info.value) - @patch('linopy.oetc.requests.post') - def test_submit_job_missing_uuid_in_response(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: + @patch("linopy.oetc.requests.post") + def test_submit_job_missing_uuid_in_response( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: """Test job submission with missing UUID in response""" # Setup input_file_name = "test_model.nc.gz" @@ -1095,8 +1201,10 @@ def test_submit_job_missing_uuid_in_response(self, mock_post: Mock, handler_with exc_info.value ) - @patch('linopy.oetc.requests.post') - def test_submit_job_network_error(self, mock_post: Mock, handler_with_auth_setup: OetcHandler) -> None: + @patch("linopy.oetc.requests.post") + def test_submit_job_network_error( + self, mock_post: Mock, handler_with_auth_setup: OetcHandler + ) -> None: """Test job submission with network error""" # Setup input_file_name = "test_model.nc.gz" @@ -1111,7 +1219,9 @@ def test_submit_job_network_error(self, mock_post: Mock, handler_with_auth_setup class TestSolveOnOetc: @pytest.fixture - def handler_with_complete_setup(self, mock_gcp_credentials_response: dict) -> OetcHandler: + def handler_with_complete_setup( + self, mock_gcp_credentials_response: dict + ) -> OetcHandler: """Create handler with complete setup for testing solve functionality""" with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): credentials = OetcCredentials( @@ -1139,10 +1249,16 @@ def handler_with_complete_setup(self, mock_gcp_credentials_response: dict) -> Oe return handler - @patch('linopy.oetc.linopy.read_netcdf') - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_file_upload(self, mock_tempfile: Mock, mock_remove: Mock, mock_read_netcdf: Mock, handler_with_complete_setup: OetcHandler) -> None: + @patch("linopy.oetc.linopy.read_netcdf") + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_file_upload( + self, + mock_tempfile: Mock, + mock_remove: Mock, + mock_read_netcdf: Mock, + handler_with_complete_setup: OetcHandler, + ) -> None: """Test solve_on_oetc method complete workflow""" # Setup mock_model = Mock() @@ -1200,8 +1316,10 @@ def test_solve_on_oetc_file_upload(self, mock_tempfile: Mock, mock_remove: Mock, ) mock_remove.assert_called_once_with("/tmp/downloaded_result.nc") - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_upload_failure(self, mock_tempfile: Mock, handler_with_complete_setup: OetcHandler) -> None: + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_upload_failure( + self, mock_tempfile: Mock, handler_with_complete_setup: OetcHandler + ) -> None: """Test solve_on_oetc method with upload failure""" # Setup mock_model = Mock() @@ -1254,11 +1372,16 @@ def handler_with_full_setup(self) -> OetcHandler: return handler - @patch('linopy.oetc.linopy.read_netcdf') - @patch('linopy.oetc.os.remove') - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_with_job_submission(self, mock_tempfile: Mock, mock_remove: Mock, mock_read_netcdf: Mock, - handler_with_full_setup: OetcHandler) -> None: + @patch("linopy.oetc.linopy.read_netcdf") + @patch("linopy.oetc.os.remove") + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_with_job_submission( + self, + mock_tempfile: Mock, + mock_remove: Mock, + mock_read_netcdf: Mock, + handler_with_full_setup: OetcHandler, + ) -> None: """Test solve_on_oetc method including job submission, waiting, and download""" # Setup mock_model = Mock() @@ -1319,8 +1442,10 @@ def test_solve_on_oetc_with_job_submission(self, mock_tempfile: Mock, mock_remov ) mock_remove.assert_called_once_with("/tmp/solution_file.nc") - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_job_submission_failure(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_job_submission_failure( + self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler + ) -> None: """Test solve_on_oetc method with job submission failure""" # Setup mock_model = Mock() @@ -1347,8 +1472,10 @@ def test_solve_on_oetc_job_submission_failure(self, mock_tempfile: Mock, handler assert "Job submission failed" in str(exc_info.value) - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_job_waiting_failure( + self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler + ) -> None: """Test solve_on_oetc method with job waiting failure""" # Setup mock_model = Mock() @@ -1381,8 +1508,10 @@ def test_solve_on_oetc_job_waiting_failure(self, mock_tempfile: Mock, handler_wi assert "Job failed: solver error" in str(exc_info.value) - @patch('linopy.oetc.tempfile.NamedTemporaryFile') - def test_solve_on_oetc_no_output_files_error(self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler) -> None: + @patch("linopy.oetc.tempfile.NamedTemporaryFile") + def test_solve_on_oetc_no_output_files_error( + self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler + ) -> None: """Test solve_on_oetc method when job completes but has no output files""" # Setup mock_model = Mock() @@ -1422,10 +1551,12 @@ def test_solve_on_oetc_no_output_files_error(self, mock_tempfile: Mock, handler_ # Additional integration-style test class TestOetcHandlerIntegration: - @patch('linopy.oetc.requests.post') - @patch('linopy.oetc.requests.get') - @patch('linopy.oetc.datetime') - def test_complete_authentication_flow(self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock) -> None: + @patch("linopy.oetc.requests.post") + @patch("linopy.oetc.requests.get") + @patch("linopy.oetc.datetime") + def test_complete_authentication_flow( + self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock + ) -> None: """Test complete authentication and credentials flow with realistic data""" # Setup fixed_time = datetime(2024, 1, 15, 12, 0, 0) From 8032e6a4225f481dee59edace6536f53a14598c8 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 4 Aug 2025 16:39:17 +0200 Subject: [PATCH 15/28] Add OETC integration entry in the release notes --- doc/release_notes.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 61580902..ba0f429c 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -8,6 +8,7 @@ Release Notes * Improved variable/expression arithmetic methods so that they correctly handle types * Gurobi: Pass dictionary as env argument `env={...}` through to gurobi env creation +* Added integration with OETC platform **Breaking Changes** From c96e9f749278f9f460f058e45a815e062d2f503a Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Mon, 4 Aug 2025 16:42:54 +0200 Subject: [PATCH 16/28] Ignore false positive type warning --- linopy/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/model.py b/linopy/model.py index 1a8a35fa..04554e13 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1157,7 +1157,7 @@ def solve( **solver_options, ) else: - solved = OetcHandler(oetc_settings).solve_on_oetc(self) + solved = OetcHandler(oetc_settings).solve_on_oetc(self) # type: ignore self.objective.set_value(solved.objective.value) self.status = solved.status From b76e03409fbc97ce6d1249ccde5d6c4878ae9708 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Aug 2025 14:43:05 +0000 Subject: [PATCH 17/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/model.py b/linopy/model.py index 04554e13..d606ff84 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1157,7 +1157,7 @@ def solve( **solver_options, ) else: - solved = OetcHandler(oetc_settings).solve_on_oetc(self) # type: ignore + solved = OetcHandler(oetc_settings).solve_on_oetc(self) # type: ignore self.objective.set_value(solved.objective.value) self.status = solved.status From bf8dc91debabdc23beb7474861022360209a4d97 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Tue, 26 Aug 2025 14:38:51 +0200 Subject: [PATCH 18/28] Extend logging for oetc processes --- linopy/oetc.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/linopy/oetc.py b/linopy/oetc.py index 11353bbd..24d94325 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -406,6 +406,7 @@ def _gzip_decompress(self, input_path: str) -> str: Exception: If decompression fails """ try: + logger.info(f"OETC - Decompressing file: {input_path}") output_path = input_path[:-3] chunk_size = 1024 * 1024 @@ -417,6 +418,7 @@ def _gzip_decompress(self, input_path: str) -> str: break f_out.write(chunk) + logger.info(f"OETC - File decompressed successfully: {output_path}") return output_path except Exception as e: raise Exception(f"Failed to decompress file: {e}") @@ -435,6 +437,8 @@ def _download_file_from_gcp(self, file_name: str) -> str: Exception: If download or decompression fails """ try: + logger.info(f"OETC - Downloading file from GCP: {file_name}") + # Create GCP credentials from service key service_key_dict = json.loads( self.cloud_provider_credentials.gcp_service_key @@ -459,6 +463,7 @@ def _download_file_from_gcp(self, file_name: str) -> str: compressed_file_path = temp_file.name blob.download_to_filename(compressed_file_path) + logger.info(f"OETC - File downloaded from GCP successfully: {file_name}") # Decompress the downloaded file decompressed_file_path = self._gzip_decompress(compressed_file_path) @@ -543,6 +548,7 @@ def _gzip_compress(self, source_path: str) -> str: Exception: If compression fails """ try: + logger.info(f"OETC - Compressing file: {source_path}") output_path = source_path + ".gz" chunk_size = 1024 * 1024 @@ -554,6 +560,7 @@ def _gzip_compress(self, source_path: str) -> str: break f_out.write(chunk) + logger.info(f"OETC - File compressed successfully: {output_path}") return output_path except Exception as e: raise Exception(f"Failed to compress file: {e}") @@ -575,6 +582,8 @@ def _upload_file_to_gcp(self, file_path: str) -> str: compressed_file_path = self._gzip_compress(file_path) compressed_file_name = os.path.basename(compressed_file_path) + logger.info(f"OETC - Uploading file to GCP: {compressed_file_name}") + # Create GCP credentials from service key service_key_dict = json.loads( self.cloud_provider_credentials.gcp_service_key @@ -594,6 +603,8 @@ def _upload_file_to_gcp(self, file_path: str) -> str: blob.upload_from_filename(compressed_file_path) + logger.info(f"OETC - File uploaded to GCP successfully: {compressed_file_name}") + # Clean up compressed file os.remove(compressed_file_path) From c23e794c66d5d594791a73d1568a3cfbc3166024 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Aug 2025 12:39:02 +0000 Subject: [PATCH 19/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/oetc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/linopy/oetc.py b/linopy/oetc.py index 24d94325..7f1c7b22 100644 --- a/linopy/oetc.py +++ b/linopy/oetc.py @@ -603,7 +603,9 @@ def _upload_file_to_gcp(self, file_path: str) -> str: blob.upload_from_filename(compressed_file_path) - logger.info(f"OETC - File uploaded to GCP successfully: {compressed_file_name}") + logger.info( + f"OETC - File uploaded to GCP successfully: {compressed_file_name}" + ) # Clean up compressed file os.remove(compressed_file_path) From c67f2e610636d7473c3964ea62aa7f89a7b1c229 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Wed, 3 Sep 2025 15:28:18 +0200 Subject: [PATCH 20/28] Use remote argument instead of oetc_settings --- linopy/model.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index d606ff84..364fd52d 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -58,7 +58,7 @@ ) from linopy.matrices import MatrixAccessor from linopy.objective import Objective -from linopy.oetc import OetcHandler, OetcSettings +from linopy.oetc import OetcHandler from linopy.solvers import IO_APIS, available_solvers, quadratic_solvers from linopy.types import ( ConstantLike, @@ -1048,9 +1048,8 @@ def solve( sanitize_zeros: bool = True, sanitize_infinities: bool = True, slice_size: int = 2_000_000, - remote: Any = None, + remote: "RemoteHandler" | OetcHandler = None, progress: bool | None = None, - oetc_settings: OetcSettings | None = None, **solver_options: Any, ) -> tuple[str, str]: """ @@ -1139,10 +1138,10 @@ def solve( f"Keyword argument `io_api` has to be one of {IO_APIS} or None" ) - if remote or oetc_settings: - if remote and oetc_settings: - raise ValueError("Remote and OETC can't be active at the same time") - if remote: + if remote is not None: + if isinstance(remote, OetcHandler): + solved = remote.solve_on_oetc(self) # type: ignore + else: solved = remote.solve_on_remote( self, solver_name=solver_name, @@ -1156,8 +1155,6 @@ def solve( sanitize_zeros=sanitize_zeros, **solver_options, ) - else: - solved = OetcHandler(oetc_settings).solve_on_oetc(self) # type: ignore self.objective.set_value(solved.objective.value) self.status = solved.status From 2addc89a2fae7e0c7893182fe000bb4d8bfa8758 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Sep 2025 13:32:15 +0000 Subject: [PATCH 21/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/model.py b/linopy/model.py index 364fd52d..e9e0fd7b 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1048,7 +1048,7 @@ def solve( sanitize_zeros: bool = True, sanitize_infinities: bool = True, slice_size: int = 2_000_000, - remote: "RemoteHandler" | OetcHandler = None, + remote: RemoteHandler | OetcHandler = None, progress: bool | None = None, **solver_options: Any, ) -> tuple[str, str]: From 5429c8dff8766181d68b10103d8803cf55218578 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Wed, 3 Sep 2025 15:33:33 +0200 Subject: [PATCH 22/28] Fix mypy errors --- linopy/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index e9e0fd7b..f8846a13 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1048,7 +1048,7 @@ def solve( sanitize_zeros: bool = True, sanitize_infinities: bool = True, slice_size: int = 2_000_000, - remote: RemoteHandler | OetcHandler = None, + remote: "RemoteHandler" | OetcHandler = None, # type: ignore progress: bool | None = None, **solver_options: Any, ) -> tuple[str, str]: @@ -1140,7 +1140,7 @@ def solve( if remote is not None: if isinstance(remote, OetcHandler): - solved = remote.solve_on_oetc(self) # type: ignore + solved = remote.solve_on_oetc(self) else: solved = remote.solve_on_remote( self, From ee3629f38fa9d9dda0e595358b114fd0ea690d11 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Sep 2025 13:35:56 +0000 Subject: [PATCH 23/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/model.py b/linopy/model.py index f8846a13..192ae3ad 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1048,7 +1048,7 @@ def solve( sanitize_zeros: bool = True, sanitize_infinities: bool = True, slice_size: int = 2_000_000, - remote: "RemoteHandler" | OetcHandler = None, # type: ignore + remote: RemoteHandler | OetcHandler = None, # type: ignore progress: bool | None = None, **solver_options: Any, ) -> tuple[str, str]: From 3285108546cf3928e875e687ac1e395ca4caa0c2 Mon Sep 17 00:00:00 2001 From: Kristijan Faust Date: Wed, 3 Sep 2025 16:12:09 +0200 Subject: [PATCH 24/28] Update solve method docs --- linopy/model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index 192ae3ad..f1bc5651 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1109,7 +1109,7 @@ def solve( Size of the slice to use for writing the lp file. The slice size is used to split large variables and constraints into smaller chunks to avoid memory issues. The default is 2_000_000. - remote : linopy.remote.RemoteHandler + remote : linopy.remote.RemoteHandler | linopy.oetc.OetcHandler, optional Remote handler to use for solving model on a server. Note that when solving on a rSee linopy.remote.RemoteHandler for more details. @@ -1117,9 +1117,6 @@ def solve( Whether to show a progress bar of writing the lp file. The default is None, which means that the progress bar is shown if the model has more than 10000 variables and constraints. - oetc_settings : dict, optional - Settings for the solving on the OETC platform. If a value is provided - solving will be attempted on OETC, otherwise it will be done locally. **solver_options : kwargs Options passed to the solver. From 302246eda9fc30878dfb23891224d7b4fb63abd2 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 4 Sep 2025 11:34:38 +0200 Subject: [PATCH 25/28] fix: missing import of RemoteHandler --- linopy/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/linopy/model.py b/linopy/model.py index f1bc5651..0e201ffc 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -59,6 +59,7 @@ from linopy.matrices import MatrixAccessor from linopy.objective import Objective from linopy.oetc import OetcHandler +from linopy.remote import RemoteHandler from linopy.solvers import IO_APIS, available_solvers, quadratic_solvers from linopy.types import ( ConstantLike, From a0b75d5a680b9e78d5751766468b88455ec1ba8a Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 4 Sep 2025 12:32:18 +0200 Subject: [PATCH 26/28] refac: use remote dir and move stuff there --- .gitignore | 2 + linopy/__init__.py | 3 +- linopy/model.py | 3 +- linopy/remote/__init__.py | 19 ++++ linopy/{ => remote}/oetc.py | 0 linopy/{remote.py => remote/ssh.py} | 0 test/test_oetc.py | 148 ++++++++++++++++------------ 7 files changed, 107 insertions(+), 68 deletions(-) create mode 100644 linopy/remote/__init__.py rename linopy/{ => remote}/oetc.py (100%) rename linopy/{remote.py => remote/ssh.py} (100%) diff --git a/.gitignore b/.gitignore index db64b200..5c6986ab 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,8 @@ paper/ monkeytype.sqlite3 .github/copilot-instructions.md uv.lock +*.pyc + # Environments .env diff --git a/linopy/__init__.py b/linopy/__init__.py index 88ee5251..3efc297a 100644 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -20,7 +20,7 @@ from linopy.io import read_netcdf from linopy.model import Model, Variable, Variables, available_solvers from linopy.objective import Objective -from linopy.remote import RemoteHandler +from linopy.remote import OetcHandler, RemoteHandler __all__ = ( "Constraint", @@ -31,6 +31,7 @@ "LinearExpression", "Model", "Objective", + "OetcHandler", "QuadraticExpression", "RemoteHandler", "Variable", diff --git a/linopy/model.py b/linopy/model.py index 0e201ffc..7be8d57d 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -58,8 +58,7 @@ ) from linopy.matrices import MatrixAccessor from linopy.objective import Objective -from linopy.oetc import OetcHandler -from linopy.remote import RemoteHandler +from linopy.remote import OetcHandler, RemoteHandler from linopy.solvers import IO_APIS, available_solvers, quadratic_solvers from linopy.types import ( ConstantLike, diff --git a/linopy/remote/__init__.py b/linopy/remote/__init__.py new file mode 100644 index 00000000..0ae1df26 --- /dev/null +++ b/linopy/remote/__init__.py @@ -0,0 +1,19 @@ +""" +Remote execution handlers for linopy models. + +This module provides different handlers for executing optimization models +on remote systems: + +- RemoteHandler: SSH-based remote execution using paramiko +- OetcHandler: Cloud-based execution via OET Cloud service +""" + +from linopy.remote.oetc import OetcCredentials, OetcHandler, OetcSettings +from linopy.remote.ssh import RemoteHandler + +__all__ = [ + "RemoteHandler", + "OetcHandler", + "OetcSettings", + "OetcCredentials", +] diff --git a/linopy/oetc.py b/linopy/remote/oetc.py similarity index 100% rename from linopy/oetc.py rename to linopy/remote/oetc.py diff --git a/linopy/remote.py b/linopy/remote/ssh.py similarity index 100% rename from linopy/remote.py rename to linopy/remote/ssh.py diff --git a/test/test_oetc.py b/test/test_oetc.py index 9d5410ae..2f4372a4 100644 --- a/test/test_oetc.py +++ b/test/test_oetc.py @@ -7,7 +7,7 @@ import requests from requests import RequestException -from linopy.oetc import ( +from linopy.remote.oetc import ( AuthenticationResult, ComputeProvider, GcpCredentials, @@ -79,9 +79,9 @@ def mock_jwt_response(self) -> dict: "expires_in": 3600, } - @patch("linopy.oetc.requests.post") - @patch("linopy.oetc.requests.get") - @patch("linopy.oetc.datetime") + @patch("linopy.remote.oetc.requests.post") + @patch("linopy.remote.oetc.requests.get") + @patch("linopy.remote.oetc.datetime") def test_successful_authentication( self, mock_datetime: Mock, @@ -150,7 +150,7 @@ def test_successful_authentication( handler.cloud_provider_credentials.solution_bucket == "test-solution-bucket" ) - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_authentication_http_error( self, mock_post: Mock, mock_settings: OetcSettings ) -> None: @@ -173,7 +173,10 @@ class TestJwtDecoding: @pytest.fixture def handler_with_mocked_auth(self) -> OetcHandler: """Create handler with mocked authentication for testing JWT decoding""" - with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): credentials = OetcCredentials( email="test@example.com", password="test_password" ) @@ -260,7 +263,7 @@ def handler_with_mocked_auth(self, sample_jwt_token: str) -> OetcHandler: return handler - @patch("linopy.oetc.requests.get") + @patch("linopy.remote.oetc.requests.get") def test_get_gcp_credentials_success( self, mock_get: Mock, @@ -294,7 +297,7 @@ def test_get_gcp_credentials_success( assert result.input_bucket == "test-input-bucket" assert result.solution_bucket == "test-solution-bucket" - @patch("linopy.oetc.requests.get") + @patch("linopy.remote.oetc.requests.get") def test_get_gcp_credentials_http_error( self, mock_get: Mock, handler_with_mocked_auth: OetcHandler ) -> None: @@ -310,7 +313,7 @@ def test_get_gcp_credentials_http_error( assert "Failed to fetch GCP credentials" in str(exc_info.value) - @patch("linopy.oetc.requests.get") + @patch("linopy.remote.oetc.requests.get") def test_get_gcp_credentials_missing_field( self, mock_get: Mock, handler_with_mocked_auth: OetcHandler ) -> None: @@ -390,7 +393,7 @@ def test_compute_provider_enum(self) -> None: assert ComputeProvider.GCP == "GCP" assert ComputeProvider.GCP.value == "GCP" - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_authentication_network_error( self, mock_post: Mock, mock_settings: OetcSettings ) -> None: @@ -404,7 +407,7 @@ def test_authentication_network_error( assert "Authentication request failed" in str(exc_info.value) - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_authentication_invalid_response_missing_token( self, mock_post: Mock, mock_settings: OetcSettings ) -> None: @@ -425,7 +428,7 @@ def test_authentication_invalid_response_missing_token( assert "Invalid response format: missing field 'token'" in str(exc_info.value) - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_authentication_invalid_response_missing_expires_in( self, mock_post: Mock, mock_settings: OetcSettings ) -> None: @@ -448,7 +451,7 @@ def test_authentication_invalid_response_missing_expires_in( exc_info.value ) - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_authentication_timeout_error( self, mock_post: Mock, mock_settings: OetcSettings ) -> None: @@ -479,7 +482,7 @@ def test_expires_at_calculation(self, auth_result: AuthenticationResult) -> None expected_expiry = datetime(2024, 1, 15, 13, 0, 0) # 1 hour later assert auth_result.expires_at == expected_expiry - @patch("linopy.oetc.datetime") + @patch("linopy.remote.oetc.datetime") def test_is_expired_false_when_not_expired( self, mock_datetime: Mock, auth_result: AuthenticationResult ) -> None: @@ -489,7 +492,7 @@ def test_is_expired_false_when_not_expired( assert auth_result.is_expired is False - @patch("linopy.oetc.datetime") + @patch("linopy.remote.oetc.datetime") def test_is_expired_true_when_expired( self, mock_datetime: Mock, auth_result: AuthenticationResult ) -> None: @@ -499,7 +502,7 @@ def test_is_expired_true_when_expired( assert auth_result.is_expired is True - @patch("linopy.oetc.datetime") + @patch("linopy.remote.oetc.datetime") def test_is_expired_true_when_exactly_expired( self, mock_datetime: Mock, auth_result: AuthenticationResult ) -> None: @@ -514,7 +517,10 @@ class TestFileCompression: @pytest.fixture def handler_with_mocked_auth(self) -> OetcHandler: """Create handler with mocked authentication for testing file operations""" - with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): credentials = OetcCredentials( email="test@example.com", password="test_password" ) @@ -533,8 +539,8 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler - @patch("linopy.oetc.gzip.open") - @patch("linopy.oetc.os.path.exists") + @patch("linopy.remote.oetc.gzip.open") + @patch("linopy.remote.oetc.os.path.exists") @patch("builtins.open") def test_gzip_compress_success( self, @@ -592,7 +598,10 @@ def handler_with_gcp_credentials( self, mock_gcp_credentials_response: dict ) -> OetcHandler: """Create handler with GCP credentials for testing upload""" - with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): credentials = OetcCredentials( email="test@example.com", password="test_password" ) @@ -619,10 +628,10 @@ def handler_with_gcp_credentials( return handler - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.os.path.basename") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.os.path.basename") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") def test_upload_file_to_gcp_success( self, mock_creds_from_info: Mock, @@ -684,7 +693,7 @@ def test_upload_file_to_gcp_success( # Verify cleanup mock_remove.assert_called_once_with(compressed_path) - @patch("linopy.oetc.json.loads") + @patch("linopy.remote.oetc.json.loads") def test_upload_file_to_gcp_invalid_service_key( self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler ) -> None: @@ -699,8 +708,8 @@ def test_upload_file_to_gcp_invalid_service_key( assert "Failed to upload file to GCP" in str(exc_info.value) - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") def test_upload_file_to_gcp_upload_error( self, mock_creds_from_info: Mock, @@ -741,7 +750,10 @@ class TestFileDecompression: @pytest.fixture def handler_with_mocked_auth(self) -> OetcHandler: """Create handler with mocked authentication for testing file operations""" - with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): credentials = OetcCredentials( email="test@example.com", password="test_password" ) @@ -760,7 +772,7 @@ def handler_with_mocked_auth(self) -> OetcHandler: return handler - @patch("linopy.oetc.gzip.open") + @patch("linopy.remote.oetc.gzip.open") @patch("builtins.open") def test_gzip_decompress_success( self, @@ -794,7 +806,7 @@ def test_gzip_decompress_success( mock_open_file.assert_called_once_with(expected_output, "wb") mock_file_out.write.assert_called_once_with(b"decompressed_data_chunk") - @patch("linopy.oetc.gzip.open") + @patch("linopy.remote.oetc.gzip.open") def test_gzip_decompress_gzip_open_error( self, mock_gzip_open: Mock, handler_with_mocked_auth: OetcHandler ) -> None: @@ -809,7 +821,7 @@ def test_gzip_decompress_gzip_open_error( assert "Failed to decompress file" in str(exc_info.value) - @patch("linopy.oetc.gzip.open") + @patch("linopy.remote.oetc.gzip.open") @patch("builtins.open") def test_gzip_decompress_write_error( self, @@ -840,7 +852,7 @@ def test_gzip_decompress_output_path_generation( ) -> None: """Test correct output path generation for decompression""" # Test first path - with patch("linopy.oetc.gzip.open") as mock_gzip_open: + with patch("linopy.remote.oetc.gzip.open") as mock_gzip_open: with patch("builtins.open") as mock_open_file: mock_file_in = Mock() mock_file_out = Mock() @@ -852,7 +864,7 @@ def test_gzip_decompress_output_path_generation( assert result == "/tmp/file.nc" # Test second path with fresh mocks - with patch("linopy.oetc.gzip.open") as mock_gzip_open: + with patch("linopy.remote.oetc.gzip.open") as mock_gzip_open: with patch("builtins.open") as mock_open_file: mock_file_in = Mock() mock_file_out = Mock() @@ -872,7 +884,10 @@ def handler_with_gcp_credentials( self, mock_gcp_credentials_response: dict ) -> OetcHandler: """Create handler with GCP credentials for testing download""" - with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): credentials = OetcCredentials( email="test@example.com", password="test_password" ) @@ -899,10 +914,10 @@ def handler_with_gcp_credentials( return handler - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") def test_download_file_from_gcp_success( self, mock_creds_from_info: Mock, @@ -968,7 +983,7 @@ def test_download_file_from_gcp_success( # Verify cleanup mock_remove.assert_called_once_with(compressed_path) - @patch("linopy.oetc.json.loads") + @patch("linopy.remote.oetc.json.loads") def test_download_file_from_gcp_invalid_service_key( self, mock_json_loads: Mock, handler_with_gcp_credentials: OetcHandler ) -> None: @@ -983,9 +998,9 @@ def test_download_file_from_gcp_invalid_service_key( assert "Failed to download file from GCP" in str(exc_info.value) - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") def test_download_file_from_gcp_download_error( self, mock_creds_from_info: Mock, @@ -1023,10 +1038,10 @@ def test_download_file_from_gcp_download_error( assert "Failed to download file from GCP" in str(exc_info.value) - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") - @patch("linopy.oetc.storage.Client") - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.storage.Client") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") def test_download_file_from_gcp_decompression_error( self, mock_creds_from_info: Mock, @@ -1070,7 +1085,7 @@ def test_download_file_from_gcp_decompression_error( assert "Failed to download file from GCP" in str(exc_info.value) - @patch("linopy.oetc.service_account.Credentials.from_service_account_info") + @patch("linopy.remote.oetc.service_account.Credentials.from_service_account_info") def test_download_file_from_gcp_credentials_error( self, mock_creds_from_info: Mock, handler_with_gcp_credentials: OetcHandler ) -> None: @@ -1119,7 +1134,7 @@ def handler_with_auth_setup(self, sample_jwt_token: str) -> OetcHandler: return handler - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_submit_job_success( self, mock_post: Mock, handler_with_auth_setup: OetcHandler ) -> None: @@ -1162,7 +1177,7 @@ def test_submit_job_success( # Verify result assert result == expected_job_uuid - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_submit_job_http_error( self, mock_post: Mock, handler_with_auth_setup: OetcHandler ) -> None: @@ -1181,7 +1196,7 @@ def test_submit_job_http_error( assert "Failed to submit job to compute service" in str(exc_info.value) - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_submit_job_missing_uuid_in_response( self, mock_post: Mock, handler_with_auth_setup: OetcHandler ) -> None: @@ -1201,7 +1216,7 @@ def test_submit_job_missing_uuid_in_response( exc_info.value ) - @patch("linopy.oetc.requests.post") + @patch("linopy.remote.oetc.requests.post") def test_submit_job_network_error( self, mock_post: Mock, handler_with_auth_setup: OetcHandler ) -> None: @@ -1223,7 +1238,10 @@ def handler_with_complete_setup( self, mock_gcp_credentials_response: dict ) -> OetcHandler: """Create handler with complete setup for testing solve functionality""" - with patch("linopy.oetc.requests.post"), patch("linopy.oetc.requests.get"): + with ( + patch("linopy.remote.oetc.requests.post"), + patch("linopy.remote.oetc.requests.get"), + ): credentials = OetcCredentials( email="test@example.com", password="test_password" ) @@ -1249,9 +1267,9 @@ def handler_with_complete_setup( return handler - @patch("linopy.oetc.linopy.read_netcdf") - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.linopy.read_netcdf") + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_file_upload( self, mock_tempfile: Mock, @@ -1316,7 +1334,7 @@ def test_solve_on_oetc_file_upload( ) mock_remove.assert_called_once_with("/tmp/downloaded_result.nc") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_upload_failure( self, mock_tempfile: Mock, handler_with_complete_setup: OetcHandler ) -> None: @@ -1372,9 +1390,9 @@ def handler_with_full_setup(self) -> OetcHandler: return handler - @patch("linopy.oetc.linopy.read_netcdf") - @patch("linopy.oetc.os.remove") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.linopy.read_netcdf") + @patch("linopy.remote.oetc.os.remove") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_with_job_submission( self, mock_tempfile: Mock, @@ -1442,7 +1460,7 @@ def test_solve_on_oetc_with_job_submission( ) mock_remove.assert_called_once_with("/tmp/solution_file.nc") - @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_job_submission_failure( self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler ) -> None: @@ -1472,7 +1490,7 @@ def test_solve_on_oetc_job_submission_failure( assert "Job submission failed" in str(exc_info.value) - @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_job_waiting_failure( self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler ) -> None: @@ -1508,7 +1526,7 @@ def test_solve_on_oetc_job_waiting_failure( assert "Job failed: solver error" in str(exc_info.value) - @patch("linopy.oetc.tempfile.NamedTemporaryFile") + @patch("linopy.remote.oetc.tempfile.NamedTemporaryFile") def test_solve_on_oetc_no_output_files_error( self, mock_tempfile: Mock, handler_with_full_setup: OetcHandler ) -> None: @@ -1551,9 +1569,9 @@ def test_solve_on_oetc_no_output_files_error( # Additional integration-style test class TestOetcHandlerIntegration: - @patch("linopy.oetc.requests.post") - @patch("linopy.oetc.requests.get") - @patch("linopy.oetc.datetime") + @patch("linopy.remote.oetc.requests.post") + @patch("linopy.remote.oetc.requests.get") + @patch("linopy.remote.oetc.datetime") def test_complete_authentication_flow( self, mock_datetime: Mock, mock_get: Mock, mock_post: Mock ) -> None: From 8f5b6fbd970236889c4b2084483e8734d5fe020e Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 4 Sep 2025 13:04:03 +0200 Subject: [PATCH 27/28] feat: increase test coverage --- test/{ => remote}/test_oetc.py | 0 test/remote/test_oetc_job_polling.py | 313 +++++++++++++++++++++++++++ 2 files changed, 313 insertions(+) rename test/{ => remote}/test_oetc.py (100%) create mode 100644 test/remote/test_oetc_job_polling.py diff --git a/test/test_oetc.py b/test/remote/test_oetc.py similarity index 100% rename from test/test_oetc.py rename to test/remote/test_oetc.py diff --git a/test/remote/test_oetc_job_polling.py b/test/remote/test_oetc_job_polling.py new file mode 100644 index 00000000..933bad94 --- /dev/null +++ b/test/remote/test_oetc_job_polling.py @@ -0,0 +1,313 @@ +""" +Tests for OetcHandler job polling and status monitoring. + +This module tests the wait_and_get_job_data method which polls for job completion +and handles various job states and error conditions. +""" + +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest +from requests import RequestException + +from linopy.remote.oetc import ( + AuthenticationResult, + ComputeProvider, + OetcCredentials, + OetcHandler, + OetcSettings, +) + + +@pytest.fixture +def mock_settings(): + """Create mock settings for testing.""" + credentials = OetcCredentials(email="test@example.com", password="test_password") + return OetcSettings( + credentials=credentials, + name="Test Job", + authentication_server_url="https://auth.example.com", + orchestrator_server_url="https://orchestrator.example.com", + compute_provider=ComputeProvider.GCP, + ) + + +@pytest.fixture +def mock_auth_result(): + """Create mock authentication result.""" + return AuthenticationResult( + token="mock_token", + token_type="Bearer", + expires_in=3600, + authenticated_at=datetime.now(), + ) + + +@pytest.fixture +def oetc_handler(mock_settings, mock_auth_result): + """Create OetcHandler with mocked authentication.""" + with patch( + "linopy.remote.oetc.OetcHandler._OetcHandler__sign_in", + return_value=mock_auth_result, + ): + with patch( + "linopy.remote.oetc.OetcHandler._OetcHandler__get_cloud_provider_credentials" + ): + handler = OetcHandler(mock_settings) + return handler + + +class TestJobPollingSuccess: + """Test successful job polling scenarios.""" + + def test_job_completes_immediately(self, oetc_handler): + """Test job that completes on first poll.""" + job_data = { + "uuid": "job-123", + "status": "FINISHED", + "name": "test-job", + "owner": "test-user", + "solver": "highs", + "duration_in_seconds": 120, + "solving_duration_in_seconds": 90, + "input_files": ["input.nc"], + "output_files": ["output.nc"], + "created_at": "2024-01-01T00:00:00Z", + } + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + result = oetc_handler.wait_and_get_job_data("job-123") + + assert result.uuid == "job-123" + assert result.status == "FINISHED" + assert result.output_files == ["output.nc"] + mock_get.assert_called_once() + + def test_job_completes_with_no_output_files_warning(self, oetc_handler): + """Test job completion with no output files generates warning.""" + job_data = {"uuid": "job-123", "status": "FINISHED", "output_files": []} + + with patch("requests.get") as mock_get: + with patch("linopy.remote.oetc.logger.warning") as mock_warning: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + result = oetc_handler.wait_and_get_job_data("job-123") + + assert result.status == "FINISHED" + mock_warning.assert_called_once_with( + "OETC - Warning: Job completed but no output files found" + ) + + @patch("time.sleep") # Mock sleep to speed up test + def test_job_polling_progression(self, mock_sleep, oetc_handler): + """Test job progresses through multiple states before completion.""" + responses = [ + {"uuid": "job-123", "status": "PENDING"}, + {"uuid": "job-123", "status": "STARTING"}, + {"uuid": "job-123", "status": "RUNNING", "duration_in_seconds": 30}, + {"uuid": "job-123", "status": "RUNNING", "duration_in_seconds": 60}, + {"uuid": "job-123", "status": "FINISHED", "output_files": ["output.nc"]}, + ] + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.side_effect = responses + mock_get.return_value = mock_response + + result = oetc_handler.wait_and_get_job_data( + "job-123", initial_poll_interval=1 + ) + + assert result.status == "FINISHED" + assert mock_get.call_count == 5 + assert mock_sleep.call_count == 4 # Sleep called 4 times between 5 polls + + @patch("time.sleep") + def test_polling_interval_backoff(self, mock_sleep, oetc_handler): + """Test polling interval increases with exponential backoff.""" + responses = [ + {"uuid": "job-123", "status": "PENDING"}, + {"uuid": "job-123", "status": "RUNNING"}, + {"uuid": "job-123", "status": "FINISHED", "output_files": ["output.nc"]}, + ] + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.side_effect = responses + mock_get.return_value = mock_response + + oetc_handler.wait_and_get_job_data( + "job-123", initial_poll_interval=10, max_poll_interval=100 + ) + + # Verify sleep was called with increasing intervals + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] + assert sleep_calls[0] == 10 # Initial interval + assert sleep_calls[1] == 15 # 10 * 1.5 = 15 + + +class TestJobPollingErrors: + """Test job polling error scenarios.""" + + def test_setup_error_status(self, oetc_handler): + """Test job with SETUP_ERROR status raises exception.""" + job_data = {"uuid": "job-123", "status": "SETUP_ERROR"} + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Job failed during setup phase"): + oetc_handler.wait_and_get_job_data("job-123") + + def test_runtime_error_status(self, oetc_handler): + """Test job with RUNTIME_ERROR status raises exception.""" + job_data = {"uuid": "job-123", "status": "RUNTIME_ERROR"} + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Job failed during execution"): + oetc_handler.wait_and_get_job_data("job-123") + + def test_unknown_status_error(self, oetc_handler): + """Test job with unknown status raises exception.""" + job_data = {"uuid": "job-123", "status": "UNKNOWN_STATUS"} + + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = job_data + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Unknown job status: UNKNOWN_STATUS"): + oetc_handler.wait_and_get_job_data("job-123") + + +class TestJobPollingNetworkErrors: + """Test network error handling during job polling.""" + + @patch("time.sleep") + def test_network_retry_success(self, mock_sleep, oetc_handler): + """Test network errors are retried and eventually succeed.""" + successful_response = { + "uuid": "job-123", + "status": "FINISHED", + "output_files": ["output.nc"], + } + + with patch("requests.get") as mock_get: + # First two calls fail, third succeeds + mock_get.side_effect = [ + RequestException("Network error 1"), + RequestException("Network error 2"), + Mock( + raise_for_status=Mock(), json=Mock(return_value=successful_response) + ), + ] + + result = oetc_handler.wait_and_get_job_data("job-123") + + assert result.status == "FINISHED" + assert mock_get.call_count == 3 + assert mock_sleep.call_count == 2 # Retry delays + + # Verify retry delays increase + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] + assert sleep_calls[0] == 10 # First retry: 1 * 10 = 10 + assert sleep_calls[1] == 20 # Second retry: 2 * 10 = 20 + + @patch("time.sleep") + def test_max_network_retries_exceeded(self, mock_sleep, oetc_handler): + """Test max network retries causes exception.""" + with patch("requests.get") as mock_get: + # All calls fail with RequestException + mock_get.side_effect = RequestException("Network error") + + with pytest.raises( + Exception, match="Failed to get job status after 10 network retries" + ): + oetc_handler.wait_and_get_job_data("job-123") + + # Should retry exactly 10 times before failing + assert mock_get.call_count == 10 + + @patch("time.sleep") + def test_network_retry_delay_cap(self, mock_sleep, oetc_handler): + """Test network retry delay is capped at 60 seconds.""" + with patch("requests.get") as mock_get: + mock_get.side_effect = RequestException("Network error") + + with pytest.raises(Exception): + oetc_handler.wait_and_get_job_data("job-123") + + # Check that delay is capped at 60 seconds + sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] + assert all(delay <= 60 for delay in sleep_calls) + + def test_keyerror_in_response(self, oetc_handler): + """Test KeyError in response parsing raises exception.""" + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {} # Missing required 'uuid' field + mock_get.return_value = mock_response + + with pytest.raises( + Exception, match="Invalid job status response format: missing field" + ): + oetc_handler.wait_and_get_job_data("job-123") + + def test_generic_exception_handling(self, oetc_handler): + """Test generic exception handling in polling loop.""" + with patch("requests.get") as mock_get: + mock_response = Mock() + mock_response.raise_for_status.side_effect = ValueError("Unexpected error") + mock_get.return_value = mock_response + + with pytest.raises( + Exception, match="Error getting job status: Unexpected error" + ): + oetc_handler.wait_and_get_job_data("job-123") + + def test_status_error_exception_preserved(self, oetc_handler): + """Test that status-related exceptions are preserved.""" + with patch("requests.get") as mock_get: + # Simulate an exception that mentions "status:" - should be re-raised as-is + mock_response = Mock() + mock_response.raise_for_status.side_effect = Exception( + "Custom status: error" + ) + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Custom status: error"): + oetc_handler.wait_and_get_job_data("job-123") + + def test_oetc_logs_exception_preserved(self, oetc_handler): + """Test that OETC logs exceptions are preserved.""" + with patch("requests.get") as mock_get: + # Simulate an exception that mentions "OETC logs" - should be re-raised as-is + mock_response = Mock() + mock_response.raise_for_status.side_effect = Exception( + "Check the OETC logs for details" + ) + mock_get.return_value = mock_response + + with pytest.raises(Exception, match="Check the OETC logs for details"): + oetc_handler.wait_and_get_job_data("job-123") From eed5ffaf74bfcb13017f7335733556bf5cf597d3 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 4 Sep 2025 13:15:14 +0200 Subject: [PATCH 28/28] add type hints --- test/remote/test_oetc_job_polling.py | 48 ++++++++++++++++++---------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/test/remote/test_oetc_job_polling.py b/test/remote/test_oetc_job_polling.py index 933bad94..96ec98b4 100644 --- a/test/remote/test_oetc_job_polling.py +++ b/test/remote/test_oetc_job_polling.py @@ -21,7 +21,7 @@ @pytest.fixture -def mock_settings(): +def mock_settings() -> OetcSettings: """Create mock settings for testing.""" credentials = OetcCredentials(email="test@example.com", password="test_password") return OetcSettings( @@ -34,7 +34,7 @@ def mock_settings(): @pytest.fixture -def mock_auth_result(): +def mock_auth_result() -> AuthenticationResult: """Create mock authentication result.""" return AuthenticationResult( token="mock_token", @@ -45,7 +45,9 @@ def mock_auth_result(): @pytest.fixture -def oetc_handler(mock_settings, mock_auth_result): +def oetc_handler( + mock_settings: OetcSettings, mock_auth_result: AuthenticationResult +) -> OetcHandler: """Create OetcHandler with mocked authentication.""" with patch( "linopy.remote.oetc.OetcHandler._OetcHandler__sign_in", @@ -61,7 +63,7 @@ def oetc_handler(mock_settings, mock_auth_result): class TestJobPollingSuccess: """Test successful job polling scenarios.""" - def test_job_completes_immediately(self, oetc_handler): + def test_job_completes_immediately(self, oetc_handler: OetcHandler) -> None: """Test job that completes on first poll.""" job_data = { "uuid": "job-123", @@ -89,7 +91,9 @@ def test_job_completes_immediately(self, oetc_handler): assert result.output_files == ["output.nc"] mock_get.assert_called_once() - def test_job_completes_with_no_output_files_warning(self, oetc_handler): + def test_job_completes_with_no_output_files_warning( + self, oetc_handler: OetcHandler + ) -> None: """Test job completion with no output files generates warning.""" job_data = {"uuid": "job-123", "status": "FINISHED", "output_files": []} @@ -108,7 +112,9 @@ def test_job_completes_with_no_output_files_warning(self, oetc_handler): ) @patch("time.sleep") # Mock sleep to speed up test - def test_job_polling_progression(self, mock_sleep, oetc_handler): + def test_job_polling_progression( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: """Test job progresses through multiple states before completion.""" responses = [ {"uuid": "job-123", "status": "PENDING"}, @@ -133,7 +139,9 @@ def test_job_polling_progression(self, mock_sleep, oetc_handler): assert mock_sleep.call_count == 4 # Sleep called 4 times between 5 polls @patch("time.sleep") - def test_polling_interval_backoff(self, mock_sleep, oetc_handler): + def test_polling_interval_backoff( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: """Test polling interval increases with exponential backoff.""" responses = [ {"uuid": "job-123", "status": "PENDING"}, @@ -160,7 +168,7 @@ def test_polling_interval_backoff(self, mock_sleep, oetc_handler): class TestJobPollingErrors: """Test job polling error scenarios.""" - def test_setup_error_status(self, oetc_handler): + def test_setup_error_status(self, oetc_handler: OetcHandler) -> None: """Test job with SETUP_ERROR status raises exception.""" job_data = {"uuid": "job-123", "status": "SETUP_ERROR"} @@ -173,7 +181,7 @@ def test_setup_error_status(self, oetc_handler): with pytest.raises(Exception, match="Job failed during setup phase"): oetc_handler.wait_and_get_job_data("job-123") - def test_runtime_error_status(self, oetc_handler): + def test_runtime_error_status(self, oetc_handler: OetcHandler) -> None: """Test job with RUNTIME_ERROR status raises exception.""" job_data = {"uuid": "job-123", "status": "RUNTIME_ERROR"} @@ -186,7 +194,7 @@ def test_runtime_error_status(self, oetc_handler): with pytest.raises(Exception, match="Job failed during execution"): oetc_handler.wait_and_get_job_data("job-123") - def test_unknown_status_error(self, oetc_handler): + def test_unknown_status_error(self, oetc_handler: OetcHandler) -> None: """Test job with unknown status raises exception.""" job_data = {"uuid": "job-123", "status": "UNKNOWN_STATUS"} @@ -204,7 +212,9 @@ class TestJobPollingNetworkErrors: """Test network error handling during job polling.""" @patch("time.sleep") - def test_network_retry_success(self, mock_sleep, oetc_handler): + def test_network_retry_success( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: """Test network errors are retried and eventually succeed.""" successful_response = { "uuid": "job-123", @@ -234,7 +244,9 @@ def test_network_retry_success(self, mock_sleep, oetc_handler): assert sleep_calls[1] == 20 # Second retry: 2 * 10 = 20 @patch("time.sleep") - def test_max_network_retries_exceeded(self, mock_sleep, oetc_handler): + def test_max_network_retries_exceeded( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: """Test max network retries causes exception.""" with patch("requests.get") as mock_get: # All calls fail with RequestException @@ -249,7 +261,9 @@ def test_max_network_retries_exceeded(self, mock_sleep, oetc_handler): assert mock_get.call_count == 10 @patch("time.sleep") - def test_network_retry_delay_cap(self, mock_sleep, oetc_handler): + def test_network_retry_delay_cap( + self, mock_sleep: Mock, oetc_handler: OetcHandler + ) -> None: """Test network retry delay is capped at 60 seconds.""" with patch("requests.get") as mock_get: mock_get.side_effect = RequestException("Network error") @@ -261,7 +275,7 @@ def test_network_retry_delay_cap(self, mock_sleep, oetc_handler): sleep_calls = [call[0][0] for call in mock_sleep.call_args_list] assert all(delay <= 60 for delay in sleep_calls) - def test_keyerror_in_response(self, oetc_handler): + def test_keyerror_in_response(self, oetc_handler: OetcHandler) -> None: """Test KeyError in response parsing raises exception.""" with patch("requests.get") as mock_get: mock_response = Mock() @@ -274,7 +288,7 @@ def test_keyerror_in_response(self, oetc_handler): ): oetc_handler.wait_and_get_job_data("job-123") - def test_generic_exception_handling(self, oetc_handler): + def test_generic_exception_handling(self, oetc_handler: OetcHandler) -> None: """Test generic exception handling in polling loop.""" with patch("requests.get") as mock_get: mock_response = Mock() @@ -286,7 +300,7 @@ def test_generic_exception_handling(self, oetc_handler): ): oetc_handler.wait_and_get_job_data("job-123") - def test_status_error_exception_preserved(self, oetc_handler): + def test_status_error_exception_preserved(self, oetc_handler: OetcHandler) -> None: """Test that status-related exceptions are preserved.""" with patch("requests.get") as mock_get: # Simulate an exception that mentions "status:" - should be re-raised as-is @@ -299,7 +313,7 @@ def test_status_error_exception_preserved(self, oetc_handler): with pytest.raises(Exception, match="Custom status: error"): oetc_handler.wait_and_get_job_data("job-123") - def test_oetc_logs_exception_preserved(self, oetc_handler): + def test_oetc_logs_exception_preserved(self, oetc_handler: OetcHandler) -> None: """Test that OETC logs exceptions are preserved.""" with patch("requests.get") as mock_get: # Simulate an exception that mentions "OETC logs" - should be re-raised as-is