Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -122,34 +122,20 @@ def get_connection(connection: HiveConnection) -> Engine:
"kerberos_service_name"
] = connection.kerberosServiceName

# Handle SSL using SSL manager (following established patterns)
# SSL cert paths (ssl_ca_certs, ssl_certfile, ssl_keyfile) are set by ssl_manager.setup_ssl()
# via SSLManager.create_temp_file(). Do not assign sslConfig fields here directly —
# SecretStr values are not file paths and will cause a driver-level file-not-found error.
ssl_manager = check_ssl_and_init(connection)
if ssl_manager:
connection = ssl_manager.setup_ssl(connection)
# Store SSL manager for cleanup
connection._ssl_manager = ssl_manager

# Add SSL configuration to connection arguments if SSL is enabled
# use_ssl=True is a Hive-specific driver flag not set by ssl_manager, so it is handled here.
if hasattr(connection, "useSSL") and connection.useSSL:
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()
connection.connectionArguments.root["use_ssl"] = True

# Add SSL certificate configuration if available
if hasattr(connection, "sslConfig") and connection.sslConfig:
if connection.sslConfig.root.sslCertificate:
connection.connectionArguments.root[
"ssl_certfile"
] = connection.sslConfig.root.sslCertificate
if connection.sslConfig.root.sslKey:
connection.connectionArguments.root[
"ssl_keyfile"
] = connection.sslConfig.root.sslKey
if connection.sslConfig.root.caCertificate:
connection.connectionArguments.root[
"ssl_ca_certs"
] = connection.sslConfig.root.caCertificate

return create_generic_db_connection(
connection=connection,
get_connection_url_fn=get_connection_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _get_kafka_connection(broker: KafkaBrokerConfig) -> KafkaConsumer:
"ssl.key.location": broker.sslConfig.root.sslKey,
}
)

if broker.securityProtocol.value in (
KafkaSecProtocol.SASL_PLAINTEXT.value,
KafkaSecProtocol.SASL_SSL.value,
Expand Down
40 changes: 32 additions & 8 deletions ingestion/src/metadata/utils/ssl_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ def _(self, connection):
raise ValueError(
"CA certificate is required for SSL mode verify-ca or verify-full"
)
# sslcert and sslkey enable mutual TLS (client certificate authentication).
# Previously these fields were extracted by check_ssl_and_init but never
# forwarded to psycopg2, causing FATAL: connection requires a valid client
# certificate when pg_hba.conf uses cert auth.
if self.cert_file_path:
connection.connectionArguments.root["sslcert"] = self.cert_file_path
if self.key_file_path:
connection.connectionArguments.root["sslkey"] = self.key_file_path
return connection

@setup_ssl.register(SalesforceConnection)
Expand Down Expand Up @@ -278,15 +286,19 @@ def _(self, connection):
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()

# Add certificate paths if available (following MySQL pattern)
ssl_args = connection.connectionArguments.root.get("ssl", {})
# CustomHiveConnection consumes these as explicit top-level kwargs that are
# forwarded to puretransport.transport_factory via socket_kwargs (see
# custom_hive_connection.py:104-109). The nested MySQL-style ssl dict is
# not accepted by CustomHiveConnection and will raise TypeError if present.
# Pop it defensively to maintain backward compatibility with stored configs
# that may have been written by a previous version of this handler.
connection.connectionArguments.root.pop("ssl", None)
if self.ca_file_path:
ssl_args["ssl_ca"] = self.ca_file_path
connection.connectionArguments.root["ssl_ca_certs"] = self.ca_file_path
if self.cert_file_path:
ssl_args["ssl_cert"] = self.cert_file_path
connection.connectionArguments.root["ssl_certfile"] = self.cert_file_path
if self.key_file_path:
ssl_args["ssl_key"] = self.key_file_path
connection.connectionArguments.root["ssl"] = ssl_args
connection.connectionArguments.root["ssl_keyfile"] = self.key_file_path

Comment thread
SumanMaharana marked this conversation as resolved.
return connection

Expand All @@ -307,9 +319,15 @@ def _(self, connection):
connection.connectionArguments.root["TrustServerCertificate"] = "yes"

elif connection.scheme.value == "mssql+pytds":
# pytds driver SSL parameters
# pytds supports cafile, certfile, and keyfile as native connection params.
# certfile and keyfile were previously extracted by check_ssl_and_init but
# never applied here, making mutual TLS silently non-functional for pytds.
if self.ca_file_path:
connection.connectionArguments.root["cafile"] = self.ca_file_path
if self.cert_file_path:
connection.connectionArguments.root["certfile"] = self.cert_file_path
if self.key_file_path:
connection.connectionArguments.root["keyfile"] = self.key_file_path

return connection

Expand Down Expand Up @@ -468,9 +486,15 @@ def _(connection):
Union[PostgresConnection, RedshiftConnection, GreenplumConnection],
connection,
)
# Previously only caCertificate was extracted, causing sslCertificate and sslKey
# to be silently dropped. All three are now passed so setup_ssl can forward
# sslcert and sslkey to psycopg2 for mutual TLS authentication.
ssl = connection.sslConfig
if connection.sslMode:
return SSLManager(
ca=connection.sslConfig.root.caCertificate if connection.sslConfig else None
ca=ssl.root.caCertificate if ssl else None,
cert=ssl.root.sslCertificate if ssl else None,
key=ssl.root.sslKey if ssl else None,
)
return None

