Skip to content

Commit 37d4ff7

Browse files
committed
fix(ingestion): robust SSL handling and test fixes
1 parent cdd0b3a commit 37d4ff7

12 files changed

Lines changed: 284 additions & 35 deletions

File tree

ingestion/src/metadata/ingestion/source/database/sas/client.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
from metadata.generated.schema.entity.services.connections.database.sasConnection import (
1818
SASConnection,
1919
)
20+
from metadata.generated.schema.security.ssl.verifySSLConfig import VerifySSL
2021
from metadata.ingestion.connections.source_api_client import TrackedREST
2122
from metadata.ingestion.ometa.client import APIError, ClientConfig
2223
from metadata.utils.helpers import clean_uri
2324
from metadata.utils.logger import ingestion_logger
25+
from metadata.utils.ssl_registry import get_verify_ssl_fn
2426

2527
logger = ingestion_logger()
2628

29+
SAS_CLI_AUTH_HEADER = "Basic c2FzLmNsaTo="
30+
2731

2832
class SASClient:
2933
"""
@@ -33,13 +37,14 @@ class SASClient:
3337
def __init__(self, config: SASConnection):
3438
self.config: SASConnection = config
3539
self.auth_token = self.get_token(config.serverHost, config.username, config.password.get_secret_value())
40+
3641
client_config: ClientConfig = ClientConfig(
3742
base_url=clean_uri(config.serverHost),
3843
auth_header="Authorization",
3944
auth_token=self.get_auth_token,
4045
api_version="",
4146
allow_redirects=True,
42-
verify=False,
47+
verify=self._get_verify(),
4348
)
4449
self.client = TrackedREST(client_config, source_name="sas")
4550
# custom setting
@@ -50,6 +55,22 @@ def __init__(self, config: SASConnection):
5055
self.enable_dataflows = config.dataflows
5156
self.custom_filter_dataflows = config.dataflowsCustomFilter
5257

58+
def _get_verify(self):
59+
"""
60+
Helper to determine the SSL verification strategy
61+
"""
62+
verify = True
63+
if self.config.verifySSL == VerifySSL.ignore:
64+
verify = False
65+
elif self.config.verifySSL == VerifySSL.no_ssl:
66+
verify = False
67+
elif self.config.verifySSL == VerifySSL.validate and self.config.sslConfig:
68+
try:
69+
verify = get_verify_ssl_fn(self.config.verifySSL)(self.config.sslConfig)
70+
except Exception: # pylint: disable=broad-except
71+
verify = True
72+
return verify
73+
5374
def check_connection(self):
5475
"""
5576
Check metadata connection to SAS
@@ -71,8 +92,8 @@ def get_instance(self, instance_id):
7192
"Accept": "application/vnd.sas.metadata.instance.entity.detail+json",
7293
}
7394
response = self.client.get(path=endpoint, headers=headers)
74-
if "error" in response.keys(): # noqa: SIM118
75-
raise APIError(response["error"])
95+
if response and isinstance(response, dict) and "error" in response:
96+
raise APIError({"message": response["error"]})
7697
return response
7798

7899
def get_information_catalog_link(self, instance_id):
@@ -98,8 +119,8 @@ def list_assets(self, assets):
98119
endpoint = f"catalog/search?indices={assets}&q={asset_filter if str(asset_filter) != 'None' else '*'}"
99120
headers = {"Accept-Item": "application/vnd.sas.metadata.instance.entity+json"}
100121
response = self.client.get(path=endpoint, headers=headers)
101-
if "error" in response.keys(): # noqa: SIM118
102-
raise APIError(response["error"])
122+
if response and isinstance(response, dict) and "error" in response:
123+
raise APIError({"message": response["error"]})
103124
return response["items"]
104125

105126
def get_views(self, query):
@@ -111,7 +132,7 @@ def get_views(self, query):
111132
logger.info(f"{query}")
112133
response = self.client.post(path=endpoint, data=query, headers=headers)
113134
if "error" in response.keys(): # noqa: SIM118
114-
raise APIError(f"{response}")
135+
raise APIError({"message": "Error fetching views from SAS"})
115136
return response
116137

