diff --git a/sqlit/domains/connections/providers/snowflake/adapter.py b/sqlit/domains/connections/providers/snowflake/adapter.py index 13e9882d..2055eeb3 100644 --- a/sqlit/domains/connections/providers/snowflake/adapter.py +++ b/sqlit/domains/connections/providers/snowflake/adapter.py @@ -97,6 +97,8 @@ def connect(self, config: ConnectionConfig) -> Any: connect_args["private_key_file_pwd"] = extras["private_key_file_pwd"] if "oauth_token" in extras: connect_args["token"] = extras["oauth_token"] + if "pat_token" in extras: + connect_args["token"] = extras["pat_token"] # Pass through any extra_options to the driver connect_args.update(config.extra_options) diff --git a/sqlit/domains/connections/providers/snowflake/schema.py b/sqlit/domains/connections/providers/snowflake/schema.py index 8a3dba78..4608c285 100644 --- a/sqlit/domains/connections/providers/snowflake/schema.py +++ b/sqlit/domains/connections/providers/snowflake/schema.py @@ -17,6 +17,7 @@ def _get_snowflake_auth_options() -> tuple[SelectOption, ...]: SelectOption("externalbrowser", "SSO (Browser)"), SelectOption("snowflake_jwt", "Key Pair (JWT)"), SelectOption("oauth", "OAuth Token"), + SelectOption("PROGRAMMATIC_ACCESS_TOKEN", "Programmatic Access Token"), ) @@ -26,6 +27,8 @@ def _get_snowflake_auth_options() -> tuple[SelectOption, ...]: _AUTH_NEEDS_PRIVATE_KEY = {"snowflake_jwt"} # Auth types that need OAuth token _AUTH_NEEDS_OAUTH = {"oauth"} +# Auth types that need Programmatic Access Token +_AUTH_NEEDS_PAT = {"PROGRAMMATIC_ACCESS_TOKEN"} SCHEMA = ConnectionSchema( @@ -81,6 +84,14 @@ def _get_snowflake_auth_options() -> tuple[SelectOption, ...]: required=False, visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_OAUTH, ), + SchemaField( + name="pat_token", + label="PAT", + field_type=FieldType.PASSWORD, + placeholder="PROGRAMMATIC_ACCESS_TOKEN", + required=False, + visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_PAT, + ), _database_field(), SchemaField( name="warehouse", diff --git a/tests/unit/test_extra_options_passthrough.py b/tests/unit/test_extra_options_passthrough.py index 1ad78eff..a3de95c9 100644 --- a/tests/unit/test_extra_options_passthrough.py +++ b/tests/unit/test_extra_options_passthrough.py @@ -80,6 +80,37 @@ def test_snowflake_jwt_auth_options(self): assert call_kwargs.get("private_key_file") == "/path/to/key.p8" assert call_kwargs.get("private_key_file_pwd") == "secret" + def test_snowflake_pat_auth_options(self): + """Test Snowflake PAT authentication options are passed.""" + from sqlit.domains.connections.domain.config import ConnectionConfig, TcpEndpoint + from sqlit.domains.connections.providers.snowflake.adapter import SnowflakeAdapter + + mock_sf = MagicMock() + mock_conn = MagicMock() + mock_sf.connect.return_value = mock_conn + + with patch.dict("sys.modules", {"snowflake.connector": mock_sf}): + adapter = SnowflakeAdapter() + config = ConnectionConfig( + name="test_sf_pat", + db_type="snowflake", + endpoint=TcpEndpoint( + host="account.snowflakecomputing.com", + username="user", + database="db", + ), + options={ + "authenticator": "PROGRAMMATIC_ACCESS_TOKEN", + "pat_token": "test_pat_token", + }, + ) + + adapter.connect(config) + + call_kwargs = mock_sf.connect.call_args[1] + assert call_kwargs.get("authenticator") == "PROGRAMMATIC_ACCESS_TOKEN" + assert call_kwargs.get("token") == "test_pat_token" + def test_postgresql_passes_extra_options(self): """Test PostgreSQL adapter passes extra_options to driver.""" from sqlit.domains.connections.domain.config import ConnectionConfig, TcpEndpoint @@ -162,12 +193,13 @@ def test_snowflake_schema_has_auth_dropdown(self): break assert auth_field is not None, "Snowflake schema should have authenticator field" - assert len(auth_field.options) == 4 + assert len(auth_field.options) == 5 auth_values = [opt.value for opt in auth_field.options] assert "default" in auth_values assert "externalbrowser" in auth_values assert "snowflake_jwt" in auth_values assert "oauth" in auth_values + assert "PROGRAMMATIC_ACCESS_TOKEN" in auth_values def test_snowflake_schema_has_private_key_fields(self): """Test Snowflake schema includes private key fields for JWT auth."""