Expand Down
239 changes: 239 additions & 0 deletions ingestion/tests/unit/test_ssl_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,37 @@ def test_setup_ssl_pytds_driver(self):

ssl_manager.cleanup_temp_files()

def test_setup_ssl_pytds_client_cert(self):
"""Test SSL setup for pytds driver with mutual TLS (all three certs)"""
from metadata.generated.schema.entity.services.connections.database.mssqlConnection import (
MssqlConnection,
MssqlScheme,
)
from metadata.utils.ssl_manager import check_ssl_and_init

connection = MssqlConnection(
hostPort="localhost:1433",
database="testdb",
username="sa",
password="password",
scheme=MssqlScheme.mssql_pytds,
sslConfig={
"caCertificate": "caCertificateData",
"sslCertificate": "sslCertificateData",
"sslKey": "sslKeyData",
},
)

ssl_manager = check_ssl_and_init(connection)
updated_connection = ssl_manager.setup_ssl(connection)

args = updated_connection.connectionArguments.root
self.assertIsNotNone(args.get("cafile"))
self.assertIsNotNone(args.get("certfile"))
self.assertIsNotNone(args.get("keyfile"))

ssl_manager.cleanup_temp_files()

def test_setup_ssl_pymssql_driver(self):
"""Test SSL setup for pymssql driver"""
from metadata.generated.schema.entity.services.connections.database.mssqlConnection import (
Expand Down Expand Up @@ -619,3 +650,211 @@ def test_setup_ssl_verify_ca_mode(self):
)

ssl_manager.cleanup_temp_files()


class PostgresSSLManagerTest(TestCase):
"""
Tests for PostgreSQL SSL Manager functionality — including mutual TLS.
"""

def test_check_ssl_and_init_all_three_fields(self):
"""All three SSL fields are extracted into SSLManager"""
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.generated.schema.security.ssl.verifySSLConfig import SslMode
from metadata.utils.ssl_manager import check_ssl_and_init

connection = PostgresConnection(
hostPort="localhost:5432",
database="testdb",
username="postgres",
sslMode=SslMode.verify_ca,
sslConfig={
"caCertificate": "caCertificateData",
"sslCertificate": "sslCertificateData",
"sslKey": "sslKeyData",
},
)

ssl_manager = check_ssl_and_init(connection)

self.assertIsNotNone(ssl_manager)
self.assertIsNotNone(ssl_manager.ca_file_path)
self.assertIsNotNone(ssl_manager.cert_file_path)
self.assertIsNotNone(ssl_manager.key_file_path)

ssl_manager.cleanup_temp_files()

def test_setup_ssl_mutual_tls_sets_all_psycopg2_params(self):
"""setup_ssl sets sslrootcert, sslcert, and sslkey in connectionArguments"""
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.generated.schema.security.ssl.verifySSLConfig import SslMode
from metadata.utils.ssl_manager import check_ssl_and_init

connection = PostgresConnection(
hostPort="localhost:5432",
database="testdb",
username="postgres",
sslMode=SslMode.verify_ca,
sslConfig={
"caCertificate": "caCertificateData",
"sslCertificate": "sslCertificateData",
"sslKey": "sslKeyData",
},
)

ssl_manager = check_ssl_and_init(connection)
updated_connection = ssl_manager.setup_ssl(connection)

args = updated_connection.connectionArguments.root
self.assertEqual(args.get("sslmode"), "verify-ca")
self.assertIsNotNone(args.get("sslrootcert"))
self.assertIsNotNone(args.get("sslcert"))
self.assertIsNotNone(args.get("sslkey"))

ssl_manager.cleanup_temp_files()

def test_setup_ssl_ca_only_verify_ca(self):
"""Existing behaviour: CA-only verify-ca sets sslrootcert but not sslcert/sslkey"""
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.generated.schema.security.ssl.verifySSLConfig import SslMode
from metadata.utils.ssl_manager import check_ssl_and_init