117138
def get_data_source(self, endpoint):
@@ -120,8 +141,8 @@ def get_data_source(self, endpoint):
120141
}
121142
response = self.client.get(path=endpoint, headers=headers)
122143
logger.info(f"{response}")
123-
if "error" in response.keys(): # noqa: SIM118
124-
raise APIError(response["error"])
144+
if response and isinstance(response, dict) and "error" in response:
145+
raise APIError({"message": response["error"]})
125146
return response
126147

127148
def get_report_link(self, resource, uri):
@@ -135,8 +156,8 @@ def load_table(self, endpoint):
135156
def get_report_relationship(self, report_id):
136157
endpoint = f"reports/commons/relationships/reports/{report_id}"
137158
response = self.client.get(endpoint)
138-
if "error" in response.keys(): # noqa: SIM118
139-
raise APIError(response["error"])
159+
if response and isinstance(response, dict) and "error" in response:
160+
raise APIError({"message": response["error"]})
140161
dependencies = []
141162
for item in response["items"]:
142163
if item["type"] == "Dependent":
@@ -145,15 +166,15 @@ def get_report_relationship(self, report_id):
145166

146167
def get_resource(self, endpoint):
147168
response = self.client.get(endpoint)
148-
if "error" in response.keys(): # noqa: SIM118
149-
raise APIError(response["error"])
169+
if response and isinstance(response, dict) and "error" in response:
170+
raise APIError({"message": response["error"]})
150171
return response
151172

152173
def get_instances_with_param(self, data):
153174
endpoint = f"catalog/instances?{data}"
154175
response = self.client.get(endpoint)
155-
if "error" in response.keys(): # noqa: SIM118
156-
raise APIError(response["error"])
176+
if response and isinstance(response, dict) and "error" in response:
177+
raise APIError({"message": response["error"]})
157178
return response["items"]
158179

159180
def get_auth_token(self):
@@ -164,8 +185,31 @@ def get_token(self, base_url, user, password):
164185
payload = {"grant_type": "password", "username": user, "password": password}
165186
headers = {
166187
"Content-type": "application/x-www-form-urlencoded",
167-
"Authorization": "Basic c2FzLmNsaTo=",
188+
"Authorization": SAS_CLI_AUTH_HEADER,
168189
}
169190
url = base_url + endpoint
170-
response = requests.request("POST", url, headers=headers, data=payload, verify=False, timeout=10)
171-
return response.json()["access_token"]
191+
192+
response = requests.request(
193+
"POST",
194+
url,
195+
headers=headers,
196+
data=payload,
197+
verify=self._get_verify(),
198+
timeout=10,
199+
)
200+
logger.debug(
201+
"Token request for user: %s completed with status: %s",
202+
user,
203+
response.status_code,
204+
)
205+
try:
206+
body = response.json()
207+
except ValueError as exc:
208+
response.raise_for_status()
209+
raise RuntimeError(f"SAS token endpoint returned non-JSON response (HTTP {response.status_code})") from exc
210+
211+
response.raise_for_status()
212+
token = body.get("access_token")
213+
if not token:
214+
raise RuntimeError(f"Failed to retrieve access_token from SAS (HTTP {response.status_code})")
215+
return token

ingestion/src/metadata/ingestion/source/search/elasticsearch/connection.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
4545
TestConnectionResult,
4646
)
47+
from metadata.generated.schema.security.ssl.verifySSLConfig import VerifySSL
4748
from metadata.ingestion.connections.builders import init_empty_connection_arguments
4849
from metadata.ingestion.connections.test_connections import test_connection_steps
4950
from metadata.ingestion.ometa.ometa_api import OpenMetadata
@@ -102,22 +103,28 @@ def _handle_ssl_context_by_path(ssl_config: SslConfig):
102103
return ca_cert, client_cert, private_key
103104

104105

105-
def get_ssl_context(ssl_config: SslConfig) -> ssl.SSLContext:
106+
def get_ssl_context(
107+
ssl_config: Optional[SslConfig], verify_ssl: Optional[VerifySSL] = VerifySSL.validate
108+
) -> Optional[ssl.SSLContext]:
106109
"""
107110
Method to get SSL Context
108111
"""
112+
if verify_ssl == VerifySSL.ignore:
113+
return ssl._create_unverified_context() # pylint: disable=protected-access
114+
115+
if verify_ssl == VerifySSL.no_ssl:
116+
return None
117+
109118
ca_cert = False
110119
client_cert = None
111120
private_key = None
112121
cert_chain = None
113122

