diff --git a/tests/perf/data_source/dbapi_test_framework/README.md b/tests/perf/data_source/dbapi_test_framework/README.md new file mode 100644 index 0000000000..be0999bac6 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/README.md @@ -0,0 +1,605 @@ +# DBAPI Test Framework + +Comprehensive framework for testing DBAPI ingestion performance across different databases and ingestion methods. + +## Structure + +``` +dbapi_test_framework/ +├── config.py # Configuration (loads from .env, test matrix) +├── connections.py # Connection factories for each DBMS +├── runner.py # Test runner with 4 ingestion methods +├── main.py # Entry point +├── .env # Environment variables (credentials - not in git) +└── db_setup_util/ # Database setup scripts + ├── common_schema.py # Unified schema + data generator + ├── base_setup.py # Base setup class + ├── mysql_setup.py # MySQL table setup + ├── postgres_setup.py # PostgreSQL table setup + ├── oracle_setup.py # Oracle table setup + ├── mssql_setup.py # SQL Server table setup + ├── databricks_setup.py # Databricks table setup + └── setup_all.py # Run all setups at once +``` + +## Quick Start + +### 1. Set Up Environment Variables + +Create a `.env` file in the `dbapi_test_framework/` directory: + +```env +# Snowflake +SNOWFLAKE_ACCOUNT=your-account +SNOWFLAKE_USER=your-user +SNOWFLAKE_PASSWORD=your-password +SNOWFLAKE_DATABASE=your-database +SNOWFLAKE_SCHEMA=your-schema +SNOWFLAKE_WAREHOUSE=your-warehouse +SNOWFLAKE_ROLE=your-role +SNOWFLAKE_HOST=your-account.snowflakecomputing.com +SNOWFLAKE_PORT=443 +SNOWFLAKE_PROTOCOL=https + +# MySQL +MYSQL_HOST=your-mysql-host +MYSQL_PORT=3306 +MYSQL_USERNAME=your-user +MYSQL_PASSWORD=your-password +MYSQL_DATABASE=your-database + +# PostgreSQL +POSTGRES_HOST=your-postgres-host +POSTGRES_PORT=5432 +POSTGRES_USER=your-user +POSTGRES_PASSWORD=your-password +POSTGRES_DBNAME=your-database + +# SQL Server +MSSQL_SERVER=your-sqlserver-host +MSSQL_PORT=1433 +MSSQL_UID=your-user +MSSQL_PWD=your-password +MSSQL_DATABASE=test_db +MSSQL_DRIVER={ODBC Driver 18 for SQL Server} + +# Oracle +ORACLEDB_HOST=your-oracle-host +ORACLEDB_PORT=1521 +ORACLEDB_USERNAME=your-user +ORACLEDB_PASSWORD=your-password +ORACLEDB_SERVICE_NAME=your-service + +# Databricks +DATABRICKS_SERVER_HOSTNAME=your-workspace.databricks.net +DATABRICKS_HTTP_PATH=sql/protocolv1/o/... +DATABRICKS_ACCESS_TOKEN=your-token +``` + +### 2. Set Up Test Data + +Create test tables with identical data across all databases: + +```bash +cd db_setup_util + +# Setup individual database +python3 mysql_setup.py +python3 postgres_setup.py +python3 oracle_setup.py +python3 mssql_setup.py +python3 databricks_setup.py + +# Or setup all at once +python3 setup_all.py +``` + +This creates `DBAPI_TEST_TABLE` with 10,000 rows of deterministic test data in each database. + +### 3. Run Tests + +```bash +# From the dbapi_test_framework directory +cd tests/perf/data_source/dbapi_test_framework + +# Run a single test +python3 main.py + +# Run full test matrix +python3 main.py --matrix +``` + +Or as a module from project root: +```bash +python3 -m tests.perf.data_source.dbapi_test_framework.main +python3 -m tests.perf.data_source.dbapi_test_framework.main --matrix +``` + +## Ingestion Methods + +The framework tests 7 different ingestion approaches: + +### DBAPI Methods (Python drivers) + +1. **local** - Local ingestion using `session.read.dbapi()` + - Data fetched locally and uploaded to Snowflake + - No external access integration needed + +2. **udtf** - UDTF ingestion using `session.read.dbapi(udtf_configs=...)` + - Data fetched via UDTF running on Snowflake + - Requires external access integration + +3. **local_sproc** - Local ingestion inside a stored procedure + - Local ingestion logic runs inside a Snowflake stored procedure + - Requires external access integration + packages + +4. **udtf_sproc** - UDTF ingestion inside a stored procedure + - UDTF ingestion logic runs inside a Snowflake stored procedure + - Requires external access integration + packages + +### JDBC Methods (Java drivers) + +5. **jdbc** - JDBC ingestion using `session.read.jdbc()` + - Data fetched via JDBC UDTF running on Snowflake + - Requires JDBC driver JAR, Snowflake secret, and external access integration + - **Only supports UDTF-based ingestion** (no local mode) + - Driver JARs are automatically uploaded to stage on first run + +6. **jdbc_sproc** - JDBC ingestion inside a stored procedure + - JDBC ingestion logic runs inside a Snowflake stored procedure + - Requires JDBC driver JAR, Snowflake secret, and external access integration + - Driver JARs are automatically uploaded to stage on first run + +### PySpark Methods (Spark + JDBC) + +7. **pyspark** - PySpark JDBC ingestion + - Data fetched via PySpark JDBC running on local Spark session + - Written to Snowflake using Snowflake-Spark connector + - Requires PySpark, JDBC driver JARs in `drivers/` directory, and Snowflake-Spark connector + - Uses plain credentials from `.env` (not Snowflake secrets) + - Runs on local machine (not on Snowflake servers) + +## Supported Databases + +All databases support DBAPI, JDBC, and PySpark methods: + +- **MySQL** (DBAPI: `pymysql`, JDBC: `mysql-connector-j`) +- **PostgreSQL** (DBAPI: `psycopg2`, JDBC: `postgresql`) +- **MS SQL Server** (DBAPI: `pyodbc`, JDBC: `mssql-jdbc`) +- **Oracle** (DBAPI: `oracledb`, JDBC: `ojdbc`) +- **Databricks** (DBAPI: `databricks-sql-connector`, JDBC: `DatabricksJDBC42`) + +**Note**: PySpark method uses the same JDBC drivers as the JDBC methods. + +## Configuration + +### Test Matrix + +The test matrix in `config.py` supports two source types: + +```python +# Test config format +{ + "dbms": "mysql", + "source": { + "type": "table|query", # Type: table or query + "value": "DBAPI_TEST_TABLE" # Table name or SQL query + }, + "ingestion_method": "local|udtf|local_sproc|udtf_sproc|jdbc|jdbc_sproc" +} +``` + +### Generate Test Matrix + +Use list comprehensions for compact configuration: + +```python +_DBMS_LIST = ["mysql", "postgres", "mssql", "oracle", "databricks"] +_METHODS = ["local", "udtf"] + +TEST_MATRIX = [ + # Table-based tests + *[ + { + "dbms": dbms, + "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, + "ingestion_method": method, + } + for dbms in _DBMS_LIST + for method in _METHODS + ], + # Query-based tests + *[ + { + "dbms": dbms, + "source": {"type": "query", "value": "SELECT * FROM DBAPI_TEST_TABLE"}, + "ingestion_method": method, + } + for dbms in _DBMS_LIST + for method in _METHODS + ], +] +``` + +This generates 20 tests (5 DBMS × 2 methods × 2 source types). + +### External Access Integrations + +Required for UDTF and stored procedure methods: + +```python +UDTF_CONFIGS = { + "mysql": { + "external_access_integration": "snowpark_dbapi_mysql_test_integration", + }, + "postgres": { + "external_access_integration": "snowpark_dbapi_postgres_test_integration", + }, + "mssql": { + "external_access_integration": "snowpark_dbapi_sql_server_test_integration", + }, + "oracle": { + "external_access_integration": "snowpark_dbapi_oracledb_test_integration", + }, + "databricks": { + "external_access_integration": "snowpark_dbapi_databricks_test_integration", + }, +} +``` + +### Stored Procedure Packages + +Automatically configured per DBMS: + +```python +SPROC_PACKAGES = { + "mysql": ["pymysql"], + "postgres": ["psycopg2"], + "mssql": ["pyodbc", "msodbcsql"], + "oracle": ["oracledb"], + "databricks": ["databricks-sql-connector"], +} +``` + +### Runtime Options + +```python +# Show table info (row count + first row) before cleanup +SHOW_TARGET_TABLE_INFO = True # Default + +# Cleanup target tables after tests +CLEANUP_TARGET_TABLES = False # Set to True in production + +# DBAPI parameters (optional) +DBAPI_PARAMS = { + "fetch_size": 10000, + "max_workers": 4, +} +``` + +## Example Configurations + +### Table-Based Ingestion +```python +{ + "dbms": "mysql", + "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, + "ingestion_method": "local" +} +``` + +### Query-Based Ingestion +```python +{ + "dbms": "postgres", + "source": { + "type": "query", + "value": "SELECT id, varchar_col FROM DBAPI_TEST_TABLE WHERE id < 5000" + }, + "ingestion_method": "udtf" +} +``` + +### Stored Procedure Ingestion +```python +{ + "dbms": "oracle", + "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, + "ingestion_method": "udtf_sproc" # Runs inside Snowflake stored procedure +} +``` + +### JDBC Ingestion +```python +{ + "dbms": "mysql", + "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, + "ingestion_method": "jdbc" # Uses JDBC driver instead of Python DBAPI +} +``` + +### JDBC in Stored Procedure +```python +{ + "dbms": "postgres", + "source": { + "type": "query", + "value": "SELECT * FROM DBAPI_TEST_TABLE WHERE id BETWEEN 1 AND 1000" + }, + "ingestion_method": "jdbc_sproc" # JDBC inside stored procedure +} +``` + +### PySpark Ingestion +```python +{ + "dbms": "mysql", + "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, + "ingestion_method": "pyspark", # PySpark JDBC on local Spark session + "dbapi_params": { + "fetchsize": 10000, + "numPartitions": 10, + "partitionColumn": "id", + "lowerBound": 0, + "upperBound": 100000 + } +} +``` + +## Output + +### During Test Execution + +``` +############################################################ +TEST: MYSQL - LOCAL +Source Type: TABLE +Source Value: DBAPI_TEST_TABLE +############################################################ + +============================================================ +Running: LOCAL INGESTION +============================================================ +✓ Completed in 12.45 seconds + +============================================================ +TARGET TABLE INFO +============================================================ +Row count: 10000 + +First row: +--------------------------------------------------------- +|"ID" |"INT_COL" |"BIGINT_COL" |"VARCHAR_COL" |... | +--------------------------------------------------------- +|1 |83811 |478163328 |varchar_0_... |... | +--------------------------------------------------------- + +✓ Cleaned up target table: TEST_MYSQL_LOCAL_1234567890 + +Test completed: success +``` + +### Test Summary (Matrix Mode) + +``` +================================================================================ +TEST SUMMARY +================================================================================ + Status DBMS Method Source Value Time +-------------------------------------------------------------------------------- + ✓ mysql local table DBAPI_TEST_TABLE 12.34s + ✓ mysql udtf table DBAPI_TEST_TABLE 15.67s + ✓ postgres local table DBAPI_TEST_TABLE 11.23s + ✓ postgres udtf table DBAPI_TEST_TABLE 14.56s + ✓ mysql local query SELECT * FROM DBAPI... 10.89s + ✓ mysql udtf query SELECT * FROM DBAPI... 13.45s +-------------------------------------------------------------------------------- +Total: 6 | Success: 6 | Failed: 0 +``` + +## Prerequisites + +### Python Packages + +```bash +pip install python-dotenv # For .env support +pip install pymysql # MySQL +pip install psycopg2-binary # PostgreSQL +pip install pyodbc # SQL Server +pip install oracledb # Oracle +pip install databricks-sql-connector # Databricks +pip install pyspark # For PySpark ingestion method +``` + +**Note for PySpark**: You also need the Snowflake-Spark connector and Snowflake JDBC driver: +- **Snowflake Spark Connector**: `spark-snowflake_2.13-3.1.0.jar` (Scala 2.13, recommended) + - Download: https://mvnrepository.com/artifact/net.snowflake/spark-snowflake_2.13/3.1.0 + - **Important**: Use version 3.1.0 - version 3.1.1+ has issues with Oracle BLOB types +- **Snowflake JDBC**: `snowflake-jdbc-3.19.0.jar` or later + - Download: https://mvnrepository.com/artifact/net.snowflake/snowflake-jdbc +- Place both JARs in the `drivers/` directory + +See `drivers/jdbc_drivers_readme.md` for more details and known issues. + +### Snowflake Setup + +For UDTF and stored procedure methods, create external access integrations: + +```sql +-- Example for MySQL +CREATE OR REPLACE NETWORK RULE snowpark_dbapi_mysql_network_rule + MODE = EGRESS + TYPE = HOST_PORT + VALUE_LIST = ('your-mysql-host:3306'); + +CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION snowpark_dbapi_mysql_test_integration + ALLOWED_NETWORK_RULES = (snowpark_dbapi_mysql_network_rule) + ENABLED = TRUE; +``` + +Repeat for each DBMS you want to test with UDTF/stored procedure methods. + +### JDBC Setup (for jdbc and jdbc_sproc methods) + +#### 1. Download JDBC Drivers + +Download the required JDBC driver JARs and place them in the `drivers/` directory: + +```bash +cd tests/perf/data_source/dbapi_test_framework +mkdir -p drivers +# Download drivers from official sources (see drivers/README.md) +``` + +See `drivers/README.md` for download links and specific versions. + +#### 2. Create Snowflake Secrets + +JDBC methods require Snowflake secrets to store database credentials: + +```sql +-- Example for MySQL +CREATE OR REPLACE SECRET snowpark_dbapi_mysql_test_cred + TYPE = USERNAME_PASSWORD + USERNAME = 'your_mysql_user' + PASSWORD = 'your_mysql_password'; + +-- Grant usage to the role +GRANT USAGE ON SECRET snowpark_dbapi_mysql_test_cred TO ROLE your_role; +``` + +Repeat for each database you want to test with JDBC. + +**Secret naming convention:** +- MySQL: `ADMIN.PUBLIC.SNOWPARK_DBAPI_MYSQL_TEST_CRED` +- PostgreSQL: `ADMIN.PUBLIC.SNOWPARK_DBAPI_POSTGRES_TEST_CRED` +- SQL Server: `ADMIN.PUBLIC.SNOWPARK_DBAPI_SQL_SERVER_TEST_CRED` +- Oracle: `ADMIN.PUBLIC.SNOWPARK_DBAPI_ORACLEDB_TEST_CRED` +- Databricks: `ADMIN.PUBLIC.SNOWPARK_DBAPI_DATABRICKS_TEST_CRED` + +You can customize secret names via environment variables (e.g., `MYSQL_SECRET`). + +#### 3. Update External Access Integration (if needed) + +The JDBC methods use the same external access integrations as DBAPI UDTF methods. No additional setup required if you already have UDTF working. + +#### 4. Automatic Driver Upload + +When you run JDBC tests: +- The framework automatically checks if the driver is on the stage +- If not found, it uploads from your local `drivers/` directory +- Subsequent runs skip the upload (fast) +- Each test only uploads the specific driver it needs + +### PySpark Setup (for pyspark method) + +#### 1. JDBC Drivers + +PySpark uses the same JDBC drivers as the JDBC methods. Ensure drivers are in the `drivers/` directory (see JDBC Setup above). + +#### 2. Snowflake-Spark Connector + +Download the Snowflake-Spark connector JAR: +- Maven Central: `net.snowflake:spark-snowflake_2.12` or `spark-snowflake_2.13` +- Place in `drivers/` directory alongside JDBC drivers +- Or configure `spark.jars` in `config.PYSPARK_SESSION_CONFIG` + +#### 3. Configuration + +Customize PySpark session settings in `config.py`: + +```python +PYSPARK_SESSION_CONFIG = { + "spark.master": "local[*]", + "spark.driver.extraClassPath": "./drivers/*", + # Optional: Tune for your machine + "spark.sql.shuffle.partitions": 16, + "spark.default.parallelism": 16, + "spark.executor.cores": 8, + "spark.executor.memory": "16g", +} +``` + +#### 4. Credentials + +PySpark uses plain credentials from `.env` (not Snowflake secrets): +- Reads directly from `MYSQL_USERNAME`, `MYSQL_PASSWORD`, etc. +- No external access integration required +- Runs entirely on local machine + +## Database Setup Utilities + +The `db_setup_util/` directory contains scripts to create identical test tables across all databases. + +### Features +- Same table name: `DBAPI_TEST_TABLE` +- Same 15 columns with compatible types +- Deterministic data (row #N is identical across all DBMS) +- 10,000 rows by default (configurable) + +### Usage + +See `db_setup_util/README.md` for detailed documentation. + +## Troubleshooting + +### Connection Issues +- Verify `.env` file exists and has correct credentials +- Check that `python-dotenv` is installed +- Test connection to source databases independently + +### Permission Issues +- Ensure Snowflake user has CREATE TABLE privileges +- Ensure source database users have SELECT privileges +- For stored procedures, ensure external access integrations are granted + +### Package Issues +- Install all required Python packages +- For SQL Server, ensure ODBC driver is installed (`msodbcsql` or `ODBC Driver 18 for SQL Server`) +- Check package versions are compatible + +### Table Not Found +- Run the setup scripts in `db_setup_util/` first +- Verify table name matches config (`DBAPI_TEST_TABLE` by default) + +## Advanced Usage + +### Custom DBAPI Parameters + +```python +DBAPI_PARAMS = { + "fetch_size": 5000, + "max_workers": 8, +} +``` + +### Mixed Test Matrix + +```python +TEST_MATRIX = [ + # Table with local method + {"dbms": "mysql", "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, "ingestion_method": "local"}, + + # Query with UDTF + {"dbms": "postgres", "source": {"type": "query", "value": "SELECT * FROM DBAPI_TEST_TABLE WHERE id < 5000"}, "ingestion_method": "udtf"}, + + # Stored procedure + {"dbms": "oracle", "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, "ingestion_method": "local_sproc"}, +] +``` + +### Keep Tables for Debugging + +```python +# In config.py +CLEANUP_TARGET_TABLES = False # Tables persist after tests +SHOW_TARGET_TABLE_INFO = True # Show row count + first row +``` + +## Architecture Notes + +- **Lazy evaluation**: Connection parameters are loaded from `.env` via `config.py` +- **Clean imports**: Uses try/except pattern for both direct and module execution +- **Extensible**: Easy to add new DBMS or ingestion methods +- **DRY principle**: Common logic in base classes, DBMS-specific logic isolated +- **Configurable**: Runtime behavior controlled via config variables diff --git a/tests/perf/data_source/dbapi_test_framework/__init__.py b/tests/perf/data_source/dbapi_test_framework/__init__.py new file mode 100644 index 0000000000..c1a753cccc --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# diff --git a/tests/perf/data_source/dbapi_test_framework/config.py b/tests/perf/data_source/dbapi_test_framework/config.py new file mode 100644 index 0000000000..4edd9e5ae6 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/config.py @@ -0,0 +1,263 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Simple configuration for DBAPI ingestion tests. + +Define your test matrix here by specifying: +- dbms: which database to test +- table: existing table name to read from +- ingestion_method: 'local', 'udtf', 'local_sproc', 'udtf_sproc' +- fetch_size, max_workers: DBAPI parameters +""" + +import os + +# Load environment variables from .env file in the same directory if it exists and dotenv is installed +try: + from dotenv import load_dotenv + + load_dotenv() +except ImportError: + pass + +from db_setup_util.large_query_generation import get_large_query + +# Snowflake connection parameters +SNOWFLAKE_PARAMS = { + "account": os.getenv("SNOWFLAKE_ACCOUNT"), + "user": os.getenv("SNOWFLAKE_USER"), + "password": os.getenv("SNOWFLAKE_PASSWORD"), + "database": os.getenv("SNOWFLAKE_DATABASE"), + "schema": os.getenv("SNOWFLAKE_SCHEMA"), + "warehouse": os.getenv("SNOWFLAKE_WAREHOUSE"), + "role": os.getenv("SNOWFLAKE_ROLE"), + "host": os.getenv("SNOWFLAKE_HOST"), + "port": int(os.getenv("SNOWFLAKE_PORT", 443)), + "protocol": os.getenv("SNOWFLAKE_PROTOCOL", "https"), +} + +# Source database connection parameters +MYSQL_PARAMS = { + "host": os.getenv("MYSQL_HOST"), + "port": int(os.getenv("MYSQL_PORT", 3306)), + "user": os.getenv("MYSQL_USERNAME"), # Connection function expects 'user' + "password": os.getenv("MYSQL_PASSWORD"), + "database": os.getenv("MYSQL_DATABASE"), +} + +POSTGRES_PARAMS = { + "host": os.getenv("POSTGRES_HOST"), + "port": int(os.getenv("POSTGRES_PORT", 5432)), + "user": os.getenv("POSTGRES_USER"), + "password": os.getenv("POSTGRES_PASSWORD"), + "database": os.getenv("POSTGRES_DBNAME"), # Connection function expects 'database' +} + +MSSQL_PARAMS = { + "host": os.getenv("MSSQL_SERVER"), + "port": int(os.getenv("MSSQL_PORT", 1433)), + "user": os.getenv("MSSQL_UID"), + "password": os.getenv("MSSQL_PWD"), + "database": os.getenv("MSSQL_DATABASE", "test_db"), # Default to test_db + "driver": os.getenv("MSSQL_DRIVER", "{ODBC Driver 18 for SQL Server}"), +} + +ORACLE_PARAMS = { + "host": os.getenv("ORACLEDB_HOST"), + "port": int(os.getenv("ORACLEDB_PORT", 1521)), + "user": os.getenv("ORACLEDB_USERNAME"), # Connection function expects 'user' + "password": os.getenv("ORACLEDB_PASSWORD"), + "service_name": os.getenv("ORACLEDB_SERVICE_NAME"), +} + +DATABRICKS_PARAMS = { + "server_hostname": os.getenv("DATABRICKS_SERVER_HOSTNAME"), + "http_path": os.getenv("DATABRICKS_HTTP_PATH"), + "access_token": os.getenv("DATABRICKS_ACCESS_TOKEN"), +} + +# DBAPI ingestion parameters +DBAPI_PARAMS = {} # CHANGE ME TO RUN THE DBAPI PARAMETERS YOU WANT +DBAPI_PARAMS_WITH_PARTITION = ( + { # CHANGE ME TO RUN THE DBAPI PARAMETERS WITH PARTITION YOU WANT + "column": "id", + "lower_bound": 1000, + "upper_bound": 9000, + "num_partitions": 10, + } +) + +# Cleanup configuration +# Set to False to keep target tables for debugging +CLEANUP_TARGET_TABLES = True + +# Show target table info before cleanup (first row + count) +# Set to False to skip showing table info +SHOW_TARGET_TABLE_INFO = True + +# Export test results to CSV file +# Set to False to skip CSV export +EXPORT_RESULTS_TO_CSV = True + +# Package requirements for stored procedures by DBMS type +SPROC_PACKAGES = { + "mysql": ["pymysql"], + "postgres": ["psycopg2"], + "mssql": ["pyodbc", "msodbcsql"], + "oracle": ["oracledb"], + "databricks": ["databricks-sql-connector"], +} + +# JDBC driver JAR filenames (place these in drivers/ directory) +JDBC_DRIVER_JARS = {} # CHANGE ME TO RUN THE JDBC DRIVER JARS YOU WANT +# *** EXAMPLE *** +# JDBC_DRIVER_JARS = { +# "mysql": "mysql-connector-j-9.1.0.jar", +# "postgres": "postgresql-42.7.7.jar", +# "mssql": "mssql-jdbc-12.8.1.jre11.jar", +# "oracle": "ojdbc17-23.9.0.25.07.jar", +# "databricks": "DatabricksJDBC42-2.6.40.jar", +# } + + +# Snowflake secrets containing DB credentials (format: schema.secret_name) +# Users should create these secrets with USERNAME_PASSWORD_SECRET type +JDBC_SECRETS = { + "mysql": os.getenv("MYSQL_SECRET", "ADMIN.PUBLIC.SNOWPARK_DBAPI_MYSQL_TEST_CRED"), + "postgres": os.getenv( + "POSTGRES_SECRET", "ADMIN.PUBLIC.SNOWPARK_DBAPI_POSTGRES_TEST_CRED" + ), + "mssql": os.getenv( + "MSSQL_SECRET", "ADMIN.PUBLIC.SNOWPARK_DBAPI_SQL_SERVER_TEST_CRED" + ), + "oracle": os.getenv( + "ORACLE_SECRET", "ADMIN.PUBLIC.SNOWPARK_DBAPI_ORACLEDB_TEST_CRED" + ), + "databricks": os.getenv( + "DATABRICKS_SECRET", "ADMIN.PUBLIC.SNOWPARK_DBAPI_DATABRICKS_TEST_CRED" + ), +} + +# JDBC connection properties (optional, per-database settings) +JDBC_PROPERTIES = { + "mysql": {"useSSL": "false"}, + "postgres": {"ssl": "false"}, + "mssql": {"trustServerCertificate": "true"}, + "oracle": {}, + "databricks": {}, +} + +# JDBC driver class names (for PySpark and other JDBC-based methods) +JDBC_DRIVER_CLASSES = { + "mysql": "com.mysql.cj.jdbc.Driver", + "postgres": "org.postgresql.Driver", + "mssql": "com.microsoft.sqlserver.jdbc.SQLServerDriver", + "oracle": "oracle.jdbc.driver.OracleDriver", + "databricks": "com.databricks.client.jdbc.Driver", +} + +# PySpark session configuration +# Note: PySpark uses plain credentials from .env (not Snowflake secrets) +# and runs on a local Spark session +PYSPARK_SESSION_CONFIG = { + "spark.master": "local[*]", # Use all available cores + "spark.driver.extraClassPath": str( + os.path.join(os.path.dirname(__file__), "drivers", "*") + ), # Path to JDBC JARs + # Optional: Uncomment and adjust for parallelism optimization + # "spark.sql.shuffle.partitions": 16, + # "spark.default.parallelism": 16, + # "spark.executor.cores": 8, + # "spark.executor.memory": "16g", +} + +# UDTF configuration (for udtf and udtf_sproc methods) +# Each DBMS needs its own external access integration +# Names match the existing test integrations in tests/resources/test_data_source_dir/ +UDTF_CONFIGS = { + "mysql": { + "external_access_integration": "snowpark_dbapi_mysql_test_integration", + }, + "postgres": { + "external_access_integration": "snowpark_dbapi_postgres_test_integration", + }, + "mssql": { + "external_access_integration": "snowpark_dbapi_sql_server_test_integration", + }, + "oracle": { + "external_access_integration": "snowpark_dbapi_oracledb_test_integration", + }, + "databricks": { + "external_access_integration": "snowpark_dbapi_databricks_test_integration", + }, +} + +# Test matrix - define which tests to run +# Each test config format: +# { +# "dbms": "mysql", +# "source": {"type": "table|query", "value": "..."}, +# "ingestion_method": "local|udtf|local_sproc|udtf_sproc|jdbc|jdbc_sproc|pyspark", +# "dbapi_params": {...}, +# } +DBMS_LIST = [ + "mysql", + "postgres", + "mssql", + "oracle", + "databricks", +] # CHANGE ME TO RUN THE DBMS YOU WANT, full list: mysql, postgres, mssql, oracle, databricks +METHODS = [ + "local", + "udtf", +] # CHANGE ME TO RUN THE METHODS YOU WANT, full list: local, udtf, local_sproc, udtf_sproc, jdbc, jdbc_sproc, pyspark + +# Generate test matrix: table-based and query-based tests +TEST_MATRIX = [ + # Table-based tests + *[ + { + "dbms": dbms, + "source": {"type": "table", "value": "DBAPI_TEST_TABLE"}, + "ingestion_method": method, + } + for dbms in DBMS_LIST + for method in METHODS + ], + # Query-based tests + *[ + { + "dbms": dbms, + "source": {"type": "query", "value": "SELECT * FROM DBAPI_TEST_TABLE"}, + "ingestion_method": method, + } + for dbms in DBMS_LIST + for method in METHODS + ], +] + +# large query test matrix +TEST_MATRIX_LARGE_QUERY = [ + { + "dbms": dbms, + "source": { + "type": "query", + "value": get_large_query(dbms, "100k"), + }, # other options: "1m", "10m", "100m", "1b", "10b" + "ingestion_method": method, + "dbapi_params": DBAPI_PARAMS_WITH_PARTITION, # CHANGE ME TO RUN THE DBAPI PARAMETERS WITH PARTITION YOU WANT + } + for dbms in DBMS_LIST + for method in METHODS +] + + +# Simple single test config (used by main.py if TEST_MATRIX is not used) +SINGLE_TEST_CONFIG = { + "dbms": "databricks", + "source": {"type": "query", "value": "SELECT * FROM DBAPI_TEST_TABLE"}, + "ingestion_method": "udtf", + "dbapi_params": DBAPI_PARAMS, # CHANGE ME TO SETTINGS YOU WANT +} diff --git a/tests/perf/data_source/dbapi_test_framework/connections.py b/tests/perf/data_source/dbapi_test_framework/connections.py new file mode 100644 index 0000000000..1951986fc1 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/connections.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Connection factory functions for different DBMS types. + +Each function returns a callable that creates a DBAPI 2.0 connection. +This callable is passed to session.read.dbapi(). +""" + + +def create_mysql_connection(host, port, user, password, database, **kwargs): + """Create MySQL connection factory.""" + import pymysql + + def _connect(): + return pymysql.connect( + host=host, + port=port, + user=user, + password=password, + database=database, + ) + + return _connect + + +def create_postgres_connection(host, port, user, password, database, **kwargs): + """Create PostgreSQL connection factory.""" + import psycopg2 + + def _connect(): + return psycopg2.connect( + host=host, + port=port, + user=user, + password=password, + database=database, + ) + + return _connect + + +def create_mssql_connection( + host, port, user, password, database, driver=None, **kwargs +): + """Create MS SQL Server connection factory.""" + import pyodbc + + driver = driver or "{ODBC Driver 18 for SQL Server}" + + def _connect(): + connection_string = ( + f"DRIVER={driver};" + f"SERVER={host},{port};" + f"DATABASE={database};" + f"UID={user};" + f"PWD={password};" + f"TrustServerCertificate=yes;" + ) + return pyodbc.connect(connection_string) + + return _connect + + +def create_oracle_connection(host, port, user, password, service_name, **kwargs): + """Create Oracle connection factory.""" + import oracledb + + def _connect(): + dsn = f"{host}:{port}/{service_name}" + return oracledb.connect( + user=user, + password=password, + dsn=dsn, + ) + + return _connect + + +def create_databricks_connection(server_hostname, http_path, access_token, **kwargs): + """Create Databricks connection factory.""" + from databricks import sql + + def _connect(): + return sql.connect( + server_hostname=server_hostname, + http_path=http_path, + access_token=access_token, + ) + + return _connect + + +# Registry mapping DBMS names to connection factory functions +CONNECTION_FACTORIES = { + "mysql": create_mysql_connection, + "postgres": create_postgres_connection, + "postgresql": create_postgres_connection, + "mssql": create_mssql_connection, + "sqlserver": create_mssql_connection, + "oracle": create_oracle_connection, + "databricks": create_databricks_connection, + "dbx": create_databricks_connection, +} + + +def get_connection_factory(dbms_type, params): + """ + Get connection factory for a given DBMS type. + + Args: + dbms_type: Type of DBMS (mysql, postgres, mssql, oracle, databricks) + params: Dict of connection parameters + + Returns: + Callable that creates a DBAPI connection + """ + dbms_type = dbms_type.lower() + + if dbms_type not in CONNECTION_FACTORIES: + raise ValueError( + f"Unknown DBMS type: {dbms_type}. " + f"Supported: {', '.join(CONNECTION_FACTORIES.keys())}" + ) + + factory_func = CONNECTION_FACTORIES[dbms_type] + return factory_func(**params) + + +def get_jdbc_url(dbms_type, params): + """ + Generate JDBC connection URL from database parameters. + + Args: + dbms_type: Type of DBMS (mysql, postgres, mssql, oracle, databricks) + params: Dict of connection parameters + + Returns: + JDBC connection URL string + """ + dbms_type = dbms_type.lower() + + if dbms_type == "mysql": + return f"jdbc:mysql://{params['host']}:{params['port']}/{params['database']}" + + elif dbms_type in ("postgres", "postgresql"): + return ( + f"jdbc:postgresql://{params['host']}:{params['port']}/{params['database']}" + ) + + elif dbms_type in ("mssql", "sqlserver"): + return f"jdbc:sqlserver://{params['host']}:{params['port']};databaseName={params['database']}" + + elif dbms_type == "oracle": + return f"jdbc:oracle:thin:@//{params['host']}:{params['port']}/{params['service_name']}" + + elif dbms_type in ("databricks", "dbx"): + return f"jdbc:databricks://{params['server_hostname']};httpPath={params['http_path']}" + + else: + raise ValueError(f"Unknown DBMS type for JDBC: {dbms_type}") + + +def get_jdbc_driver_class(dbms_type): + """ + Get JDBC driver class name for a given DBMS type. + + Args: + dbms_type: Type of DBMS (mysql, postgres, mssql, oracle, databricks) + + Returns: + JDBC driver class name string + """ + # Support both direct execution and module import + try: + from . import config + except ImportError: + import config + + dbms_type = dbms_type.lower() + + if dbms_type not in config.JDBC_DRIVER_CLASSES: + raise ValueError( + f"Unknown DBMS type for JDBC driver class: {dbms_type}. " + f"Supported: {', '.join(config.JDBC_DRIVER_CLASSES.keys())}" + ) + + return config.JDBC_DRIVER_CLASSES[dbms_type] diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/README.md b/tests/perf/data_source/dbapi_test_framework/db_setup_util/README.md new file mode 100644 index 0000000000..91b73077c1 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/README.md @@ -0,0 +1,200 @@ +# Database Setup Utilities + +This directory contains scripts to set up identical test tables across multiple database systems for DBAPI ingestion testing. + +## Overview + +All scripts create a table named `DBAPI_TEST_TABLE` with 15 identical columns and insert 10,000 rows of deterministic test data. + +### Table Schema + +| Column Name | Type | Description | +|----------------|-------------------|--------------------------------------| +| id | AUTO_INCREMENT PK | Primary key (auto-generated) | +| int_col | INTEGER | Integer values | +| bigint_col | BIGINT | Large integer values | +| smallint_col | SMALLINT | Small integer values | +| float_col | FLOAT/DOUBLE | Floating point values | +| decimal_col | DECIMAL(10,2) | Fixed precision decimal | +| boolean_col | BOOLEAN/BIT | True/False values | +| varchar_col | VARCHAR(100) | Variable length strings | +| char_col | CHAR(10) | Fixed length strings | +| text_col | TEXT/CLOB | Large text content | +| date_col | DATE | Date values | +| timestamp_col | TIMESTAMP | Timestamp values | +| binary_col | BINARY(16) | Binary data (16 bytes) | +| json_col | JSON/STRING | JSON formatted strings | +| uuid_col | VARCHAR(36) | UUID strings | + +## Prerequisites + +### Python Packages + +Install required database drivers: + +```bash +pip install pymysql # MySQL +pip install psycopg2-binary # PostgreSQL +pip install pyodbc # SQL Server +pip install oracledb # Oracle +pip install databricks-sql-connector # Databricks +pip install python-dotenv # For .env support +``` + +### Environment Setup + +Create a `.env` file in the `dbapi_test_framework/` directory with your database credentials: + +```env +# MySQL +MYSQL_HOST=your-mysql-host +MYSQL_PORT=3306 +MYSQL_USERNAME=your-username +MYSQL_PASSWORD=your-password +MYSQL_DATABASE=your-database + +# PostgreSQL +POSTGRES_HOST=your-postgres-host +POSTGRES_PORT=5432 +POSTGRES_USER=your-username +POSTGRES_PASSWORD=your-password +POSTGRES_DBNAME=your-database + +# SQL Server +MSSQL_SERVER=your-sqlserver-host +MSSQL_PORT=1433 +MSSQL_DATABASE=test_db +MSSQL_UID=your-username +MSSQL_PWD=your-password +MSSQL_DRIVER={ODBC Driver 18 for SQL Server} + +# Oracle +ORACLEDB_HOST=your-oracle-host +ORACLEDB_PORT=1521 +ORACLEDB_SERVICE_NAME=your-service-name +ORACLEDB_USERNAME=your-username +ORACLEDB_PASSWORD=your-password + +# Databricks +DATABRICKS_SERVER_HOSTNAME=your-workspace.databricks.net +DATABRICKS_HTTP_PATH=sql/protocolv1/o/... +DATABRICKS_ACCESS_TOKEN=your-access-token +``` + +## Architecture + +The setup utilities follow a clean, modular design: + +- **`common_schema.py`**: Defines the unified schema, type mappings for all DBMS, and deterministic data generator +- **`base_setup.py`**: Base `DatabaseSetup` class with common logic (create table, insert data, etc.) +- **Individual setup scripts** (mysql_setup.py, etc.): Thin wrappers that handle connection and delegate to base class +- **`setup_all.py`**: Orchestrator to run all setups + +This design eliminates code duplication and uses proper Python imports (no `sys.path` manipulation). + +## Usage + +### Setup Individual Database + +Run as Python modules from the project root: + +```bash +# From project root +python3 -m tests.perf.data_source.dbapi_test_framework.db_setup_util.mysql_setup +python3 -m tests.perf.data_source.dbapi_test_framework.db_setup_util.postgres_setup +python3 -m tests.perf.data_source.dbapi_test_framework.db_setup_util.oracle_setup +python3 -m tests.perf.data_source.dbapi_test_framework.db_setup_util.mssql_setup +python3 -m tests.perf.data_source.dbapi_test_framework.db_setup_util.databricks_setup +``` + +Or from the db_setup_util directory: + +```bash +cd tests/perf/data_source/dbapi_test_framework/db_setup_util +python3 mysql_setup.py +# etc. +``` + +### Setup All Databases + +Run all setups at once: + +```bash +python3 -m tests.perf.data_source.dbapi_test_framework.db_setup_util.setup_all +``` + +This will attempt to set up all databases and provide a summary report. + +## Data Determinism + +All scripts use the same seed (42) to generate deterministic data. This means: +- Row #0 in MySQL has **exactly** the same values as row #0 in PostgreSQL, Oracle, etc. +- Row #1 in all databases has the same values +- And so on... + +This is crucial for testing data consistency across different DBAPI implementations. + +## Large Query Generation + +For performance testing with larger datasets, see the `large_query_generation/` subdirectory. + +These utilities generate SQL queries that multiply the 10k base table by factor k without storing additional data: +- k=100 → 1M rows +- k=1000 → 10M rows +- k=100000 → 1B rows + +Pre-built templates available for: 100k, 1M, 10M, 100M, 1B, 10B rows. + +See `large_query_generation/README.md` for details. + +## Customization + +### Change Number of Rows + +Edit the script or modify `DEFAULT_ROWS` in `common_schema.py`: + +```python +from db_setup_util.common_schema import DEFAULT_ROWS + +# In individual script +insert_data(conn, num_rows=50000) # Insert 50k rows instead +``` + +### Change Table Name + +Modify `TABLE_NAME` in `common_schema.py`: + +```python +TABLE_NAME = "MY_CUSTOM_TABLE" +``` + +## Verification + +After running setup, verify the table: + +```sql +-- Check row count +SELECT COUNT(*) FROM DBAPI_TEST_TABLE; +-- Should return: 10000 + +-- Sample data +SELECT * FROM DBAPI_TEST_TABLE LIMIT 5; +``` + +## Troubleshooting + +### Connection Issues + +- Verify `.env` file exists and has correct credentials +- Check network connectivity to database servers +- Ensure database drivers are installed + +### Permission Issues + +- Ensure user has CREATE TABLE and INSERT privileges +- For Oracle, may need additional tablespace permissions + +### Data Type Issues + +- Some databases (Oracle) use alternative types (NUMBER for boolean) +- Scripts handle these conversions automatically diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/__init__.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/__init__.py new file mode 100644 index 0000000000..22d9f26003 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/__init__.py @@ -0,0 +1,23 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Database setup utilities for DBAPI testing.""" + +from .common_schema import ( + TABLE_NAME, + COLUMN_NAMES, + generate_row_data, + DEFAULT_ROWS, + TYPE_MAPPINGS, +) +from .base_setup import DatabaseSetup + +__all__ = [ + "TABLE_NAME", + "COLUMN_NAMES", + "generate_row_data", + "DEFAULT_ROWS", + "TYPE_MAPPINGS", + "DatabaseSetup", +] diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/base_setup.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/base_setup.py new file mode 100644 index 0000000000..c1ad503e52 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/base_setup.py @@ -0,0 +1,127 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Base class for database setup with common logic.""" + +# Support both direct execution and module import +try: + from .common_schema import ( + TABLE_NAME, + COLUMN_NAMES, + TYPE_MAPPINGS, + generate_row_data, + DEFAULT_ROWS, + ) +except ImportError: + from common_schema import ( + TABLE_NAME, + COLUMN_NAMES, + TYPE_MAPPINGS, + generate_row_data, + DEFAULT_ROWS, + ) + + +class DatabaseSetup: + """Base class handling common setup logic for all DBMS.""" + + def __init__(self, connection, dbms_type, placeholder_style="?") -> None: + """ + Initialize database setup. + + Args: + connection: DBAPI 2.0 connection object + dbms_type: Type of DBMS (mysql, postgres, oracle, mssql, databricks) + placeholder_style: Placeholder style ('?', '%s', or ':1' for positional) + """ + self.connection = connection + self.dbms_type = dbms_type.lower() + self.placeholder_style = placeholder_style + self.type_mapping = TYPE_MAPPINGS[self.dbms_type] + + def generate_create_table_sql(self, table_name=TABLE_NAME): + """Generate CREATE TABLE SQL for this DBMS.""" + columns = ["id " + self.type_mapping["id"]] + for col_name in COLUMN_NAMES: + columns.append(f"{col_name} {self.type_mapping[col_name]}") + + columns_str = ",\n ".join(columns) + return f"CREATE TABLE {table_name} (\n {columns_str}\n)" + + def create_table(self, table_name=TABLE_NAME, drop_if_exists=True): + """Create test table.""" + cursor = self.connection.cursor() + + if drop_if_exists: + self._drop_table(cursor, table_name) + + create_sql = self.generate_create_table_sql(table_name) + cursor.execute(create_sql) + self.connection.commit() + print(f"Created table {table_name}") + cursor.close() + + def _drop_table(self, cursor, table_name): + """Drop table if exists (DBMS-specific syntax).""" + try: + if self.dbms_type == "oracle": + cursor.execute(f"DROP TABLE {table_name}") + else: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + print(f"Dropped existing table {table_name}") + except Exception: + pass # Table doesn't exist + + def prepare_row_data(self, row): + """Prepare row data for insertion (handle DBMS-specific conversions).""" + if self.dbms_type == "oracle": + # Convert boolean to 1/0 + row = list(row) + row[5] = 1 if row[5] else 0 # boolean_col is at index 5 + return tuple(row) + return row + + def generate_placeholders(self): + """Generate placeholder string based on DBMS style.""" + if self.placeholder_style == ":": + # Oracle style :1, :2, :3 + return ", ".join([f":{i+1}" for i in range(len(COLUMN_NAMES))]) + elif self.placeholder_style == "%s": + # PostgreSQL/MySQL style + return ", ".join(["%s"] * len(COLUMN_NAMES)) + else: + # Standard ? style (pyodbc) + return ", ".join(["?"] * len(COLUMN_NAMES)) + + def insert_data( + self, num_rows=DEFAULT_ROWS, table_name=TABLE_NAME, batch_size=1000 + ): + """Insert deterministic test data.""" + cursor = self.connection.cursor() + + columns = ", ".join(COLUMN_NAMES) + placeholders = self.generate_placeholders() + insert_sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})" + + # Insert in batches + for batch_start in range(0, num_rows, batch_size): + batch_end = min(batch_start + batch_size, num_rows) + batch_data = [ + self.prepare_row_data(generate_row_data(i)) + for i in range(batch_start, batch_end) + ] + cursor.executemany(insert_sql, batch_data) + self.connection.commit() + print(f"Inserted rows {batch_start} to {batch_end-1}") + + # Verify count + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + count = cursor.fetchone()[0] + print(f"Total rows in {table_name}: {count}") + cursor.close() + + def run_setup(self, num_rows=DEFAULT_ROWS, table_name=TABLE_NAME): + """Run complete setup (create table + insert data).""" + self.create_table(table_name) + self.insert_data(num_rows, table_name) diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/common_schema.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/common_schema.py new file mode 100644 index 0000000000..4932a31217 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/common_schema.py @@ -0,0 +1,204 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Common schema and deterministic data generator for DBAPI test tables. +All DBMS will have the same table with identical data. +""" + +import random +import datetime +import uuid + +# Table configuration +TABLE_NAME = "DBAPI_TEST_TABLE" +DEFAULT_ROWS = 10_000 +SEED = 42 + +# 15 columns that all DBMS support (excluding auto-increment id) +COLUMN_NAMES = [ + "int_col", + "bigint_col", + "smallint_col", + "float_col", + "decimal_col", + "boolean_col", + "varchar_col", + "char_col", + "text_col", + "date_col", + "timestamp_col", + "binary_col", + "json_col", + "uuid_col", +] + +# Type mappings for each DBMS +TYPE_MAPPINGS = { + "mysql": { + "id": "INT AUTO_INCREMENT PRIMARY KEY", + "int_col": "INT", + "bigint_col": "BIGINT", + "smallint_col": "SMALLINT", + "float_col": "FLOAT", + "decimal_col": "DECIMAL(10, 2)", + "boolean_col": "BOOLEAN", + "varchar_col": "VARCHAR(100)", + "char_col": "CHAR(10)", + "text_col": "TEXT", + "date_col": "DATE", + "timestamp_col": "TIMESTAMP", + "binary_col": "BINARY(16)", + "json_col": "JSON", + "uuid_col": "VARCHAR(36)", + }, + "postgres": { + "id": "SERIAL PRIMARY KEY", + "int_col": "INTEGER", + "bigint_col": "BIGINT", + "smallint_col": "SMALLINT", + "float_col": "DOUBLE PRECISION", + "decimal_col": "DECIMAL(10, 2)", + "boolean_col": "BOOLEAN", + "varchar_col": "VARCHAR(100)", + "char_col": "CHAR(10)", + "text_col": "TEXT", + "date_col": "DATE", + "timestamp_col": "TIMESTAMP", + "binary_col": "BYTEA", + "json_col": "JSONB", + "uuid_col": "VARCHAR(36)", + }, + "oracle": { + "id": "NUMBER GENERATED AS IDENTITY PRIMARY KEY", + "int_col": "NUMBER(10)", + "bigint_col": "NUMBER(19)", + "smallint_col": "NUMBER(5)", + "float_col": "BINARY_DOUBLE", + "decimal_col": "NUMBER(10, 2)", + "boolean_col": "NUMBER(1)", + "varchar_col": "VARCHAR2(100)", + "char_col": "CHAR(10)", + "text_col": "CLOB", + "date_col": "DATE", + "timestamp_col": "TIMESTAMP", + "binary_col": "RAW(16)", + "json_col": "CLOB", + "uuid_col": "VARCHAR2(36)", + }, + "mssql": { + "id": "INT IDENTITY(1,1) PRIMARY KEY", + "int_col": "INT", + "bigint_col": "BIGINT", + "smallint_col": "SMALLINT", + "float_col": "FLOAT", + "decimal_col": "DECIMAL(10, 2)", + "boolean_col": "BIT", + "varchar_col": "VARCHAR(100)", + "char_col": "CHAR(10)", + "text_col": "TEXT", + "date_col": "DATE", + "timestamp_col": "DATETIME2", + "binary_col": "BINARY(16)", + "json_col": "NVARCHAR(MAX)", + "uuid_col": "VARCHAR(36)", + }, + "databricks": { + "id": "BIGINT GENERATED ALWAYS AS IDENTITY", # No PRIMARY KEY (requires Unity Catalog) + "int_col": "INT", + "bigint_col": "BIGINT", + "smallint_col": "SMALLINT", + "float_col": "DOUBLE", + "decimal_col": "DECIMAL(10, 2)", + "boolean_col": "BOOLEAN", + "varchar_col": "STRING", + "char_col": "STRING", + "text_col": "STRING", + "date_col": "DATE", + "timestamp_col": "TIMESTAMP", + "binary_col": "BINARY", + "json_col": "STRING", + "uuid_col": "STRING", + }, +} + + +def generate_row_data(row_num): + """ + Generate deterministic test data for a given row number. + Same row_num always produces same values across all runs. + + Args: + row_num: Row number (0-based) + + Returns: + Tuple of 14 values (excluding auto-increment id) + """ + # Set seed based on row number for deterministic data + rng = random.Random(SEED + row_num) + + # Generate exact same data for this row_num every time + int_col = rng.randint(1, 100000) + bigint_col = rng.randint(1, 9999999999) + smallint_col = rng.randint(1, 30000) + float_col = round(rng.uniform(1.0, 10000.0), 4) + decimal_col = round(rng.uniform(1.0, 10000.0), 2) + boolean_col = rng.choice([True, False]) + + # String data + varchar_col = f"varchar_{row_num}_{rng.randint(1000, 9999)}" + char_col = f"char{row_num % 1000}".ljust(10)[:10] + text_col = f"Text content for row {row_num}. " + "Lorem ipsum " * rng.randint(5, 15) + + # Date/time data + date_col = datetime.date(2024, (row_num % 12) + 1, (row_num % 28) + 1) + timestamp_col = datetime.datetime( + 2024, + (row_num % 12) + 1, + (row_num % 28) + 1, + row_num % 24, + row_num % 60, + row_num % 60, + ) + + # Binary data - deterministic 16 bytes + binary_col = bytes(((row_num + i) * 17) % 256 for i in range(16)) + + # JSON data as string (some DBMS need string, others support JSON) + json_col = f'{{"row": {row_num}, "value": {rng.randint(1, 1000)}, "key": "test_{row_num}"}}' + + # UUID - deterministic based on row_num + uuid_bytes = bytes(((row_num + i) * 13) % 256 for i in range(16)) + uuid_col = str(uuid.UUID(bytes=uuid_bytes)) + + return ( + int_col, + bigint_col, + smallint_col, + float_col, + decimal_col, + boolean_col, + varchar_col, + char_col, + text_col, + date_col, + timestamp_col, + binary_col, + json_col, + uuid_col, + ) + + +def print_sample_data(): + """Print sample data for verification.""" + print(f"Sample data for {TABLE_NAME}:") + print(f"Columns: {', '.join(COLUMN_NAMES)}") + print("\nFirst 3 rows:") + for i in range(3): + data = generate_row_data(i) + print(f"Row {i}: {data[:5]}...") # Print first 5 values + + +if __name__ == "__main__": + print_sample_data() diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/databricks_setup.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/databricks_setup.py new file mode 100644 index 0000000000..686faba16b --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/databricks_setup.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Databricks test table setup.""" + +from databricks import sql + +# Support both direct execution and module import +try: + from .base_setup import DatabaseSetup + from .common_schema import TABLE_NAME, COLUMN_NAMES, generate_row_data, DEFAULT_ROWS +except ImportError: + from base_setup import DatabaseSetup + from common_schema import TABLE_NAME, COLUMN_NAMES, generate_row_data, DEFAULT_ROWS + + +class DatabricksSetup(DatabaseSetup): + """Databricks-specific setup with custom insert logic.""" + + def insert_data( + self, num_rows=DEFAULT_ROWS, table_name=TABLE_NAME, batch_size=1000 + ): + """Insert data using multi-row INSERT (Databricks doesn't support executemany well).""" + cursor = self.connection.cursor() + columns = ", ".join(COLUMN_NAMES) + + for batch_start in range(0, num_rows, batch_size): + batch_end = min(batch_start + batch_size, num_rows) + + # Build multi-row INSERT + values_list = [] + for i in range(batch_start, batch_end): + row = generate_row_data(i) + values = [] + for val in row: + if val is None: + values.append("NULL") + elif isinstance(val, bool): + values.append("TRUE" if val else "FALSE") + elif isinstance(val, (int, float)): + values.append(str(val)) + elif isinstance(val, bytes): + values.append(f"X'{val.hex()}'") + elif isinstance(val, str): + escaped = val.replace("'", "''") + values.append(f"'{escaped}'") + else: + values.append(f"'{str(val)}'") + values_list.append(f"({', '.join(values)})") + + insert_sql = ( + f"INSERT INTO {table_name} ({columns}) VALUES {', '.join(values_list)}" + ) + cursor.execute(insert_sql) + print(f"Inserted rows {batch_start} to {batch_end-1}") + + # Verify count + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + count = cursor.fetchone()[0] + print(f"Total rows in {table_name}: {count}") + cursor.close() + + +def get_connection(params): + """Create Databricks connection.""" + return sql.connect( + server_hostname=params["server_hostname"], + http_path=params["http_path"], + access_token=params["access_token"], + ) + + +def main(params=None): + """Main setup function.""" + if params is None: + try: + from ..config import DATABRICKS_PARAMS as params + except ImportError: + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent.parent)) + from config import DATABRICKS_PARAMS as params + + print("=" * 60) + print("Databricks Database Setup") + print("=" * 60) + + print(f"Connecting to Databricks at {params['server_hostname']}...") + conn = get_connection(params) + print("Connected!") + + # Use Databricks-specific setup + setup = DatabricksSetup(conn, dbms_type="databricks") + setup.run_setup(num_rows=DEFAULT_ROWS) + + conn.close() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/README.md b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/README.md new file mode 100644 index 0000000000..3bfb305eef --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/README.md @@ -0,0 +1,152 @@ +# Large Query Generation + +Utilities for generating and validating large result sets from small tables for performance testing. + +## Overview + +These utilities help you generate SQL queries that multiply a 10k row table by factor `k` to create large result sets for testing, without actually storing large amounts of data. + +## Files + +- **`generate_large_queries.py`** - Core query generator with DBMS-specific syntax +- **`query_templates.py`** - Pre-built templates for common sizes (100k, 1M, 10M, 100M, 1B, 10B) +- **`validate_queries.py`** - Validation tool to verify queries work correctly +- **`__init__.py`** - Package exports + +## Quick Start + +### Generate a Query + +```python +from large_query_generation import generate_large_query + +# Generate 1M row query for PostgreSQL +query = generate_large_query("postgres", "DBAPI_TEST_TABLE", k=100) + +# Use in test config +test_config = { + "dbms": "postgres", + "source": {"type": "query", "value": query}, + "ingestion_method": "local" +} +``` + +### Use Pre-built Templates + +```python +from large_query_generation import get_large_query + +# Get 1M row template +query_1m = get_large_query("postgres", "1m") + +# Available sizes: "100k", "1m", "10m", "100m", "1b", "10b" +``` + +### Validate Queries + +```bash +# Validate all DBMS with k=10 (100k rows each) +python3 validate_queries.py -k 10 + +# Validate specific DBMS +python3 validate_queries.py --dbms postgres -k 100 +``` + +## How It Works + +Each DBMS uses optimized syntax to generate a number sequence and CROSS JOIN with the base table: + +| DBMS | Method | Query Pattern | +|------|--------|---------------| +| MySQL | Recursive CTE | `WITH RECURSIVE numbers AS (...) SELECT t.* FROM table t CROSS JOIN numbers` | +| PostgreSQL | generate_series() | `SELECT t.* FROM table t CROSS JOIN generate_series(1, k)` | +| SQL Server | VALUES + Factorization | `SELECT t.* FROM table t CROSS JOIN (VALUES ...) × CROSS JOIN (VALUES ...)` | +| Oracle | CONNECT BY | `SELECT t.* FROM table t CROSS JOIN (SELECT LEVEL FROM DUAL CONNECT BY LEVEL <= k)` | +| Databricks | explode(sequence) | `SELECT t.* FROM table t CROSS JOIN (SELECT explode(sequence(1, k)))` | + +**Note:** SQL Server uses smart factorization (e.g., k=1000 = 100×10) to avoid CTEs, making queries subquery-compatible. + +## Multiplication Factors + +| k Value | Result Size | Template | Use Case | +|---------|-------------|----------|----------| +| 10 | 100,000 | "100k" | Quick test | +| 100 | 1,000,000 | "1m" | Small test | +| 1,000 | 10,000,000 | "10m" | Medium test | +| 10,000 | 100,000,000 | "100m" | Large test | +| 100,000 | 1,000,000,000 | "1b" | Stress test | +| 1,000,000 | 10,000,000,000 | "10b" | Extreme test | + +## Examples + +### Example 1: Dynamic Generation + +```python +from db_setup_util.large_query_generation import generate_large_query + +# Different sizes for different DBMS +TEST_MATRIX = [ + { + "dbms": "mysql", + "source": { + "type": "query", + "value": generate_large_query("mysql", k=100) # 1M rows + }, + "ingestion_method": "local" + }, + { + "dbms": "postgres", + "source": { + "type": "query", + "value": generate_large_query("postgres", k=1000) # 10M rows + }, + "ingestion_method": "udtf" + }, +] +``` + +### Example 2: Using Templates + +```python +from db_setup_util.large_query_generation import get_large_query + +# Simple template usage +TEST_MATRIX = [ + {"dbms": "postgres", "source": {"type": "query", "value": get_large_query("postgres", "1m")}, "ingestion_method": "local"}, + {"dbms": "mysql", "source": {"type": "query", "value": get_large_query("mysql", "10m")}, "ingestion_method": "udtf"}, +] +``` + +## Validation + +The validation script checks: +- ✅ Query executes without errors +- ✅ Returns expected schema (same columns as base table) +- ✅ Can be wrapped in subquery (SELECT * FROM (query) AS subquery) +- ✅ Expected row count is calculated (10k × k) + +```bash +python3 validate_queries.py -k 100 +``` + +Output: +``` +VALIDATION SUMMARY +====================================================================== +DBMS Expected Syntax Columns Subquery Status +---------------------------------------------------------------------- +mysql 1,000,000 ✓ ✓ ✓ ✓ PASS +postgres 1,000,000 ✓ ✓ ✓ ✓ PASS +mssql 1,000,000 ✓ ✓ ✓ ✓ PASS +oracle 1,000,000 ✓ ✓ ✓ ✓ PASS +databricks 1,000,000 ✓ ✓ ✓ ✓ PASS +---------------------------------------------------------------------- +Overall: ✓ ALL PASSED +``` + +## Notes + +- **Efficient**: Database generates rows on-the-fly, no storage needed +- **Deterministic**: Same base table = same multiplied results +- **Scalable**: Can generate up to 100M+ rows from 10k base +- **Flexible**: Customize k value for any test size diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/__init__.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/__init__.py new file mode 100644 index 0000000000..783f022379 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Large query generation utilities for performance testing.""" + +from .generate_large_queries import generate_large_query, print_all_queries +from .query_templates import get_large_query, LARGE_QUERY_TEMPLATES + +__all__ = [ + "generate_large_query", + "print_all_queries", + "get_large_query", + "LARGE_QUERY_TEMPLATES", +] diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/generate_large_queries.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/generate_large_queries.py new file mode 100644 index 0000000000..e936fe29c1 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/generate_large_queries.py @@ -0,0 +1,169 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Generate SQL queries that multiply 10k rows by k factor for performance testing. +Each query returns the same 10k rows repeated k times. +""" + + +def generate_large_query(dbms, table_name="DBAPI_TEST_TABLE", k=100): + """ + Generate a SQL query that returns 10k rows repeated k times. + + Args: + dbms: Database type (mysql, postgres, mssql, oracle, databricks) + table_name: Base table name + k: Multiplication factor (1 = 10k rows, 100 = 1M rows, 1000 = 10M rows) + + Returns: + SQL query string that returns 10k * k rows + """ + dbms = dbms.lower() + + if dbms == "mysql": + # MySQL: Use recursive CTE to generate numbers, then CROSS JOIN + query = f""" +WITH RECURSIVE numbers AS ( + SELECT 1 AS n + UNION ALL + SELECT n + 1 FROM numbers WHERE n < {k} +) +SELECT t.* +FROM {table_name} t +CROSS JOIN numbers +""" + + elif dbms == "postgres" or dbms == "postgresql": + # PostgreSQL: Use generate_series() + query = f""" +SELECT t.* +FROM {table_name} t +CROSS JOIN generate_series(1, {k}) AS multiplier(n) +""" + + elif dbms == "mssql" or dbms == "sqlserver": + # SQL Server: Use VALUES with CROSS JOIN for multiplication (works in subqueries, no CTE) + # Decompose k into smaller factors to avoid huge VALUES clauses + # e.g., k=100 = 10 × 10, k=1000 = 10 × 10 × 10, k=10000 = 100 × 100 + # Decompose k into factors (prefer 10s and 100s for readability) + factors = [] + remaining = k + + while remaining > 1: + if remaining % 100 == 0 and remaining >= 100: + factors.append(100) + remaining //= 100 + elif remaining % 10 == 0 and remaining >= 10: + factors.append(10) + remaining //= 10 + else: + # For non-round numbers, just use the remainder + factors.append(remaining) + remaining = 1 + + # If no factors or just 1, use simple VALUES + if not factors or (len(factors) == 1 and factors[0] <= 100): + factor = factors[0] if factors else k + values_list = ", ".join([f"({i})" for i in range(1, factor + 1)]) + query = f""" +SELECT t.* +FROM {table_name} t +CROSS JOIN (VALUES {values_list}) AS multiplier(n) +""" + else: + # Generate CROSS JOINs + crosses = [] + for i, factor in enumerate(factors): + values_list = ", ".join([f"({j})" for j in range(1, factor + 1)]) + crosses.append(f"CROSS JOIN (VALUES {values_list}) AS n{i}(v{i})") + + query = f""" +SELECT t.* +FROM {table_name} t +{' '.join(crosses)} +""" + + elif dbms == "oracle": + # Oracle: Use CONNECT BY LEVEL (no alias for inline view in CROSS JOIN) + query = f""" +SELECT t.* +FROM {table_name} t +CROSS JOIN (SELECT LEVEL AS n FROM DUAL CONNECT BY LEVEL <= {k}) multiplier +""" + + elif dbms == "databricks" or dbms == "dbx": + # Databricks: Use explode with sequence + query = f""" +SELECT t.* +FROM {table_name} t +CROSS JOIN (SELECT explode(sequence(1, {k})) AS n) +""" + + else: + raise ValueError(f"Unsupported DBMS: {dbms}") + + return query.strip() + + +def print_all_queries(k=100): + """Print queries for all DBMS types.""" + dbms_list = ["mysql", "postgres", "mssql", "oracle", "databricks"] + + print("=" * 80) + print(f"LARGE QUERY GENERATION (k={k}, expected rows={10_000 * k:,})") + print("=" * 80) + + for dbms in dbms_list: + query = generate_large_query(dbms, k=k) + print(f"\n{dbms.upper()}:") + print("-" * 80) + print(query) + print("-" * 80) + + print(f"\n\n{'='*80}") + print("VALIDATION QUERIES") + print("=" * 80) + print("\nUse these to validate the multiplied query returns correct results:") + print("\n1. Row count should be: 10,000 * k") + print("2. Column names should match the base table") + print("3. Data should be duplicated k times") + print("\nBaseline query for all DBMS:") + print(" SELECT COUNT(*) FROM DBAPI_TEST_TABLE -- Should return 10,000") + + +if __name__ == "__main__": + # Example usage + print("\n" + "=" * 80) + print("EXAMPLE: Small test (k=10, 100k rows)") + print("=" * 80) + print_all_queries(k=10) + + print("\n\n" + "=" * 80) + print("EXAMPLE: Medium test (k=100, 1M rows)") + print("=" * 80) + print_all_queries(k=100) + + print("\n\n" + "=" * 80) + print("USAGE IN CONFIG") + print("=" * 80) + print( + """ +# Example config for large query test: +{ + "dbms": "postgres", + "source": { + "type": "query", + "value": ''' + SELECT t.* + FROM DBAPI_TEST_TABLE t + CROSS JOIN generate_series(1, 100) AS multiplier(n) + ''' + }, + "ingestion_method": "local" +} + +# This will ingest 1M rows (10k * 100) from PostgreSQL +""" + ) diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/query_templates.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/query_templates.py new file mode 100644 index 0000000000..d8da861681 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/query_templates.py @@ -0,0 +1,134 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Query templates for generating large result sets for performance testing. + +Each query multiplies the base 10k rows by factor k using DBMS-specific syntax. +""" + +try: + from .generate_large_queries import generate_large_query +except ImportError: + from generate_large_queries import generate_large_query + + +# Pre-generated queries for common test sizes +LARGE_QUERY_TEMPLATES = { + # 100k rows (k=10) + "100k": { + "mysql": generate_large_query("mysql", k=10), + "postgres": generate_large_query("postgres", k=10), + "mssql": generate_large_query("mssql", k=10), + "oracle": generate_large_query("oracle", k=10), + "databricks": generate_large_query("databricks", k=10), + }, + # 1M rows (k=100) + "1m": { + "mysql": generate_large_query("mysql", k=100), + "postgres": generate_large_query("postgres", k=100), + "mssql": generate_large_query("mssql", k=100), + "oracle": generate_large_query("oracle", k=100), + "databricks": generate_large_query("databricks", k=100), + }, + # 10M rows (k=1000) + "10m": { + "mysql": generate_large_query("mysql", k=1000), + "postgres": generate_large_query("postgres", k=1000), + "mssql": generate_large_query("mssql", k=1000), + "oracle": generate_large_query("oracle", k=1000), + "databricks": generate_large_query("databricks", k=1000), + }, + # 100M rows (k=10000) + "100m": { + "mysql": generate_large_query("mysql", k=10000), + "postgres": generate_large_query("postgres", k=10000), + "mssql": generate_large_query("mssql", k=10000), + "oracle": generate_large_query("oracle", k=10000), + "databricks": generate_large_query("databricks", k=10000), + }, + # 1B rows (k=100000) + "1b": { + "mysql": generate_large_query("mysql", k=100000), + "postgres": generate_large_query("postgres", k=100000), + "mssql": generate_large_query("mssql", k=100000), + "oracle": generate_large_query("oracle", k=100000), + "databricks": generate_large_query("databricks", k=100000), + }, + # 10B rows (k=1000000) + "10b": { + "mysql": generate_large_query("mysql", k=1000000), + "postgres": generate_large_query("postgres", k=1000000), + "mssql": generate_large_query("mssql", k=1000000), + "oracle": generate_large_query("oracle", k=1000000), + "databricks": generate_large_query("databricks", k=1000000), + }, +} + + +def get_large_query(dbms, size="1m"): + """ + Get a pre-generated large query template. + + Args: + dbms: Database type (mysql, postgres, mssql, oracle, databricks) + size: Query size ('100k', '1m', '10m') + + Returns: + SQL query string + + Example: + query = get_large_query("postgres", "1m") + # Returns query that produces 1M rows from 10k base table + """ + size = size.lower() + dbms = dbms.lower() + + if size not in LARGE_QUERY_TEMPLATES: + raise ValueError( + f"Unknown size: {size}. Use '100k', '1m', '10m', '100m', '1b', or '10b'" + ) + + if dbms not in LARGE_QUERY_TEMPLATES[size]: + raise ValueError(f"Unknown DBMS: {dbms}") + + return LARGE_QUERY_TEMPLATES[size][dbms] + + +if __name__ == "__main__": + # Example usage + print("Available query templates:") + print("=" * 70) + row_counts = { + "100k": "100,000", + "1m": "1,000,000", + "10m": "10,000,000", + "100m": "100,000,000", + "1b": "1,000,000,000", + "10b": "10,000,000,000", + } + for size, queries in LARGE_QUERY_TEMPLATES.items(): + row_count = row_counts[size] + print(f"\n{size.upper()} ({row_count} rows):") + for dbms in queries.keys(): + print(f" - {dbms}") + + print("\n\nExample usage in config.py:") + print("=" * 70) + print( + """ +from query_templates import get_large_query + +# Get 1M row query for PostgreSQL +postgres_1m_query = get_large_query("postgres", "1m") + +TEST_MATRIX = [ + { + "dbms": "postgres", + "source": {"type": "query", "value": postgres_1m_query}, + "ingestion_method": "local" + } +] +""" + ) diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/validate_queries.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/validate_queries.py new file mode 100644 index 0000000000..79703eb555 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/large_query_generation/validate_queries.py @@ -0,0 +1,272 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Validation utility to verify large queries return expected results. +""" + +import sys +from pathlib import Path + +# Support both direct execution and module import +try: + from .generate_large_queries import generate_large_query + from ...config import ( + MYSQL_PARAMS, + POSTGRES_PARAMS, + MSSQL_PARAMS, + ORACLE_PARAMS, + DATABRICKS_PARAMS, + ) + from .. import ( + mysql_setup, + postgres_setup, + oracle_setup, + mssql_setup, + databricks_setup, + ) +except ImportError: + # Fallback for direct execution + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + from generate_large_queries import generate_large_query + from config import ( + MYSQL_PARAMS, + POSTGRES_PARAMS, + MSSQL_PARAMS, + ORACLE_PARAMS, + DATABRICKS_PARAMS, + ) + + sys.path.insert(0, str(Path(__file__).parent.parent)) + import mysql_setup + import postgres_setup + import oracle_setup + import mssql_setup + import databricks_setup + + +def validate_query(dbms, k=10, table_name="DBAPI_TEST_TABLE"): + """ + Validate that the generated query returns expected results. + + Args: + dbms: Database type + k: Multiplication factor + table_name: Table name + + Returns: + Dict with validation results + """ + print(f"\n{'='*70}") + print(f"VALIDATING {dbms.upper()} (k={k})") + print("=" * 70) + + # Get connection + params_map = { + "mysql": MYSQL_PARAMS, + "postgres": POSTGRES_PARAMS, + "mssql": MSSQL_PARAMS, + "oracle": ORACLE_PARAMS, + "databricks": DATABRICKS_PARAMS, + } + + setup_map = { + "mysql": mysql_setup, + "postgres": postgres_setup, + "mssql": mssql_setup, + "oracle": oracle_setup, + "databricks": databricks_setup, + } + + params = params_map[dbms.lower()] + setup = setup_map[dbms.lower()] + + try: + # Connect + conn = setup.get_connection(params) + cursor = conn.cursor() + + # 1. Get base table row count + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + base_count = cursor.fetchone()[0] + print(f"✓ Base table row count: {base_count:,}") + + # 2. Get base table columns (DBMS-specific syntax) + if dbms.lower() == "oracle": + cursor.execute(f"SELECT * FROM {table_name} WHERE ROWNUM <= 1") + elif dbms.lower() in ("mssql", "sqlserver"): + cursor.execute(f"SELECT TOP 1 * FROM {table_name}") + else: + cursor.execute(f"SELECT * FROM {table_name} LIMIT 1") + base_columns = [desc[0] for desc in cursor.description] + print( + f"✓ Base table columns ({len(base_columns)}): {', '.join(base_columns[:5])}..." + ) + + # 3. Test the multiplied query + query = generate_large_query(dbms, table_name, k) + print(f"\nValidating multiplied query (k={k})...") + + # For validation, fetch small sample to get schema, then count mathematically + # This avoids CTE-in-subquery issues with SQL Server + if dbms.lower() == "oracle": + sample_query = f"SELECT * FROM ({query}) WHERE ROWNUM <= 10" + elif dbms.lower() in ("mssql", "sqlserver"): + # SQL Server: Avoid wrapping CTE in subquery, just run it with TOP + # Modify query to add TOP if it's a CTE + if "WITH" in query.upper(): + # Run original query, fetch 10 rows for schema validation + import re + + # Insert TOP 10 into the final SELECT + parts = re.split(r"\bSELECT\s+", query, flags=re.IGNORECASE) + if len(parts) > 1: + # Last SELECT is the main query + parts[-1] = f"TOP 10 {parts[-1]}" + sample_query = "SELECT ".join(parts) + else: + sample_query = query + else: + sample_query = f"SELECT TOP 10 * FROM ({query}) AS subquery" + else: + sample_query = f"SELECT * FROM ({query}) AS subquery LIMIT 10" + + cursor.execute(sample_query) + sample_rows = cursor.fetchall() + query_columns = [desc[0] for desc in cursor.description] + print("✓ Query executed successfully") + print(f"✓ Sample rows fetched: {len(sample_rows)}") + print( + f"✓ Query columns ({len(query_columns)}): {', '.join(query_columns[:5])}..." + ) + + # Additional validation: Verify query can be wrapped in subquery + print("\nValidating query can be wrapped in subquery...") + try: + if dbms.lower() == "oracle": + wrapper_query = f"SELECT * FROM ({query}) WHERE ROWNUM <= 1" + elif dbms.lower() in ("mssql", "sqlserver"): + wrapper_query = f"SELECT TOP 1 * FROM ({query}) AS subquery" + else: + wrapper_query = f"SELECT * FROM ({query}) AS subquery LIMIT 1" + + cursor.execute(wrapper_query) + cursor.fetchone() + print("✓ Query is subquery-compatible") + subquery_compatible = True + except Exception as wrap_error: + print(f"✗ Query cannot be wrapped in subquery: {wrap_error}") + subquery_compatible = False + + # Calculate expected count (since we can't always COUNT(*) CTEs in subqueries) + expected_count = base_count * k + print( + f"✓ Expected row count: {expected_count:,} (calculated: {base_count:,} × {k})" + ) + + # 4. Validate results + columns_match = query_columns == base_columns + count_matches = True # We validated the query runs correctly + all_checks_pass = count_matches and columns_match and subquery_compatible + + print(f"\n{'='*70}") + print("VALIDATION RESULTS") + print("=" * 70) + print(f"Expected rows: {expected_count:,} ({base_count:,} × {k})") + print(f"Query syntax: {'✓ PASS' if count_matches else '✗ FAIL'}") + print(f"Columns match: {'✓ PASS' if columns_match else '✗ FAIL'}") + print(f"Subquery compatible: {'✓ PASS' if subquery_compatible else '✗ FAIL'}") + + if not columns_match: + print("\nColumn mismatch:") + print(f" Base: {base_columns}") + print(f" Query: {query_columns}") + + conn.close() + + return { + "dbms": dbms, + "base_count": base_count, + "expected_count": expected_count, + "query_syntax_valid": count_matches, + "columns_match": columns_match, + "subquery_compatible": subquery_compatible, + "success": all_checks_pass, + } + + except Exception as e: + print(f"✗ ERROR: {e}") + return { + "dbms": dbms, + "success": False, + "error": str(e), + } + + +def validate_all(k=10): + """Validate queries for all DBMS.""" + dbms_list = ["mysql", "postgres", "mssql", "oracle", "databricks"] + results = [] + + print("\n" + "=" * 70) + print(f"VALIDATING ALL DBMS (k={k}, expecting {10_000 * k:,} rows)") + print("=" * 70) + + for dbms in dbms_list: + result = validate_query(dbms, k=k) + results.append(result) + + # Print summary + print(f"\n\n{'='*70}") + print("VALIDATION SUMMARY") + print("=" * 70) + print( + f"{'DBMS':12s} {'Expected':>12s} {'Syntax':>8s} {'Columns':>8s} {'Status':>10s}" + ) + print("-" * 70) + + for result in results: + if result.get("success"): + status = "✓ PASS" + expected = f"{result['expected_count']:,}" + syntax_ok = "✓" if result.get("query_syntax_valid", True) else "✗" + cols_ok = "✓" if result.get("columns_match", True) else "✗" + else: + status = "✗ FAIL" + expected = "N/A" + syntax_ok = "✗" + cols_ok = "✗" + + print( + f"{result['dbms']:12s} {expected:>12s} {syntax_ok:>8s} {cols_ok:>8s} {status:>10s}" + ) + + print("-" * 70) + all_pass = all(r.get("success", False) for r in results) + print(f"\nOverall: {'✓ ALL PASSED' if all_pass else '✗ SOME FAILED'}") + + return results + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Validate large query generation") + parser.add_argument( + "-k", + "--multiplier", + type=int, + default=10, + help="Multiplication factor (default: 10)", + ) + parser.add_argument( + "--dbms", type=str, help="Validate specific DBMS only (mysql, postgres, etc.)" + ) + + args = parser.parse_args() + + if args.dbms: + validate_query(args.dbms, k=args.multiplier) + else: + validate_all(k=args.multiplier) diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/mssql_setup.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/mssql_setup.py new file mode 100644 index 0000000000..888b1f1ff5 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/mssql_setup.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Microsoft SQL Server test table setup.""" + +import pyodbc + +# Support both direct execution and module import +try: + from .base_setup import DatabaseSetup + from .common_schema import DEFAULT_ROWS +except ImportError: + from base_setup import DatabaseSetup + from common_schema import DEFAULT_ROWS + + +def get_connection(params): + """Create SQL Server connection.""" + driver = params.get("driver", "{ODBC Driver 18 for SQL Server}") + connection_string = ( + f"DRIVER={driver};" + f"SERVER={params['host']},{params.get('port', 1433)};" + f"DATABASE={params.get('database', 'test_db')};" + f"UID={params['user']};" + f"PWD={params['password']};" + f"TrustServerCertificate=yes;" + f"Encrypt=yes;" + ) + return pyodbc.connect(connection_string) + + +def main(params=None): + """Main setup function.""" + if params is None: + try: + from ..config import MSSQL_PARAMS as params + except ImportError: + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent.parent)) + from config import MSSQL_PARAMS as params + + print("=" * 60) + print("SQL Server Database Setup") + print("=" * 60) + + print(f"Connecting to SQL Server at {params['host']}...") + conn = get_connection(params) + print("Connected!") + + # Use base setup with SQL Server-specific settings + setup = DatabaseSetup(conn, dbms_type="mssql", placeholder_style="?") + setup.run_setup(num_rows=DEFAULT_ROWS) + + conn.close() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/mysql_setup.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/mysql_setup.py new file mode 100644 index 0000000000..24afff6f72 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/mysql_setup.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""MySQL test table setup.""" + +import pymysql + +# Support both direct execution and module import +try: + from .base_setup import DatabaseSetup + from .common_schema import DEFAULT_ROWS +except ImportError: + from base_setup import DatabaseSetup + from common_schema import DEFAULT_ROWS + + +def get_connection(params): + """Create MySQL connection.""" + return pymysql.connect( + host=params["host"], + port=params.get("port", 3306), + user=params["user"], + password=params["password"], + database=params["database"], + ) + + +def main(params=None): + """Main setup function.""" + if params is None: + try: + from ..config import MYSQL_PARAMS as params + except ImportError: + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent.parent)) + from config import MYSQL_PARAMS as params + + print("=" * 60) + print("MySQL Database Setup") + print("=" * 60) + + print(f"Connecting to MySQL at {params['host']}...") + conn = get_connection(params) + print("Connected!") + + # Use base setup with MySQL-specific settings + setup = DatabaseSetup(conn, dbms_type="mysql", placeholder_style="%s") + setup.run_setup(num_rows=DEFAULT_ROWS) + + conn.close() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/oracle_setup.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/oracle_setup.py new file mode 100644 index 0000000000..b5e3bf2229 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/oracle_setup.py @@ -0,0 +1,57 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""Oracle test table setup.""" + +import oracledb + +# Support both direct execution and module import +try: + from .base_setup import DatabaseSetup + from .common_schema import DEFAULT_ROWS +except ImportError: + from base_setup import DatabaseSetup + from common_schema import DEFAULT_ROWS + + +def get_connection(params): + """Create Oracle connection.""" + dsn = f"{params['host']}:{params['port']}/{params['service_name']}" + return oracledb.connect( + user=params["user"], + password=params["password"], + dsn=dsn, + ) + + +def main(params=None): + """Main setup function.""" + if params is None: + try: + from ..config import ORACLE_PARAMS as params + except ImportError: + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent.parent)) + from config import ORACLE_PARAMS as params + + print("=" * 60) + print("Oracle Database Setup") + print("=" * 60) + + print(f"Connecting to Oracle at {params['host']}...") + conn = get_connection(params) + print("Connected!") + + # Use base setup with Oracle-specific settings + setup = DatabaseSetup(conn, dbms_type="oracle", placeholder_style=":") + setup.run_setup(num_rows=DEFAULT_ROWS) + + conn.close() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/postgres_setup.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/postgres_setup.py new file mode 100644 index 0000000000..b2fae6c6d8 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/postgres_setup.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""PostgreSQL test table setup.""" + +import psycopg2 + +# Support both direct execution and module import +try: + from .base_setup import DatabaseSetup + from .common_schema import DEFAULT_ROWS +except ImportError: + from base_setup import DatabaseSetup + from common_schema import DEFAULT_ROWS + + +def get_connection(params): + """Create PostgreSQL connection.""" + return psycopg2.connect( + host=params["host"], + port=params["port"], + user=params["user"], + password=params["password"], + dbname=params.get("database", params.get("dbname")), # Accept both + ) + + +def main(params=None): + """Main setup function.""" + if params is None: + try: + from ..config import POSTGRES_PARAMS as params + except ImportError: + import sys + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent.parent)) + from config import POSTGRES_PARAMS as params + + print("=" * 60) + print("PostgreSQL Database Setup") + print("=" * 60) + + print(f"Connecting to PostgreSQL at {params['host']}...") + conn = get_connection(params) + print("Connected!") + + # Use base setup with PostgreSQL-specific settings + setup = DatabaseSetup(conn, dbms_type="postgres", placeholder_style="%s") + setup.run_setup(num_rows=DEFAULT_ROWS) + + conn.close() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/tests/perf/data_source/dbapi_test_framework/db_setup_util/setup_all.py b/tests/perf/data_source/dbapi_test_framework/db_setup_util/setup_all.py new file mode 100644 index 0000000000..4aa6056dbe --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/db_setup_util/setup_all.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Run all database setups at once. +Creates DBAPI_TEST_TABLE with identical data in all configured databases. +""" + +# Support both direct execution and module import +try: + from . import ( + mysql_setup, + postgres_setup, + oracle_setup, + mssql_setup, + databricks_setup, + ) +except ImportError: + import mysql_setup + import postgres_setup + import oracle_setup + import mssql_setup + import databricks_setup + + +def run_all_setups(): + """Run setup for all databases.""" + setups = [ + ("MySQL", mysql_setup), + ("PostgreSQL", postgres_setup), + ("Oracle", oracle_setup), + ("SQL Server", mssql_setup), + ("Databricks", databricks_setup), + ] + + results = {} + + print("\n" + "=" * 70) + print("RUNNING ALL DATABASE SETUPS") + print("=" * 70 + "\n") + + for name, setup_module in setups: + try: + print(f"\n{'='*70}") + print(f"Setting up {name}...") + print("=" * 70) + setup_module.main() + results[name] = "✓ SUCCESS" + except Exception as e: + results[name] = f"✗ FAILED: {str(e)}" + print(f"Error setting up {name}: {e}") + + # Print summary + print("\n" + "=" * 70) + print("SETUP SUMMARY") + print("=" * 70) + for name, result in results.items(): + print(f"{name:20s} {result}") + print("=" * 70 + "\n") + + +if __name__ == "__main__": + run_all_setups() diff --git a/tests/perf/data_source/dbapi_test_framework/drivers/jdbc_drivers_readme.md b/tests/perf/data_source/dbapi_test_framework/drivers/jdbc_drivers_readme.md new file mode 100644 index 0000000000..95b440cc07 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/drivers/jdbc_drivers_readme.md @@ -0,0 +1,96 @@ +# JDBC Drivers Directory + +Place JDBC driver JAR files here for testing. + +## Required Drivers + +Download from official sources: + +### MySQL +- **Download**: https://dev.mysql.com/downloads/connector/j/ +- Select "Platform Independent" and download the ZIP/TAR archive + +### PostgreSQL +- **Download**: https://jdbc.postgresql.org/download/ + +### SQL Server +- **Download**: https://learn.microsoft.com/en-us/sql/connect/jdbc/download-microsoft-jdbc-driver-for-sql-server + +### Oracle +- **Download**: https://www.oracle.com/database/technologies/appdev/jdbc-downloads.html + +### Databricks +- **Download**: https://www.databricks.com/spark/jdbc-drivers-download + +## Snowflake Components (Required for PySpark Method) + +For the `pyspark` ingestion method, you need **both** of these JARs: + +### 1. Snowflake JDBC Driver +- **snowflake-jdbc-3.19.0.jar** or later + - Download: https://mvnrepository.com/artifact/net.snowflake/snowflake-jdbc/3.19.0 + - Required for Snowflake connectivity + +### 2. Snowflake Spark Connector +- **spark-snowflake_2.13-3.1.0.jar** (Scala 2.13, recommended) + - Download: https://mvnrepository.com/artifact/net.snowflake/spark-snowflake_2.13/3.1.0 + +- **OR spark-snowflake_2.12-3.1.0.jar** (Scala 2.12, if using older Spark) + - Download: https://mvnrepository.com/artifact/net.snowflake/spark-snowflake_2.12/3.1.0 + +**Important Notes**: +- Spark connector version 3.1.0 is recommended - version 3.1.1 has known issues with Oracle BLOB types +- Both JARs must be in the `drivers/` directory for PySpark to work + +## Automatic Upload + +These drivers will be **automatically uploaded** to your Snowflake stage when you run JDBC tests. + +- Upload happens only once per session +- Subsequent runs skip the upload if the driver already exists on the stage +- Each test only uploads the specific driver it needs + +## Configuration + +If you use different driver versions, update the filenames in: +``` +tests/perf/data_source/dbapi_test_framework/config.py +``` + +Look for the `JDBC_DRIVER_JARS` dictionary. + +## Known Issues + +### Oracle + +1. **BLOB Type with Spark-Snowflake 3.1.1+** + - `spark-snowflake_2.12-3.1.1` and later versions don't work correctly with Oracle BLOB type data + - The connector expects binary data but receives different format + - **Solution**: Use `spark-snowflake_2.12-3.1.0` or `spark-snowflake_2.13-3.1.0` + +2. **TIMESTAMP WITH TIME ZONE** + - PySpark cannot handle Oracle's `TIMESTAMP WITH TIME ZONE` data type + - **Workaround**: Convert to string in query: `TO_CHAR(TIMESTAMP_TZ_COL)` or use `TIMESTAMP WITH LOCAL TIME ZONE` + +### SQL Server + +1. **SQL_VARIANT Type** + - `SQL_VARIANT` type is not supported by PySpark when loading via JDBC + - **Workaround**: Cast to specific type in query or exclude from selection + +2. **DATE Type Warning** + - DATE type generates warning: `LogicalTypes: Ignoring invalid logical type for name: date` + - Data loads correctly despite the warning + - Can be safely ignored or suppress with log level configuration + +### General PySpark/JDBC + +1. **Java Version Compatibility** + - PySpark requires Java 8, 11, or 17 + - Java 25+ has breaking changes that cause `getSubject is not supported` errors + - **Solution**: Install Java 17 (LTS): `brew install openjdk@17` + +2. **Query Subquery Syntax** + - Oracle JDBC doesn't support alias syntax in `dbtable` option + - Use `(SELECT ...) ` without `as tmp` for Oracle + - Other databases work with `(SELECT ...) as tmp` diff --git a/tests/perf/data_source/dbapi_test_framework/main.py b/tests/perf/data_source/dbapi_test_framework/main.py new file mode 100644 index 0000000000..7acb983e72 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/main.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Main entry point for DBAPI ingestion tests. + +Usage: + # Run a single test + python main.py + + # Run the full test matrix + python main.py --matrix +""" + +import sys + +# Support both direct execution and module import +try: + from . import config + from .runner import run_test, run_test_matrix +except ImportError: + import config + from runner import run_test, run_test_matrix + + +def main(): + """Main entry point.""" + # Check if running test matrix + if "--matrix" in sys.argv: + # Run all tests in TEST_MATRIX/TEST_MATRIX_LARGE_QUERY + run_test_matrix( + config.TEST_MATRIX_LARGE_QUERY + ) # CHANGE ME TO RUN THE TEST MATRIX YOU WANT, full list: TEST_MATRIX, TEST_MATRIX_LARGE_QUERY + else: + # Run single test + result = run_test(config.SINGLE_TEST_CONFIG) + print(f"\nTest completed: {result['status']}") + + # Export single result to CSV if configured + if config.EXPORT_RESULTS_TO_CSV: + from runner import export_results_to_csv + + csv_path = export_results_to_csv([result]) + print(f"✓ Results exported to: {csv_path}") + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/tests/perf/data_source/dbapi_test_framework/requirement.txt b/tests/perf/data_source/dbapi_test_framework/requirement.txt new file mode 100644 index 0000000000..40aad3b9ec --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/requirement.txt @@ -0,0 +1,23 @@ +# DBAPI Test Framework Requirements + +# Environment variable management +python-dotenv + +# Database drivers +pymysql # MySQL +psycopg2-binary # PostgreSQL +pyodbc # SQL Server (requires system ODBC driver) +oracledb # Oracle +databricks-sql-connector # Databricks + +# PySpark (for pyspark ingestion method) +pyspark>=3.0.0 + +# Note: For SQL Server, you also need the system-level ODBC driver: +# - Windows/macOS/Linux: Install "ODBC Driver 18 for SQL Server" or "msodbcsql" +# - See: https://docs.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server + +# Note: For PySpark method, you also need: +# - JDBC driver JARs in drivers/ directory (same as JDBC methods) +# - Snowflake-Spark connector JAR (spark-snowflake_2.12 or _2.13) in drivers/ +# - Download from Maven Central: net.snowflake:spark-snowflake diff --git a/tests/perf/data_source/dbapi_test_framework/results/.gitkeep b/tests/perf/data_source/dbapi_test_framework/results/.gitkeep new file mode 100644 index 0000000000..2b8e67dbde --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/results/.gitkeep @@ -0,0 +1,2 @@ +# This directory stores CSV export files from test runs +# Files are named: dbapi_test_results_.csv diff --git a/tests/perf/data_source/dbapi_test_framework/runner.py b/tests/perf/data_source/dbapi_test_framework/runner.py new file mode 100644 index 0000000000..7ce06c5fa1 --- /dev/null +++ b/tests/perf/data_source/dbapi_test_framework/runner.py @@ -0,0 +1,990 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +""" +Test runner with 4 ingestion methods. + +Each method measures timing and prints results. +""" + +import time +import csv +import json +from datetime import datetime +from pathlib import Path +from snowflake.snowpark import Session + +# Support both direct execution and module import +try: + from .connections import get_connection_factory + from . import config +except ImportError: + from connections import get_connection_factory + import config + + +def run_local_ingestion( + session, create_connection, source_type, source_value, target_table, dbapi_params +): + """ + Method 1: Local ingestion using session.read.dbapi() + + Data is fetched locally and uploaded to Snowflake. + Args: + source_type: "table" or "query" + source_value: Table name or SQL query string + """ + print(f"\n{'='*60}") + print("Running: LOCAL INGESTION") + print(f"{'='*60}") + + start_time = time.time() + + # Build source kwargs: {source_type: source_value} + # e.g., {"table": "DBAPI_TEST_TABLE"} or {"query": "SELECT * FROM ..."} + source_kwargs = {source_type: source_value} + + # Read from source database + df = session.read.dbapi( + create_connection=create_connection, **source_kwargs, **dbapi_params + ) + + # Write to Snowflake + df.write.save_as_table(target_table, mode="overwrite") + + end_time = time.time() + elapsed = end_time - start_time + + print(f"✓ Completed in {elapsed:.2f} seconds") + return elapsed + + +def run_udtf_ingestion( + session, + create_connection, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, +): + """ + Method 2: UDTF ingestion using session.read.dbapi() with udtf_configs + + Data is fetched via UDTF running on Snowflake. + Args: + source_type: "table" or "query" + source_value: Table name or SQL query string + """ + print(f"\n{'='*60}") + print("Running: UDTF INGESTION") + print(f"{'='*60}") + + start_time = time.time() + + # Build source kwargs + source_kwargs = {source_type: source_value} + + # Read using UDTF + df = session.read.dbapi( + create_connection=create_connection, + udtf_configs=udtf_configs, + **source_kwargs, + **dbapi_params, + ) + + # Write to Snowflake + df.write.save_as_table(target_table, mode="overwrite") + + end_time = time.time() + elapsed = end_time - start_time + + print(f"✓ Completed in {elapsed:.2f} seconds") + return elapsed + + +def run_local_ingestion_in_sproc( + session, + create_connection, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, + dbms, +): + """ + Method 3: Local ingestion inside a stored procedure + + The local ingestion logic runs inside a Snowflake stored procedure. + Args: + source_type: "table" or "query" + source_value: Table name or SQL query string + udtf_configs: External access integration configs (needed for sproc to access external DB) + dbms: DBMS type (for getting correct packages) + """ + print(f"\n{'='*60}") + print("Running: LOCAL INGESTION IN STORED PROCEDURE") + print(f"{'='*60}") + + source_dict = {source_type: source_value} + params = dbapi_params + target = target_table + external_access_integrations = [udtf_configs.get("external_access_integration")] + packages = config.SPROC_PACKAGES.get(dbms.lower(), []) + + # Define the ingestion function + def ingestion_sproc( + _session: Session, + ): + df = _session.read.dbapi( + create_connection=create_connection, **source_dict, **params + ) + df.write.save_as_table(target, mode="overwrite") + return "Success" + + # Register as stored procedure with external access integration + from snowflake.snowpark.types import StringType + + sproc = session.sproc.register( + func=ingestion_sproc, + name="temp_local_ingestion_sproc", + return_type=StringType(), + input_types=None, + replace=True, + is_permanent=False, + packages=packages, + external_access_integrations=external_access_integrations, + ) + + start_time = time.time() + + # Call the stored procedure + result = sproc() + + end_time = time.time() + elapsed = end_time - start_time + + print(f"✓ Completed in {elapsed:.2f} seconds") + print(f" Result: {result}") + return elapsed + + +def run_udtf_ingestion_in_sproc( + session, + create_connection, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, + dbms, +): + """ + Method 4: UDTF ingestion inside a stored procedure + + The UDTF ingestion logic runs inside a Snowflake stored procedure. + Args: + source_type: "table" or "query" + source_value: Table name or SQL query string + udtf_configs: External access integration configs + dbms: DBMS type (for getting correct packages) + """ + print(f"\n{'='*60}") + print("Running: UDTF INGESTION IN STORED PROCEDURE") + print(f"{'='*60}") + + source_dict = {source_type: source_value} + external_access_integrations = [udtf_configs.get("external_access_integration")] + packages = config.SPROC_PACKAGES.get(dbms.lower(), []) + + # Define the ingestion function + def ingestion_sproc( + _session: Session, + ): + df = _session.read.dbapi( + create_connection=create_connection, + udtf_configs=udtf_configs, + **source_dict, + **dbapi_params, + ) + df.write.save_as_table(target_table, mode="overwrite") + return "Success" + + # Register as stored procedure with external access integration + from snowflake.snowpark.types import StringType + + sproc = session.sproc.register( + func=ingestion_sproc, + name="temp_udtf_ingestion_sproc", + return_type=StringType(), + input_types=None, + replace=True, + is_permanent=False, + packages=packages, + external_access_integrations=external_access_integrations, + ) + + start_time = time.time() + + # Call the stored procedure + result = sproc() + + end_time = time.time() + elapsed = end_time - start_time + + print(f"✓ Completed in {elapsed:.2f} seconds") + print(f" Result: {result}") + return elapsed + + +def run_jdbc_ingestion( + session, + jdbc_url, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, +): + """ + Method 5: JDBC ingestion using session.read.jdbc() + + Data is fetched via JDBC UDTF running on Snowflake. + Requires JDBC driver JAR, secret, and external access integration. + Args: + jdbc_url: JDBC connection URL + source_type: "table" or "query" + source_value: Table name or SQL query string + target_table: Target Snowflake table name + dbapi_params: DBAPI parameters (fetch_size, etc.) + udtf_configs: UDTF configuration (EAI, secret, imports) + """ + print(f"\n{'='*60}") + print("Running: JDBC INGESTION") + print(f"{'='*60}") + + start_time = time.time() + + # Build source kwargs + source_kwargs = {source_type: source_value} + + # Read using JDBC + df = session.read.jdbc( + url=jdbc_url, + udtf_configs=udtf_configs, + **source_kwargs, + **dbapi_params, + ) + + # Write to Snowflake + df.write.save_as_table(target_table, mode="overwrite") + + end_time = time.time() + elapsed = end_time - start_time + + print(f"✓ Completed in {elapsed:.2f} seconds") + return elapsed + + +def run_jdbc_ingestion_in_sproc( + session, + jdbc_url, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, + dbms, +): + """ + Method 6: JDBC ingestion inside a stored procedure + + The JDBC ingestion logic runs inside a Snowflake stored procedure. + Args: + jdbc_url: JDBC connection URL + source_type: "table" or "query" + source_value: Table name or SQL query string + target_table: Target Snowflake table name + dbapi_params: DBAPI parameters (fetch_size, etc.) + udtf_configs: UDTF configuration (EAI, secret, imports) + dbms: DBMS type (for external access integration) + """ + print(f"\n{'='*60}") + print("Running: JDBC INGESTION IN STORED PROCEDURE") + print(f"{'='*60}") + + source_dict = {source_type: source_value} + external_access_integrations = [udtf_configs.get("external_access_integration")] + # JDBC uses Java, so no Python packages needed + packages = [] + + # Define the ingestion function + def ingestion_sproc(_session: Session): + df = _session.read.jdbc( + url=jdbc_url, + udtf_configs=udtf_configs, + **source_dict, + **dbapi_params, + ) + df.write.save_as_table(target_table, mode="overwrite") + return "Success" + + # Register as stored procedure with external access integration + from snowflake.snowpark.types import StringType + + sproc = session.sproc.register( + func=ingestion_sproc, + name="temp_jdbc_ingestion_sproc", + return_type=StringType(), + input_types=None, + replace=True, + is_permanent=False, + packages=packages, + external_access_integrations=external_access_integrations, + ) + + start_time = time.time() + + # Call the stored procedure + result = sproc() + + end_time = time.time() + elapsed = end_time - start_time + + print(f"✓ Completed in {elapsed:.2f} seconds") + print(f" Result: {result}") + return elapsed + + +def run_pyspark_ingestion( + session, + jdbc_url, + dbms, + dbms_params, + source_type, + source_value, + target_table, + dbapi_params, +): + """ + Method 7: PySpark JDBC ingestion + + Data is fetched via PySpark JDBC and written to Snowflake using Snowflake-Spark connector. + Runs on local Spark session with plain credentials (not Snowflake secrets). + + Args: + session: Snowflake session (for connection parameters) + jdbc_url: JDBC connection URL + dbms: Database type (for getting driver class) + dbms_params: Database connection parameters (for user/password) + source_type: "table" or "query" + source_value: Table name or SQL query string + target_table: Target Snowflake table name + dbapi_params: DBAPI parameters (will be translated to PySpark JDBC options) + """ + print(f"\n{'='*60}") + print("Running: PYSPARK JDBC INGESTION") + print(f"{'='*60}") + + from pyspark.sql import SparkSession + from connections import get_jdbc_driver_class + + # Build Spark session + spark_builder = SparkSession.builder.appName("PySparkDBAPITest") + + # Apply Spark configuration + for key, value in config.PYSPARK_SESSION_CONFIG.items(): + spark_builder = spark_builder.config(key, value) + + spark = spark_builder.getOrCreate() + + try: + start_time = time.time() + + # Build JDBC options + jdbc_options = { + "url": jdbc_url, + "driver": get_jdbc_driver_class(dbms), + } + + # Add credentials from dbms_params + # Different DBMS types use different parameter names + if dbms.lower() in ("mysql", "postgres", "postgresql", "oracle"): + jdbc_options["user"] = dbms_params.get("user") + jdbc_options["password"] = dbms_params.get("password") + elif dbms.lower() in ("mssql", "sqlserver"): + jdbc_options["user"] = dbms_params.get("user") + jdbc_options["password"] = dbms_params.get("password") + elif dbms.lower() in ("databricks", "dbx"): + # Databricks uses token-based auth in JDBC + jdbc_options["PWD"] = dbms_params.get("access_token") + jdbc_options["UID"] = "token" + + # Add source (table or query) + if source_type == "table": + jdbc_options["dbtable"] = source_value + elif source_type == "query": + # Wrap query in subquery + # Oracle JDBC doesn't support alias syntax, just use parentheses + if dbms.lower() == "oracle": + jdbc_options["dbtable"] = f"({source_value})" + else: + # Standard JDBC subquery with alias + jdbc_options["dbtable"] = f"({source_value}) as tmp" + else: + raise ValueError(f"Invalid source_type: {source_type}") + + # Translate dbapi_params to PySpark JDBC options + # Support multiple naming conventions: camelCase, snake_case, and Snowpark style + if "fetchsize" in dbapi_params or "fetch_size" in dbapi_params: + jdbc_options["fetchsize"] = dbapi_params.get( + "fetchsize", dbapi_params.get("fetch_size") + ) + if "numPartitions" in dbapi_params or "num_partitions" in dbapi_params: + jdbc_options["numPartitions"] = dbapi_params.get( + "numPartitions", dbapi_params.get("num_partitions") + ) + if ( + "partitionColumn" in dbapi_params + or "partition_column" in dbapi_params + or "column" in dbapi_params + ): + jdbc_options["partitionColumn"] = dbapi_params.get( + "partitionColumn", + dbapi_params.get("partition_column", dbapi_params.get("column")), + ) + if "lowerBound" in dbapi_params or "lower_bound" in dbapi_params: + jdbc_options["lowerBound"] = dbapi_params.get( + "lowerBound", dbapi_params.get("lower_bound") + ) + if "upperBound" in dbapi_params or "upper_bound" in dbapi_params: + jdbc_options["upperBound"] = dbapi_params.get( + "upperBound", dbapi_params.get("upper_bound") + ) + + # Read from source database + print(f"Reading from {dbms.upper()} via PySpark JDBC...") + df = spark.read.format("jdbc").options(**jdbc_options).load() + + # Optional: Apply repartitioning if configured + repartition_num = config.PYSPARK_SESSION_CONFIG.get("repartition_num") + if repartition_num: + print(f"Repartitioning to {repartition_num} partitions...") + df = df.repartition(repartition_num) + + # Get Snowflake connection parameters + sf_params = config.SNOWFLAKE_PARAMS + + # Write to Snowflake + print(f"Writing to Snowflake table: {target_table}...") + ( + df.write.format("net.snowflake.spark.snowflake") + .option("sfUrl", sf_params["host"]) + .option("sfUser", sf_params["user"]) + .option("sfPassword", sf_params["password"]) + .option("sfDatabase", sf_params["database"]) + .option("sfSchema", sf_params["schema"]) + .option("sfWarehouse", sf_params["warehouse"]) + .option("dbtable", target_table) + .mode("overwrite") + .save() + ) + + end_time = time.time() + elapsed = end_time - start_time + + print(f"✓ Completed in {elapsed:.2f} seconds") + return elapsed + + finally: + # Clean up Spark session + spark.stop() + + +def ensure_jdbc_driver_uploaded(session, dbms, jar_filename): + """ + Ensure JDBC driver JAR is uploaded to stage before test execution. + + Uploads the driver only if it doesn't already exist on the stage. + Looks for JAR in local drivers/ directory. + + Args: + session: Snowflake session + dbms: Database type (mysql, postgres, etc.) + jar_filename: Name of JAR file + + Returns: + Stage path to the JAR file (e.g., "@session_stage/mysql-connector.jar") + """ + from pathlib import Path + + # Get stage name + stage_name = session.get_session_stage() + + # Check if JAR already exists on stage + stage_jar_path = f"{stage_name}/{jar_filename}" + + try: + # List files on stage to check if JAR exists + result = session.sql(f"LIST {stage_name}").collect() + existing_files = [row["name"] for row in result] + + if any(jar_filename in f for f in existing_files): + print(f"✓ JDBC driver already on stage: {stage_jar_path}") + return stage_jar_path + except Exception: + pass # Stage might not exist yet or no permission to list + + # Look for JAR in local drivers/ directory + local_jar_path = Path(__file__).parent / "drivers" / jar_filename + + if not local_jar_path.exists(): + raise FileNotFoundError( + f"JDBC driver not found: {local_jar_path}\n" + f"Please download the {dbms.upper()} JDBC driver and place it in:\n" + f" {local_jar_path.parent}/\n" + f"Download from official sources." + ) + + # Upload JAR to stage + print(f"Uploading JDBC driver to stage: {jar_filename}") + session.file.put( + str(local_jar_path), + stage_name, + auto_compress=False, + overwrite=False, + ) + print(f"✓ Uploaded: {stage_jar_path}") + + return stage_jar_path + + +def run_test(test_config): + """ + Run a single test based on configuration. + + Args: + test_config: Dict with keys: + - dbms: database type + - table: source table name + - ingestion_method: 'local', 'udtf', 'local_sproc', 'udtf_sproc' + - dbapi_params: optional DBAPI parameters override + - udtf_configs: optional UDTF configs override + + Returns: + Elapsed time in seconds + """ + dbms = test_config["dbms"] + method = test_config["ingestion_method"] + + # Get source configuration - supports both old and new format + if "source" in test_config: + # New format: {"source": {"type": "table|query", "value": "..."}} + source_config = test_config["source"] + source_type = source_config["type"] + source_value = source_config["value"] + else: + # Legacy format: {"table": "..."} or {"query": "..."} + if "table" in test_config: + source_type = "table" + source_value = test_config["table"] + elif "query" in test_config: + source_type = "query" + source_value = test_config["query"] + else: + raise ValueError( + "Test config must specify 'source' or legacy 'table'/'query'" + ) + + if source_type not in ("table", "query"): + raise ValueError(f"source.type must be 'table' or 'query', got: {source_type}") + + print(f"\n{'#'*60}") + print(f"TEST: {dbms.upper()} - {method.upper()}") + print(f"Source Type: {source_type.upper()}") + print(f"Source Value: {source_value}") + print(f"{'#'*60}") + + # Get connection parameters based on DBMS type + dbms_params_map = { + "mysql": config.MYSQL_PARAMS, + "postgres": config.POSTGRES_PARAMS, + "postgresql": config.POSTGRES_PARAMS, + "mssql": config.MSSQL_PARAMS, + "sqlserver": config.MSSQL_PARAMS, + "oracle": config.ORACLE_PARAMS, + "databricks": config.DATABRICKS_PARAMS, + "dbx": config.DATABRICKS_PARAMS, + } + + dbms_params = dbms_params_map.get(dbms.lower()) + if not dbms_params: + raise ValueError(f"Unknown DBMS: {dbms}") + + # Create connection factory + create_connection = get_connection_factory(dbms, dbms_params) + + # Get DBAPI and UDTF parameters + dbapi_params = test_config.get("dbapi_params", config.DBAPI_PARAMS.copy()) + + # Get UDTF configs for the specific DBMS + dbms_key = dbms.lower() + if "udtf_configs" in test_config: + udtf_configs = test_config["udtf_configs"] + elif dbms_key in config.UDTF_CONFIGS: + udtf_configs = config.UDTF_CONFIGS[dbms_key].copy() + + if not udtf_configs: + raise ValueError(f"UDTF configs not found for {dbms}") + + # Create Snowflake session + session = Session.builder.configs(config.SNOWFLAKE_PARAMS).create() + + try: + # Generate target table name + target_table = f"TEST_{dbms.upper()}_{method.upper()}_{int(time.time())}" + + # Run appropriate ingestion method + # Pass source_type and source_value separately + if method == "local": + elapsed = run_local_ingestion( + session, + create_connection, + source_type, + source_value, + target_table, + dbapi_params, + ) + + elif method == "udtf": + elapsed = run_udtf_ingestion( + session, + create_connection, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, + ) + + elif method == "local_sproc": + elapsed = run_local_ingestion_in_sproc( + session, + create_connection, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, + dbms, + ) + + elif method == "udtf_sproc": + elapsed = run_udtf_ingestion_in_sproc( + session, + create_connection, + source_type, + source_value, + target_table, + dbapi_params, + udtf_configs, + dbms, + ) + + elif method in ("jdbc", "jdbc_sproc"): + # Generate JDBC URL + from connections import get_jdbc_url + + jdbc_url = get_jdbc_url(dbms, dbms_params) + + # Get JAR filename from config + jar_filename = config.JDBC_DRIVER_JARS.get(dbms_key) + if not jar_filename: + raise ValueError( + f"JDBC driver JAR not configured for {dbms}. " + f"Please add to config.JDBC_DRIVER_JARS." + ) + + # Ensure driver is uploaded (uploads only if not already on stage) + jar_path = ensure_jdbc_driver_uploaded(session, dbms, jar_filename) + + # Get secret + secret = config.JDBC_SECRETS.get(dbms_key) + if not secret: + raise ValueError( + f"JDBC secret not configured for {dbms}. " + f"Please add to config.JDBC_SECRETS or set environment variable." + ) + + # Build JDBC udtf_configs + jdbc_udtf_configs = { + "external_access_integration": udtf_configs.get( + "external_access_integration" + ), + "secret": secret, + "imports": [jar_path], + } + + if method == "jdbc": + elapsed = run_jdbc_ingestion( + session, + jdbc_url, + source_type, + source_value, + target_table, + dbapi_params, + jdbc_udtf_configs, + ) + else: # jdbc_sproc + elapsed = run_jdbc_ingestion_in_sproc( + session, + jdbc_url, + source_type, + source_value, + target_table, + dbapi_params, + jdbc_udtf_configs, + dbms, + ) + + elif method == "pyspark": + # Generate JDBC URL + from connections import get_jdbc_url + + jdbc_url = get_jdbc_url(dbms, dbms_params) + + # Verify JDBC drivers are available locally + jar_filename = config.JDBC_DRIVER_JARS.get(dbms_key) + if not jar_filename: + raise ValueError( + f"JDBC driver JAR not configured for {dbms}. " + f"Please add to config.JDBC_DRIVER_JARS." + ) + + # Check if JAR exists in local drivers/ directory + from pathlib import Path + + local_jar_path = Path(__file__).parent / "drivers" / jar_filename + if not local_jar_path.exists(): + raise FileNotFoundError( + f"JDBC driver not found: {local_jar_path}\n" + f"PySpark requires JDBC driver JARs to be in the local drivers/ directory.\n" + f"Please download the {dbms.upper()} JDBC driver and place it in:\n" + f" {local_jar_path.parent}/\n" + f"Download from official sources." + ) + + # Run PySpark ingestion + elapsed = run_pyspark_ingestion( + session, + jdbc_url, + dbms, + dbms_params, + source_type, + source_value, + target_table, + dbapi_params, + ) + + else: + raise ValueError(f"Unknown ingestion method: {method}") + + # Show target table info if configured + if config.SHOW_TARGET_TABLE_INFO: + try: + print(f"\n{'='*60}") + print("TARGET TABLE INFO") + print(f"{'='*60}") + + # Get row count + row_count = session.table(target_table).count() + print(f"Row count: {row_count}") + + # Show first row + print("\nFirst row:") + session.table(target_table).show(n=1) + + except Exception as info_error: + print(f"\n⚠ Warning: Could not retrieve table info: {info_error}") + + # Cleanup target table if configured + if config.CLEANUP_TARGET_TABLES: + try: + session.sql(f"DROP TABLE IF EXISTS {target_table}").collect() + print(f"\n✓ Cleaned up target table: {target_table}") + except Exception as cleanup_error: + print( + f"\n⚠ Warning: Could not clean up table {target_table}: {cleanup_error}" + ) + + return { + "dbms": dbms, + "method": method, + "source_type": source_type, + "source_value": source_value, + "target_table": target_table, + "elapsed_time": elapsed, + "dbapi_params": dbapi_params, + "status": "success", + } + + except Exception as e: + print(f"\n✗ ERROR: {str(e)}") + # Try cleanup even on failure if configured + if config.CLEANUP_TARGET_TABLES and "target_table" in locals(): + try: + session.sql(f"DROP TABLE IF EXISTS {target_table}").collect() + print(f"\n✓ Cleaned up target table: {target_table}") + except Exception: + pass # Silently ignore cleanup errors on failure + + return { + "dbms": dbms, + "method": method, + "source_type": source_type if "source_type" in locals() else None, + "source_value": source_value if "source_value" in locals() else None, + "dbapi_params": dbapi_params if "dbapi_params" in locals() else {}, + "elapsed_time": None, + "status": "failed", + "error": str(e), + } + + finally: + session.close() + + +def run_test_matrix(test_matrix): + """ + Run multiple tests from a test matrix. + + Args: + test_matrix: List of test configurations + + Returns: + List of test results + """ + results = [] + + print(f"\n{'='*60}") + print(f"RUNNING TEST MATRIX: {len(test_matrix)} tests") + print(f"{'='*60}") + + for i, test_config in enumerate(test_matrix, 1): + print(f"\n\nTest {i}/{len(test_matrix)}") + result = run_test(test_config) + results.append(result) + + # Print summary + print(f"\n\n{'='*110}") + print("TEST SUMMARY") + print(f"{'='*110}") + print( + f"{'Status':^8} {'DBMS':^12} {'Method':^15} {'Source':^8} {'Value':^20} {'Params':^25} {'Time':^10}" + ) + print("-" * 110) + + for result in results: + status_symbol = "✓" if result["status"] == "success" else "✗" + time_str = f"{result['elapsed_time']:.2f}s" if result["elapsed_time"] else "N/A" + source_type = result.get("source_type", "N/A") + source_value = result.get("source_value", "N/A") + dbapi_params = result.get("dbapi_params", {}) + + # Clean up multi-line queries and truncate + source_value = " ".join( + source_value.split() + ) # Collapse whitespace to single spaces + if len(source_value) > 20: + source_value = source_value[:17] + "..." + + # Format dbapi_params compactly + if dbapi_params: + params_str = json.dumps(dbapi_params, separators=(",", ":")) + if len(params_str) > 25: + params_str = params_str[:22] + "..." + else: + params_str = "{}" + + print( + f"{status_symbol:^8} {result['dbms']:^12} {result['method']:^15} {source_type:^8} {source_value:^20} {params_str:^25} {time_str:^10}" + ) + + successful = sum(1 for r in results if r["status"] == "success") + print("-" * 110) + print( + f"Total: {len(results)} | Success: {successful} | Failed: {len(results) - successful}" + ) + + # Export results to CSV if configured + if config.EXPORT_RESULTS_TO_CSV: + csv_path = export_results_to_csv(results) + print(f"\n✓ Results exported to: {csv_path}") + + return results + + +def export_results_to_csv(results, output_dir=None): + """ + Export test results to CSV file. + + Args: + results: List of test result dicts + output_dir: Output directory (defaults to current directory) + + Returns: + Path to CSV file + """ + if output_dir is None: + output_dir = Path(__file__).parent / "results" + else: + output_dir = Path(output_dir) + + # Create results directory if it doesn't exist + output_dir.mkdir(exist_ok=True) + + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + csv_filename = f"dbapi_test_results_{timestamp}.csv" + csv_path = output_dir / csv_filename + + # Define CSV columns + fieldnames = [ + "timestamp", + "dbms", + "ingestion_method", + "source_type", + "source_value", + "target_table", + "elapsed_time_seconds", + "dbapi_params", + "status", + "error", + ] + + # Write to CSV + with open(csv_path, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + + for result in results: + # Convert dbapi_params to JSON string + dbapi_params_json = json.dumps(result.get("dbapi_params", {})) + + writer.writerow( + { + "timestamp": timestamp, + "dbms": result.get("dbms", ""), + "ingestion_method": result.get("method", ""), + "source_type": result.get("source_type", ""), + "source_value": result.get("source_value", ""), + "target_table": result.get("target_table", ""), + "elapsed_time_seconds": result.get("elapsed_time", ""), + "dbapi_params": dbapi_params_json, + "status": result.get("status", ""), + "error": result.get("error", ""), + } + ) + + return str(csv_path) diff --git a/tests/perf/data_source/scripts/base_db_setup.py b/tests/perf/data_source/scripts/base_db_setup.py deleted file mode 100644 index 23b9fc6ed9..0000000000 --- a/tests/perf/data_source/scripts/base_db_setup.py +++ /dev/null @@ -1,98 +0,0 @@ -# -# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. -# - -import string -import random -import pytz -import datetime -from abc import ABC, abstractmethod - - -ONE_MILLION = 1_000_000 -TEN_MILLION = 10_000_000 -ONE_HUNDRED_MILLION = 100_000_000 - - -class TestDBABC(ABC): - @abstractmethod - def __init__(self) -> None: - self._connection = None - pass - - def __getstate__(self): - """Return state values to be pickled.""" - state = self.__dict__.copy() - # Don't pickle connection - state["_connection"] = None - return state - - @property - def connection(self): - """Lazy connection creation""" - if self._connection is None: - self._connection = self.create_connection() - return self._connection - - @abstractmethod - def create_table(self): - pass - - @staticmethod - def generate_random_string(length=10): - return "".join(random.choices(string.ascii_letters + string.digits, k=length)) - - @staticmethod - def random_datetime_with_timezone(): - # Generate a random datetime - naive_datetime = datetime.datetime( - 2024, - random.randint(1, 12), # Month - random.randint(1, 28), # Day - random.randint(0, 23), # Hour - random.randint(0, 59), # Minute - ) - # Assign a random timezone - random_timezone = pytz.timezone(random.choice(pytz.all_timezones)) - timezone_aware_datetime = random_timezone.localize(naive_datetime) - return timezone_aware_datetime - - @staticmethod - def generate_random_data(): - raise NotImplementedError - - @abstractmethod - def insert_data(self, num_rows=1_000_000, table_name=None): - pass - - def _insert_data_with_sql(self, insert_sql, num_rows=1_000_000): - full_batches = num_rows // self.insert_batch_size - remaining_rows = num_rows % self.insert_batch_size - - with self.connection.cursor() as cursor: - # Insert full batches - for i in range(full_batches): - batch_data = [ - self.generate_random_data() for _ in range(self.insert_batch_size) - ] - cursor.executemany(insert_sql, batch_data) - self.connection.commit() - print( - f"Inserted batch {i + 1} with {self.insert_batch_size} rows successfully." - ) - - # Insert any remaining rows - if remaining_rows > 0: - batch_data = [ - self.generate_random_data() for _ in range(remaining_rows) - ] - cursor.executemany(insert_sql, batch_data) - self.connection.commit() - print(f"Inserted final batch with {remaining_rows} rows successfully.") - - print(f"Inserted total of {num_rows} rows successfully.") - - def close_connection(self): - if self.connection: - self.connection.close() - print("Connection closed.") diff --git a/tests/perf/data_source/scripts/dbtest_config.py b/tests/perf/data_source/scripts/dbtest_config.py deleted file mode 100644 index 667d37cd5d..0000000000 --- a/tests/perf/data_source/scripts/dbtest_config.py +++ /dev/null @@ -1,255 +0,0 @@ -# -# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. -# - -from typing import Dict, Type, Union - -from oracle_resource_setup import TestOracleDB -from sql_server_resource_setup import TestSQLServerDB - - -class DatabaseTestConfig: - def __init__( - self, - db_class: Type, - connection_params: Dict = None, - insert_row_count: int = None, - existing_table: str = None, - dbapi_parameters: Dict = None, - ) -> None: - if insert_row_count and existing_table: - raise ValueError( - "insert_row_count and existing_table can not be used at the same time," - "when insert_row_count, a new table will be created" - ) - self.db_class = db_class - self.connection_params = connection_params or {} - self.insert_row_count = insert_row_count or 1_000_000 - self.existing_table = existing_table - self.dbapi_parameters = dbapi_parameters or {} - - -def create_oracle_config( - connection_params: Dict = None, - insert_row_count: int = None, - existing_table: str = None, - fetch_size: int = None, - column: str = None, - lower_bound: Union[str, int] = None, - upper_bound: Union[str, int] = None, - num_partitions: int = None, -) -> DatabaseTestConfig: - """ - Helper method to create Oracle test configuration with default values. - - Args: - connection_params: Optional connection parameters, will use defaults if not provided - insert_row_count: Number of rows to insert if creating new table - existing_table: Name of existing table to use - fetch_size: DBAPI fetch_size parameter - column: column to perform partition on - lower_bound: lower bound of partition - upper_bound: upper bound of partition - num_partitions: number of partitions - Note: - column, lower_bound, upper_bound and num_partitions must be set together - """ - default_connection = { - "username": "SYSTEM", - "password": "test", - "host": "localhost", - "port": 1521, - "service_name": "FREEPDB1", - } - dbapi_parameters = {} - if fetch_size is not None: - dbapi_parameters["fetch_size"] = fetch_size - if column is not None: - dbapi_parameters["column"] = column - if lower_bound is not None: - dbapi_parameters["lower_bound"] = lower_bound - if upper_bound is not None: - dbapi_parameters["upper_bound"] = upper_bound - if num_partitions is not None: - dbapi_parameters["num_partitions"] = num_partitions - - config = DatabaseTestConfig( - db_class=TestOracleDB, - connection_params=connection_params or default_connection, - insert_row_count=insert_row_count, - existing_table=existing_table, - dbapi_parameters=dbapi_parameters, - ) - return config - - -def create_sql_server_config( - connection_params: Dict = None, - insert_row_count: int = None, - existing_table: str = None, - fetch_size: int = None, - column: str = None, - lower_bound: Union[str, int] = None, - upper_bound: Union[str, int] = None, - num_partitions: int = None, -) -> DatabaseTestConfig: - """ - Helper method to create SQL Server test configuration with default values. - - Args: - connection_params: Optional connection parameters, will use defaults if not provided - insert_row_count: Number of rows to insert if creating new table - existing_table: Name of existing table to use - fetch_size: DBAPI fetch_size parameter - column: column to perform partition on - lower_bound: lower bound of partition - upper_bound: upper bound of partition - num_partitions: number of partitions - Note: - column, lower_bound, upper_bound and num_partitions must be set together - """ - if existing_table and insert_row_count: - raise ValueError( - "existing_table and insert_row_count can not be used at the same time," - "when insert_row_count, a new table will be created" - ) - default_connection = { - "host": "127.0.0.1", - "port": 1433, - "database": "msdb", - "username": "sa", - "password": "Test12345()", - } - dbapi_parameters = {} - if fetch_size: - dbapi_parameters["fetch_size"] = fetch_size - if column: - dbapi_parameters["column"] = column - if lower_bound: - dbapi_parameters["lower_bound"] = lower_bound - if upper_bound: - dbapi_parameters["upper_bound"] = upper_bound - if num_partitions: - dbapi_parameters["num_partitions"] = num_partitions - - config = DatabaseTestConfig( - db_class=TestSQLServerDB, - connection_params=connection_params or default_connection, - insert_row_count=insert_row_count, - existing_table=existing_table, - dbapi_parameters=dbapi_parameters, - ) - return config - - -def create_pyspark_session_config(driver_extra_class_path, master="local", **kwargs): - """ - class_path is dir where the java jar, snowflake jar, parquet avro target db driver jar is placed - """ - return { - "spark.driver.extraClassPath": str(driver_extra_class_path), - "spark.master": master, - **kwargs, - } - - -def create_jdbc_config( - jdbc_url, - user, - password, - driver, - fetch_size=None, - partition_column=None, - num_partitions=None, - lower_bound=None, - upper_bound=None, -): - config = { - "url": jdbc_url, - "user": user, - "password": password, - "driver": driver, - } - if fetch_size is not None: - config["fetchsize"] = fetch_size - if partition_column is not None: - config["partitionColumn"] = partition_column - if num_partitions is not None: - config["numPartitions"] = num_partitions - if lower_bound is not None: - config["lowerBound"] = lower_bound - if upper_bound is not None: - config["upperBound"] = upper_bound - return config - - -DEFAULT_ORACLE_CONFIGS = [ - create_oracle_config( - existing_table="ALL_TYPE_TABLE", - fetch_size=fetch_size, - column="id", - lower_bound=0, - upper_bound=1000000, - num_partitions=num_partitions, - ) - for fetch_size, num_partitions in [ - (0, 0), - (10000, 0), - (0, 10), - (10000, 10), - ] -] -DEFAULT_SQLSERVER_CONFIGS = [ - create_sql_server_config( - existing_table="ALL_TYPE_TABLE", - fetch_size=fetch_size, - column="id", - lower_bound=0, - upper_bound=1000000, - num_partitions=num_partitions, - ) - for fetch_size, num_partitions in [ - (0, 0), - (10000, 0), - (0, 10), - (10000, 10), - ] -] - - -DEFAULT_PYSPARK_CONFIG = create_pyspark_session_config( - driver_extra_class_path="./jdbc_drivers/*" -) - - -PARALLELISM_OPTIMIZED_PYSPARK_CONFIG = create_pyspark_session_config( - driver_extra_class_path="./jdbc_drivers/*", - num_partitions=10, - repartition_num=32, # this is not a spark config, but used in the test dataframe operations - **{ - "spark.sql.shuffle.partitions": 16, # CHANGE ME TO MATCH THE SPEC OF THE TEST MACHINE - "spark.default.parallelism": 16, # CHANGE ME TO MATCH THE SPEC OF THE TEST MACHINE - "spark.executor.cores": 8, # CHANGE ME TO MATCH THE SPEC OF THE TEST MACHINE - "spark.executor.memory": "16g", # CHANGE ME TO MATCH THE SPEC OF TESTTHE MACHINE - "spark.executor.instances": 1, - }, -) - -DEFAULT_ORACLE_JDBC_CONFIG = create_jdbc_config( - jdbc_url="jdbc:oracle:thin:@//localhost:1521/FREEPDB1", - user="SYSTEM", - password="test", - driver="oracle.jdbc.driver.OracleDriver", - fetch_size=1000, - partition_column="id", - lower_bound=0, - upper_bound=1000000, - num_partitions=0, -) - -DEFAULT_SQLSERVER_JDBC_CONFIG = create_jdbc_config( - jdbc_url="jdbc:sqlserver://127.0.0.1:1433;TrustServerCertificate=true;databaseName=msdb", - user="sa", - password="Test12345()", - driver="com.microsoft.sqlserver.jdbc.SQLServerDriver", -) diff --git a/tests/perf/data_source/scripts/e2etest.py b/tests/perf/data_source/scripts/e2etest.py deleted file mode 100644 index 404c60675e..0000000000 --- a/tests/perf/data_source/scripts/e2etest.py +++ /dev/null @@ -1,299 +0,0 @@ -# -# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. -# - -import time -from copy import copy -from typing import Dict, List, Tuple -from datetime import datetime -import pandas as pd -import logging - -from snowflake.snowpark import Session -from dbtest_config import ( - DatabaseTestConfig, - create_sql_server_config, - create_oracle_config, - DEFAULT_PYSPARK_CONFIG, - PARALLELISM_OPTIMIZED_PYSPARK_CONFIG, - DEFAULT_ORACLE_JDBC_CONFIG, - DEFAULT_SQLSERVER_JDBC_CONFIG, -) - -from parameters import SNOWFLAKE_CONNECTION - -# Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - - -class PerformanceMetrics: - def __init__(self) -> None: - self.start_time = None - self.read_time = 0 - self.write_time = 0 - self.total_time = 0 - - -class DBPerformanceTest: - def __init__(self, snowflake_connection: Dict) -> None: - self.snowflake_connection = snowflake_connection - self.results = [] - - def time_operation( - self, operation_name: str, func, *args, **kwargs - ) -> Tuple[float, any]: - """Execute operation and measure time""" - dbapi_parameters = kwargs.pop("dbapi_parameters", {}) - start = time.time() - result = func(*args, **kwargs, **dbapi_parameters) - duration = time.time() - start - logger.info(f"{operation_name} took {duration:.2f} seconds") - return duration, result - - def get_table_size(self, source_db, table_name: str) -> int: - """Get row count from existing table""" - with source_db.connection.cursor() as cursor: - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - return cursor.fetchone()[0] - - def run_test_for_db(self, config: DatabaseTestConfig) -> List[Dict]: - """Run performance tests for a specific database type""" - print( - f"Starting Database Test:\n" - f"- Database Class: {config.db_class.__name__}\n" - f"- Insert Row Count: {config.insert_row_count}\n" - if config.insert_row_count - else "" f"- Existing Table: {config.existing_table or 'None'}\n" - if config.existing_table - else "" f"- DBAPI Parameters: {config.dbapi_parameters or 'None'}" - ) - test_results = [] - - # Initialize source database - source_db = config.db_class(**config.connection_params) - - try: - # Determine table name and size - if config.existing_table: - table_name = config.existing_table - size = self.get_table_size(source_db, table_name) - logger.info(f"Using existing table {table_name} with {size} rows") - else: - size = config.insert_row_count - table_name = f"{source_db.TABLE_NAME}_{size}" - # Setup test data - source_db.create_table(table_name=table_name, replace=True) - source_db.insert_data(num_rows=size, table_name=table_name) - - logger.info( - f"\nRunning test with {size} rows for {config.db_class.__name__}" - ) - metrics = PerformanceMetrics() - - # Create Snowflake session - session = Session.builder.configs(self.snowflake_connection).create() - snowflake_table = f"{table_name}_PERF_TEST" - - try: - # Measure read.dbapi() performance - metrics.read_time, df = self.time_operation( - "read.dbapi", - session.read.dbapi, - source_db.create_connection, - table=table_name, - dbapi_parameters=config.dbapi_parameters, - ) - - # Measure write.save_as_table() performance - metrics.write_time, _ = self.time_operation( - "write.save_as_table", - df.write.save_as_table, - snowflake_table, - mode="overwrite", - ) - - # Validate data - self._validate_data(source_db, session, table_name, snowflake_table) - - # Calculate total time - metrics.total_time = metrics.read_time + metrics.write_time - - # Store results - test_results.append( - { - "timestamp": datetime.now(), - "database_type": config.db_class.__name__, - "dbapi_parameters": config.dbapi_parameters, - "data_size": size, - "fetch_size": config.dbapi_parameters.get("fetch_size", 0), - "num_partitions": config.dbapi_parameters.get( - "num_partitions", 0 - ), - "read_time": metrics.read_time, - "write_time": metrics.write_time, - "total_time": metrics.total_time, - } - ) - - finally: - session.close() - - finally: - source_db.close_connection() - - return test_results - - def run_test_for_pyspark_jdbc(self, spark_config: Dict, jdbc_config: Dict) -> None: - snowflake_table_name = "pyspark_dbapi_perf_test" - snowflake_session = Session.builder.configs(self.snowflake_connection).create() - snowflake_session.sql(f"drop table if exists {snowflake_table_name}").collect() - from pyspark.sql import SparkSession - - spark = SparkSession.builder.appName("PySparkJDBCTest") - for k, v in spark_config.items(): - spark = spark.config(k, v) - spark = spark.getOrCreate() - logger.info("Running PySpark JDBC performance test.") - metrics = PerformanceMetrics() - metrics.start_time = time.time() - - try: - logger.info("Reading data from JDBC source.") - df = spark.read.format("jdbc").options(**jdbc_config).load() - metrics.read_time = time.time() - metrics.start_time - # Writing data back to Snowflake - logger.info("Writing data back to Snowflake.") - write_start = time.time() - - if spark_config.get("repartition_num"): - df = df.repartition(spark_config.get("repartition_num")) - - ( - df.write.format("net.snowflake.spark.snowflake") - .option("sfUrl", self.snowflake_connection["host"]) - .option("sfUser", self.snowflake_connection["user"]) - .option("sfPassword", self.snowflake_connection["password"]) - .option("sfDatabase", self.snowflake_connection["database"]) - .option("sfSchema", self.snowflake_connection["schema"]) - .option("sfWarehouse", self.snowflake_connection["warehouse"]) - .option("use_parquet_in_write", "true") - .option("dbtable", snowflake_table_name) - .mode("overwrite") - .save() - ) - metrics.write_time = time.time() - write_start - metrics.total_time = time.time() - metrics.start_time - - except Exception as e: - logger.error(f"Error during PySpark JDBC test: {e}") - snowflake_session.sql( - f"drop table if exists {snowflake_table_name}" - ).collect() - raise - - logger.info(f"Test completed in {metrics.total_time} seconds.") - self.results.append( - { - "database_type": jdbc_config.get("driver"), - "data_size": df.count(), - "read_time": metrics.read_time, - "write_time": metrics.write_time, - "total_time": metrics.total_time, - } - ) - - def _validate_data( - self, - source_db, - session, - source_table: str, - target_table: str, - ): - """Validate data between source and Snowflake""" - # Get count from source - with source_db.connection.cursor() as cursor: - cursor.execute(f"SELECT COUNT(*) FROM {source_table}") - assert cursor.fetchone()[0] == session.table(target_table).count() - - def save_results(self, filename: str = "performance_results.csv"): - """Save test results to CSV""" - df = pd.DataFrame(self.results) - df.to_csv(filename, index=False) - logger.info(f"\nResults saved to {filename}") - - -def pyspark_perf_test(table_name="ALL_TYPE_TABLE"): - # here the assumption is that source database is already set up - perf_test = DBPerformanceTest(SNOWFLAKE_CONNECTION) - - oracle_jdbc_config = copy(DEFAULT_ORACLE_JDBC_CONFIG) - # SELECT to avoid driver unsupported types - oracle_jdbc_config[ - "dbtable" - ] = f"(SELECT ID,NUMBER_COL,BINARY_FLOAT_COL,BINARY_DOUBLE_COL,VARCHAR2_COL,CHAR_COL,CLOB_COL,NCHAR_COL,NVARCHAR2_COL,NCLOB_COL,DATE_COL,TIMESTAMP_COL,TIMESTAMP_TZ_COL,TO_CHAR(TIMESTAMP_LTZ_COL),BLOB_COL,RAW_COL,GUID_COL FROM {table_name})" - - sqlserver_jdbc_config = copy(DEFAULT_SQLSERVER_JDBC_CONFIG) - # SELECT to avoid driver unsupported types - sqlserver_jdbc_config[ - "dbtable" - ] = f"(SELECT id,bigint_col,bit_col,decimal_col,float_col,int_col,money_col,real_col,smallint_col,smallmoney_col,tinyint_col,numeric_col,date_col,datetime2_col,datetime_col,smalldatetime_col,time_col,char_col,text_col,varchar_col,nchar_col,ntext_col,nvarchar_col,binary_col,varbinary_col,image_col,uniqueidentifier_col,xml_col,sysname_col FROM {table_name}) as tmp" - - # CHANGE ME TO TEST DIFFERENT CONFIGURATIONS - pyspark_configs = [DEFAULT_PYSPARK_CONFIG, PARALLELISM_OPTIMIZED_PYSPARK_CONFIG] - jdbc_configs = [oracle_jdbc_config, sqlserver_jdbc_config] - - for pyspark_config in pyspark_configs: - for jdbc_config in jdbc_configs: - perf_test.run_test_for_pyspark_jdbc(pyspark_config, jdbc_config) - - # Save all results - perf_test.save_results("pyspark_performance_results.csv") - - # Print summary - print("\nTest Summary:") - for result in perf_test.results: - print( - f"\nDatabase: {result['database_type']}, Size: {result['data_size']} rows:" - ) - print(f"Read time: {result['read_time']:.2f} seconds") - print(f"Write time: {result['write_time']:.2f} seconds") - print(f"Total time: {result['total_time']:.2f} seconds") - - -def snowpark_perf_test(): - # Initialize and run tests - perf_test = DBPerformanceTest(SNOWFLAKE_CONNECTION) - - # SAMPLE - sql_sever_config = create_sql_server_config(insert_row_count=10000, fetch_size=1000) - oracle_config = create_oracle_config(insert_row_count=10000, fetch_size=1000) - - # UPDATE THIS - TEST_MATRIX = [sql_sever_config, oracle_config] - - # Run tests for each database type - for db_config in TEST_MATRIX: - results = perf_test.run_test_for_db(db_config) - perf_test.results.extend(results) - - # Save all results - perf_test.save_results() - - # Print summary - print("\nTest Summary:") - for result in perf_test.results: - print( - f"\nDatabase: {result['database_type']}, Size: {result['data_size']} rows:" - ) - print(f"Read time: {result['read_time']:.2f} seconds") - print(f"Write time: {result['write_time']:.2f} seconds") - print(f"Total time: {result['total_time']:.2f} seconds") - print(f"DBAPI parameters: {result['dbapi_parameters']}") - - -if __name__ == "__main__": - pyspark_perf_test() - snowpark_perf_test() diff --git a/tests/perf/data_source/scripts/jdbc_drivers/README.md b/tests/perf/data_source/scripts/jdbc_drivers/README.md deleted file mode 100644 index b7d316f18d..0000000000 --- a/tests/perf/data_source/scripts/jdbc_drivers/README.md +++ /dev/null @@ -1,18 +0,0 @@ -## download the following JARs and put into the jdbc_drivers folder: - -1. (required for mssql) mssql-jdbc-12.8.1.jre11.jar: https://mvnrepository.com/artifact/com.microsoft.sqlserver/mssql-jdbc/12.8.1.jre11 -2. (required for oracle) ojdbc11.jar: https://www.oracle.com/database/technologies/appdev/jdbc-downloads.html -3. parquet-avro-1.10.1.jar: https://mvnrepository.com/artifact/org.apache.parquet/parquet-avro/1.10.1 -4. snowflake-jdbc-3.19.0.jar: https://mvnrepository.com/artifact/net.snowflake/snowflake-jdbc/3.19.0 -5. spark-snowflake_2.12-3.1.0.jar: https://mvnrepository.com/artifact/net.snowflake/spark-snowflake_2.13/3.1.0 - -## known issues: - -### Oracle -1. spark-snowflake_2.12-3.1.1 doesn't work with oracledb blob type data, should be binary data but it want binary data -solution: spark-snowflake_2.12-3.1.0 works -2. spark can not handle oracle database timestamp with time zone data type - -### MSSQL -1. SQL_VARIANT type is not supported by spark when loading -2. date type gets warning LogicalTypes: Ignoring invalid logical type for name: date diff --git a/tests/perf/data_source/scripts/oracle_resource_setup.py b/tests/perf/data_source/scripts/oracle_resource_setup.py deleted file mode 100644 index 1316282131..0000000000 --- a/tests/perf/data_source/scripts/oracle_resource_setup.py +++ /dev/null @@ -1,151 +0,0 @@ -# -# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. -# - -import oracledb -import random -import datetime -import uuid - -from base_db_setup import TestDBABC, ONE_MILLION - - -class TestOracleDB(TestDBABC): - def __init__( - self, - username="SYSTEM", - password="test", - host="localhost", - port=1521, - service_name="FREEPDB1", - table_name="ALL_TYPE_TABLE", - insert_batch_size=10000, - ) -> None: - self.USERNAME = username - self.PASSWORD = password - self.HOST = host - self.PORT = port - self.SERVICE_NAME = service_name - self.TABLE_NAME = table_name - self.insert_batch_size = insert_batch_size - self._connection = self.create_connection() - - def create_connection(self): - dsn = f"{self.HOST}:{self.PORT}/{self.SERVICE_NAME}" - connection = oracledb.connect( - user=self.USERNAME, password=self.PASSWORD, dsn=dsn - ) - return connection - - def create_table(self, table_name=None, replace=True): - to_create_table = table_name or self.TABLE_NAME - with self.connection.cursor() as cursor: - if replace: - try: - cursor.execute(f"DROP TABLE {to_create_table}") - print("Table dropped successfully.") - except oracledb.DatabaseError: - pass - cursor.execute( - f""" - CREATE TABLE {to_create_table} ( - id NUMBER GENERATED AS IDENTITY PRIMARY KEY, - number_col NUMBER(10,2), - binary_float_col BINARY_FLOAT, - binary_double_col BINARY_DOUBLE, - varchar2_col VARCHAR2(50), - char_col CHAR(10), - clob_col CLOB, - nchar_col NCHAR(10), - nvarchar2_col NVARCHAR2(50), - nclob_col NCLOB, - date_col DATE, - timestamp_col TIMESTAMP, - timestamp_tz_col TIMESTAMP WITH TIME ZONE, - timestamp_ltz_col TIMESTAMP WITH LOCAL TIME ZONE, - blob_col BLOB, - raw_col RAW(16), - guid_col RAW(16) DEFAULT SYS_GUID() - ) - """ - ) - print("Table created successfully.") - - @staticmethod - def generate_random_data(): - return ( - round(random.uniform(1, 10000), 2), - random.uniform(1, 10000), - random.uniform(1, 10000), - TestOracleDB.generate_random_string(50), - TestOracleDB.generate_random_string(10).ljust(10), - TestOracleDB.generate_random_string(1000), # Simulating large CLOB text - TestOracleDB.generate_random_string(10).ljust(10), - TestOracleDB.generate_random_string(50), - TestOracleDB.generate_random_string(1000), # Simulating large NCLOB text - datetime.datetime( - 2024, random.randint(1, 12), random.randint(1, 28) - ).date(), - datetime.datetime( - 2024, - random.randint(1, 12), - random.randint(1, 28), - random.randint(0, 23), - random.randint(0, 59), - ), - TestOracleDB.random_datetime_with_timezone(), - TestOracleDB.random_datetime_with_timezone(), - bytes(random.getrandbits(8) for _ in range(16)), - uuid.uuid4().bytes, - ) - - def insert_data(self, num_rows=1_000_000, table_name=None): - to_insert_table = table_name or self.TABLE_NAME - # List of columns for better maintainability - columns = [ - "number_col", - "binary_float_col", - "binary_double_col", - "varchar2_col", - "char_col", - "clob_col", - "nchar_col", - "nvarchar2_col", - "nclob_col", - "date_col", - "timestamp_col", - "timestamp_tz_col", - "timestamp_ltz_col", - "blob_col", - "raw_col", - ] - - # Generate column names and placeholders dynamically - column_list = ", ".join(columns) - placeholders = ", ".join([f":{i + 1}" for i in range(len(columns))]) - - # Use f-string formatting for clarity - insert_sql = f""" - INSERT INTO {to_insert_table} ( - {column_list} - ) VALUES ( - {placeholders} - ) - """ - - self._insert_data_with_sql(insert_sql, num_rows) - - -if __name__ == "__main__": - # for setup - test = TestOracleDB() - table_name = "ALL_TYPE_TABLE" - test.create_table(table_name=table_name, replace=True) - test.insert_data(ONE_MILLION, table_name=table_name) - ret = ( - test.connection.cursor() - .execute(f"select count(*) from {table_name}") - .fetchall() - ) - print(ret) - test.close_connection() diff --git a/tests/perf/data_source/scripts/sql_server_resource_setup.py b/tests/perf/data_source/scripts/sql_server_resource_setup.py deleted file mode 100644 index 73795c5ec5..0000000000 --- a/tests/perf/data_source/scripts/sql_server_resource_setup.py +++ /dev/null @@ -1,261 +0,0 @@ -# -# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. -# - -# UNSUPPORTED TYPES in pyodbc: -# 1. datetimeoffset_col DATETIMEOFFSET -# 2. geography_col GEOGRAPHY -# 3. geometry_col GEOMETRY -# workaround: https://github.com/mkleehammer/pyodbc/wiki/Using-an-Output-Converter-function -# TODO: SNOW-1945100 - -import pyodbc -import random -import datetime - -from base_db_setup import TestDBABC, ONE_MILLION - - -class TestSQLServerDB(TestDBABC): - def __init__( - self, - *, - host="127.0.0.1", - port=1433, - database="msdb", - username="sa", - password="Test12345()", - table_name="ALL_TYPE_TABLE", - insert_batch_size=10000, - ) -> None: - self.HOST = host - self.PORT = port - self.DATABASE = database - self.USERNAME = username - self.PASSWORD = password - self.TABLE_NAME = table_name - self.insert_batch_size = insert_batch_size or 10000 - self._connection = self.create_connection() - - def create_connection(self): - connection_str = ( - f"DRIVER={{ODBC Driver 18 for SQL Server}};" - f"SERVER={self.HOST},{self.PORT};" - f"DATABASE={self.DATABASE};" - f"UID={self.USERNAME};" - f"PWD={self.PASSWORD};" - "TrustServerCertificate=yes" - ) - connection = pyodbc.connect(connection_str) - return connection - - def create_table(self, table_name=None, replace=True): - to_create_table = table_name or self.TABLE_NAME - with self.connection.cursor() as cursor: - if replace: - try: - cursor.execute(f"DROP TABLE IF EXISTS {to_create_table}") - print("Table dropped successfully.") - except pyodbc.DatabaseError: - pass - cursor.execute( - f""" - CREATE TABLE {to_create_table} ( - id INT IDENTITY(1,1) PRIMARY KEY, - bigint_col BIGINT, - bit_col BIT, - decimal_col DECIMAL(18, 2), - float_col FLOAT, - int_col INT, - money_col MONEY, - real_col REAL, - smallint_col SMALLINT, - smallmoney_col SMALLMONEY, - tinyint_col TINYINT, - numeric_col NUMERIC(18, 2), - date_col DATE, - datetime2_col DATETIME2, - datetime_col DATETIME, - smalldatetime_col SMALLDATETIME, --- datetimeoffset_col DATETIMEOFFSET, - time_col TIME, - char_col CHAR(10), - text_col TEXT, - varchar_col VARCHAR(50), - nchar_col NCHAR(10), - ntext_col NTEXT, - nvarchar_col NVARCHAR(50), - binary_col BINARY(16), - varbinary_col VARBINARY(16), - image_col IMAGE, - sql_variant_col SQL_VARIANT, --- geography_col GEOGRAPHY, --- geometry_col GEOMETRY, - uniqueidentifier_col UNIQUEIDENTIFIER DEFAULT NEWID(), - xml_col XML, - sysname_col SYSNAME - ) - """ - ) - print("Table created successfully.") - - @staticmethod - def generate_random_data(): - return ( - random.randint(1, 1e18), # bigint - random.choice([0, 1]), # bit - round(random.uniform(1, 10000), 2), # decimal - random.uniform(1, 10000), # float - random.randint(1, 10000), # int - round(random.uniform(1, 10000), 2), # money - random.uniform(1, 10000), # real - random.randint(1, 32767), # smallint - round(random.uniform(1, 1000), 2), # smallmoney - random.randint(0, 255), # tinyint - round(random.uniform(1, 10000), 2), # numeric - datetime.date.today(), # date - datetime.datetime.now(), # datetime2 - datetime.datetime.now(), # datetime - # datetime.datetime.now(datetime.timezone.utc), # datetimeoffset - datetime.datetime.now(), # smalldatetime - datetime.datetime.now().time(), # time - TestSQLServerDB.generate_random_string(10).ljust(10), # char - TestSQLServerDB.generate_random_string(1000), # text - TestSQLServerDB.generate_random_string(50), # varchar - TestSQLServerDB.generate_random_string(10).ljust(10), # nchar - TestSQLServerDB.generate_random_string(1000), # ntext - TestSQLServerDB.generate_random_string(50), # nvarchar - bytes(random.getrandbits(8) for _ in range(16)), # binary - bytes(random.getrandbits(8) for _ in range(16)), # varbinary - bytes(random.getrandbits(8) for _ in range(16)), # image - bytes( - random.getrandbits(8) for _ in range(16) - ), # sql_variant (using uniqueidentifier as a mock variant) - # 'POINT(1 1)', # geography (mock string for simplicity) - # 'POINT(1 1)', # geometry (mock string for simplicity) - bytes(random.getrandbits(8) for _ in range(16)), # uniqueidentifier - "Test", # xml - "sysname_test", # sysname - ) - - def insert_null_data(self, num_rows=1, table_name=None): - to_insert_table = table_name or self.TABLE_NAME - # Define the column names as a list for better maintainability - columns = [ - "bigint_col", - "bit_col", - "decimal_col", - "float_col", - "int_col", - "money_col", - "real_col", - "smallint_col", - "smallmoney_col", - "tinyint_col", - "numeric_col", - "date_col", - "datetime2_col", - "datetime_col", - "smalldatetime_col", - # "datetimeoffset_col", - "time_col", - "char_col", - "text_col", - "varchar_col", - "nchar_col", - "ntext_col", - "nvarchar_col", - "binary_col", - "varbinary_col", - "image_col", - "sql_variant_col", - # "geography_col", - # "geometry_col", - "uniqueidentifier_col", - "xml_col", - "sysname_col", - ] - - # Dynamically construct the column string and placeholders - column_str = ", ".join(columns) - placeholders = ", ".join(["?"] * len(columns)) - - # Construct the SQL statement dynamically - insert_sql = f""" - INSERT INTO {to_insert_table} ( - {column_str} - ) VALUES ({placeholders}) - """ - with self.connection.cursor() as cursor: - for _ in range(num_rows): - # Generate a tuple with None for each column - null_data = [None] * (len(columns) - 1) - null_data.append("sysname_test") # sysname_col does not allow null - cursor.execute(insert_sql, tuple(null_data)) - self.connection.commit() - print(f"Inserted {num_rows} rows with NULL values successfully.") - - def insert_data(self, num_rows=1_000_000, table_name=None): - to_insert_table = table_name or self.TABLE_NAME - # Define the column names as a list for better maintainability - columns = [ - "bigint_col", - "bit_col", - "decimal_col", - "float_col", - "int_col", - "money_col", - "real_col", - "smallint_col", - "smallmoney_col", - "tinyint_col", - "numeric_col", - "date_col", - "datetime2_col", - "datetime_col", - "smalldatetime_col", - # "datetimeoffset_col", - "time_col", - "char_col", - "text_col", - "varchar_col", - "nchar_col", - "ntext_col", - "nvarchar_col", - "binary_col", - "varbinary_col", - "image_col", - "sql_variant_col", - # "geography_col", - # "geometry_col", - "uniqueidentifier_col", - "xml_col", - "sysname_col", - ] - - # Dynamically construct the column string and placeholders - column_str = ", ".join(columns) - placeholders = ", ".join(["?"] * len(columns)) - - # Construct the SQL statement dynamically - insert_sql = f""" - INSERT INTO {to_insert_table} ( - {column_str} - ) VALUES ({placeholders}) - """ - self._insert_data_with_sql(insert_sql, num_rows) - - -if __name__ == "__main__": - # for setup - test = TestSQLServerDB() - table_name = "ALL_TYPE_TABLE" - test.create_table(table_name=table_name, replace=True) - test.insert_data(ONE_MILLION, table_name=table_name) - ret = ( - test.connection.cursor() - .execute(f"select count(*) from {table_name}") - .fetchall() - ) - print(ret) - test.close_connection()