connection = PostgresConnection(
hostPort="localhost:5432",
database="testdb",
username="postgres",
sslMode=SslMode.verify_ca,
sslConfig={"caCertificate": "caCertificateData"},
)

ssl_manager = check_ssl_and_init(connection)
updated_connection = ssl_manager.setup_ssl(connection)

args = updated_connection.connectionArguments.root
self.assertEqual(args.get("sslmode"), "verify-ca")
self.assertIsNotNone(args.get("sslrootcert"))
self.assertIsNone(args.get("sslcert"))
self.assertIsNone(args.get("sslkey"))

ssl_manager.cleanup_temp_files()

def test_setup_ssl_require_mode_no_ca(self):
"""sslmode=require without CA does not set sslrootcert"""
from metadata.generated.schema.entity.services.connections.database.postgresConnection import (
PostgresConnection,
)
from metadata.generated.schema.security.ssl.verifySSLConfig import SslMode
from metadata.utils.ssl_manager import check_ssl_and_init

connection = PostgresConnection(
hostPort="localhost:5432",
database="testdb",
username="postgres",
sslMode=SslMode.require,
)

ssl_manager = check_ssl_and_init(connection)
updated_connection = ssl_manager.setup_ssl(connection)

args = updated_connection.connectionArguments.root
self.assertEqual(args.get("sslmode"), "require")
self.assertIsNone(args.get("sslrootcert"))
self.assertIsNone(args.get("sslcert"))
self.assertIsNone(args.get("sslkey"))

ssl_manager.cleanup_temp_files()

def test_redshift_mutual_tls_sets_all_psycopg2_params(self):
"""RedshiftConnection shares the handler — mutual TLS params are set"""
from metadata.generated.schema.entity.services.connections.database.redshiftConnection import (
RedshiftConnection,
)
from metadata.generated.schema.security.ssl.verifySSLConfig import SslMode
from metadata.utils.ssl_manager import check_ssl_and_init

connection = RedshiftConnection(
hostPort="localhost:5439",
database="testdb",
username="redshift",
sslMode=SslMode.verify_ca,
sslConfig={
"caCertificate": "caCertificateData",
"sslCertificate": "sslCertificateData",
"sslKey": "sslKeyData",
},
)

ssl_manager = check_ssl_and_init(connection)
updated_connection = ssl_manager.setup_ssl(connection)

args = updated_connection.connectionArguments.root
self.assertEqual(args.get("sslmode"), "verify-ca")
self.assertIsNotNone(args.get("sslrootcert"))
self.assertIsNotNone(args.get("sslcert"))
self.assertIsNotNone(args.get("sslkey"))

ssl_manager.cleanup_temp_files()


class HiveSSLManagerTest(TestCase):
"""
Tests that setup_ssl for HiveConnection produces the kwarg names
that CustomHiveConnection expects (ssl_certfile, ssl_keyfile, ssl_ca_certs).
"""

def test_setup_ssl_sets_custom_hive_connection_kwargs(self):
"""ssl_ca_certs / ssl_certfile / ssl_keyfile are set at the top level"""
from metadata.generated.schema.entity.services.connections.database.hiveConnection import (
HiveConnection,
)
from metadata.utils.ssl_manager import check_ssl_and_init

connection = HiveConnection(
hostPort="localhost:10000",
useSSL=True,
sslConfig={
"caCertificate": "caCertificateData",
"sslCertificate": "sslCertificateData",
"sslKey": "sslKeyData",
},
)

ssl_manager = check_ssl_and_init(connection)
updated_connection = ssl_manager.setup_ssl(connection)

args = updated_connection.connectionArguments.root
self.assertIsNotNone(args.get("ssl_ca_certs"))
self.assertIsNotNone(args.get("ssl_certfile"))
self.assertIsNotNone(args.get("ssl_keyfile"))
# Must not fall back to the old MySQL-style nested dict
self.assertNotIn("ssl", args)

ssl_manager.cleanup_temp_files()

def test_setup_ssl_ca_only(self):
"""CA-only config sets ssl_ca_certs but not ssl_certfile or ssl_keyfile"""
from metadata.generated.schema.entity.services.connections.database.hiveConnection import (
HiveConnection,
)
from metadata.utils.ssl_manager import check_ssl_and_init

connection = HiveConnection(
hostPort="localhost:10000",
useSSL=True,
sslConfig={"caCertificate": "caCertificateData"},
)

ssl_manager = check_ssl_and_init(connection)
updated_connection = ssl_manager.setup_ssl(connection)

args = updated_connection.connectionArguments.root
self.assertIsNotNone(args.get("ssl_ca_certs"))
self.assertIsNone(args.get("ssl_certfile"))
self.assertIsNone(args.get("ssl_keyfile"))
ssl_manager.cleanup_temp_files()
Loading