114-
if not ssl_config.certificates:
115-
return None
116-
117-
if isinstance(ssl_config.certificates, SslCertificatesByValues):
118-
ca_cert, client_cert, private_key = _handle_ssl_context_by_value(ssl_config=ssl_config)
119-
elif isinstance(ssl_config.certificates, SslCertificatesByPath):
120-
ca_cert, client_cert, private_key = _handle_ssl_context_by_path(ssl_config=ssl_config)
123+
if ssl_config and ssl_config.certificates:
124+
if isinstance(ssl_config.certificates, SslCertificatesByValues):
125+
ca_cert, client_cert, private_key = _handle_ssl_context_by_value(ssl_config=ssl_config)
126+
elif isinstance(ssl_config.certificates, SslCertificatesByPath):
127+
ca_cert, client_cert, private_key = _handle_ssl_context_by_path(ssl_config=ssl_config)
121128

122129
if client_cert and private_key:
123130
cert_chain = (client_cert, private_key)
@@ -133,7 +140,7 @@ def get_ssl_context(ssl_config: SslConfig) -> ssl.SSLContext:
133140
)
134141
return ssl_context # noqa: RET504
135142

136-
return ssl._create_unverified_context() # pylint: disable=protected-access
143+
return ssl.create_default_context()
137144

138145

139146
def get_connection(connection: ElasticsearchConnection) -> Elasticsearch:
@@ -146,7 +153,7 @@ def get_connection(connection: ElasticsearchConnection) -> Elasticsearch:
146153
if isinstance(connection.authType, BasicAuthentication) and connection.authType.username:
147154
basic_auth = (
148155
connection.authType.username,
149-
connection.authType.password.get_secret_value() if connection.authType.password else None,
156+
(connection.authType.password.get_secret_value() if connection.authType.password else None),
150157
)
151158

152159
if isinstance(connection.authType, ApiKeyAuthentication):
@@ -161,8 +168,7 @@ def get_connection(connection: ElasticsearchConnection) -> Elasticsearch:
161168
if not connection.connectionArguments:
162169
connection.connectionArguments = init_empty_connection_arguments()
163170

164-
if connection.sslConfig:
165-
ssl_context = get_ssl_context(connection.sslConfig)
171+
ssl_context = get_ssl_context(connection.sslConfig, connection.verifySSL)
166172

167173
return Elasticsearch(
168174
str(connection.hostPort),
@@ -188,15 +194,16 @@ def test_connection(
188194
def test_get_search_indexes():
189195
try:
190196
result = client.indices.get_alias(expand_wildcards="open")
191-
if result is None:
192-
raise ConnectionError("Failed to retrieve search indexes from Elasticsearch") # noqa: TRY301
193-
return result # noqa: TRY300
194197
except Exception as exc:
195198
raise ConnectionError(
196199
f"Unable to connect to Elasticsearch or retrieve indexes: {exc}. "
197200
"Please check your Elasticsearch connection configuration and cluster health."
198201
) from exc
199202

203+
if result is None:
204+
raise ConnectionError("Failed to retrieve search indexes from Elasticsearch")
205+
return result
206+
200207
test_fn = {
201208
"CheckAccess": client.info,
202209
"GetSearchIndexes": test_get_search_indexes,

ingestion/src/metadata/utils/secrets/aws_secrets_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_string_value(self, secret_id: str) -> Optional[str]: # noqa: UP045
5252
try:
5353
kwargs = {"SecretId": secret_id}
5454
response = self.client.get_secret_value(**kwargs)
55-
logger.debug("Got value for secret %s.", secret_id)
55+
logger.debug("Successfully retrieved value from secrets manager.")
5656
except ClientError as err:
5757
logger.debug(traceback.format_exc())
5858
logger.error(f"Couldn't get value for secret [{secret_id}]: {err}")
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 Collate
2+
# Licensed under the Collate Community License, Version 1.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
Unit tests for SASClient
13+
"""
14+
15+
from unittest.mock import MagicMock, patch
16+
17+
import pytest
18+
from pydantic import SecretStr
19+
20+
from metadata.generated.schema.entity.services.connections.database.sasConnection import (
21+
SASConnection,
22+
)
23+
from metadata.generated.schema.security.ssl.verifySSLConfig import VerifySSL
24+
from metadata.ingestion.ometa.client import APIError
25+
from metadata.ingestion.source.database.sas.client import SASClient
26+
27+
MOCK_CONFIG = SASConnection(
28+
username="user",
29+
password=SecretStr("pass"),
30+
serverHost="http://sas.com",
31+
verifySSL=VerifySSL.validate,
32+
)
33+
34+
35+
class TestSASClient:
36+
@patch("metadata.ingestion.source.database.sas.client.requests.request")
37+
@patch("metadata.ingestion.source.database.sas.client.TrackedREST")
38+
def test_init_success(self, mock_rest, mock_request):
39+
mock_response = MagicMock()
40+
mock_response.status_code = 200
41+
mock_response.json.return_value = {"access_token": "token123"}
42+
mock_request.return_value = mock_response
43+
44+
client = SASClient(MOCK_CONFIG)
45+
46+
assert client.auth_token == "token123"
47+
mock_rest.assert_called_once()
48+
# Verify default verify is True (VerifySSL.validate)
49+
args, _ = mock_rest.call_args
50+
assert args[0].verify is True
51+
52+
@patch("metadata.ingestion.source.database.sas.client.requests.request")
53+
@patch("metadata.ingestion.source.database.sas.client.TrackedREST")
54+
def test_init_verify_ignore(self, mock_rest, mock_request):
55+
mock_response = MagicMock()
56+
mock_response.status_code = 200
57+
mock_response.json.return_value = {"access_token": "token123"}
58+
mock_request.return_value = mock_response
59+
60+
config = MOCK_CONFIG.model_copy()
61+
config.verifySSL = VerifySSL.ignore
62+
63+
SASClient(config)
64+
65+
args, _ = mock_rest.call_args
66+
assert args[0].verify is False
67+
68+
@patch("metadata.ingestion.source.database.sas.client.requests.request")
69+
@patch("metadata.ingestion.source.database.sas.client.TrackedREST")
70+
def test_get_token_error_handling(self, mock_rest, mock_request):
71+
# Case 1: Non-JSON response (with successful HTTP status)
72+
mock_response = MagicMock()
73+
mock_response.status_code = 200
74+
mock_response.raise_for_status.return_value = None
75+
mock_response.json.side_effect = ValueError("Not JSON")
76+
mock_request.return_value = mock_response
77+
78+
with pytest.raises(RuntimeError) as exc:
79+
SASClient(MOCK_CONFIG)
80+
assert "non-JSON response" in str(exc.value)
81+
82+
# Case 2: Missing access_token
83+
mock_response.status_code = 200
84+
mock_response.json.side_effect = None
85+
mock_response.json.return_value = {"something": "else"}
86+
mock_response.raise_for_status.return_value = None
87+
88+
with pytest.raises(RuntimeError) as exc:
89+
SASClient(MOCK_CONFIG)
90+
assert "Failed to retrieve access_token" in str(exc.value)
91+
92+
@patch("metadata.ingestion.source.database.sas.client.requests.request")
93+
@patch("metadata.ingestion.source.database.sas.client.TrackedREST")
94+
def test_api_error_wrapping(self, mock_rest, mock_request):
95+
mock_response = MagicMock()
96+
mock_response.status_code = 200
97+
mock_response.json.return_value = {"access_token": "token123"}
98+
mock_request.return_value = mock_response
99+
100+
client = SASClient(MOCK_CONFIG)
101+
102+
# Mocking client.get to return an error key
103+
client.client.get.return_value = {"error": "Unauthorized access"}
104+
105+
with pytest.raises(APIError) as exc:
106+
client.get_instance("123")
107+
assert exc.value.args[0] == "Unauthorized access"
108+
109+
@patch("metadata.ingestion.source.database.sas.client.requests.request")
110+
@patch("metadata.ingestion.source.database.sas.client.TrackedREST")
111+
def test_get_views_error_sanitization(self, mock_rest, mock_request):
112+
mock_response = MagicMock()
113+
mock_response.status_code = 200
114+
mock_response.json.return_value = {"access_token": "token123"}
115+
mock_request.return_value = mock_response
116+
117+
client = SASClient(MOCK_CONFIG)
118+
119+
# Mocking client.post to return an error key
120+
client.client.post.return_value = {"error": "Sensitive internal details"}
121+
122+
with pytest.raises(APIError) as exc:
123+
client.get_views("query")
124+
# Should show the sanitized message instead of the raw error
125+
assert exc.value.args[0] == "Error fetching views from SAS"

0 commit comments

Comments
 (0)