|
3 | 3 | Supports both traditional ODBC connections and passwordless Azure AD authentication |
4 | 4 | """ |
5 | 5 |
|
| 6 | +import logging |
6 | 7 | import os |
7 | 8 | import struct |
8 | | -import pyodbc |
9 | | -import logging |
10 | | -from typing import Optional, List, Dict, Any |
11 | 9 | from contextlib import contextmanager |
| 10 | +from typing import Any, Dict, List, Optional |
| 11 | + |
| 12 | +import pyodbc |
12 | 13 | from azure.identity import DefaultAzureCredential |
| 14 | +from azure.keyvault.secrets import SecretClient |
13 | 15 |
|
14 | 16 | # Configure logging |
15 | 17 | logger = logging.getLogger(__name__) |
@@ -73,26 +75,89 @@ def from_env(cls) -> 'SqlHelper': |
73 | 75 | - SQL_USERNAME |
74 | 76 | - SQL_PASSWORD |
75 | 77 | """ |
| 78 | + key_vault_name = os.environ.get("KEY_VAULT_NAME") |
| 79 | + secret_name = os.environ.get("SECRET_NAME") |
| 80 | + |
| 81 | + if key_vault_name and secret_name: |
| 82 | + return cls.from_key_vault(key_vault_name, secret_name) |
| 83 | + |
76 | 84 | client_id = os.environ.get("AZURE_CLIENT_ID") |
77 | 85 | client_secret = os.environ.get("AZURE_CLIENT_SECRET") |
78 | 86 | tenant_id = os.environ.get("AZURE_TENANT_ID") |
79 | 87 | server = os.environ.get("SQL_SERVER") |
80 | 88 | database = os.environ.get("SQL_DATABASE") |
81 | 89 | username = os.environ.get("SQL_USERNAME") |
82 | 90 | password = os.environ.get("SQL_PASSWORD") |
83 | | - |
| 91 | + |
84 | 92 | if not any([client_id, client_secret, tenant_id, server, database, username, password]): |
85 | 93 | raise ValueError("You properly need to define environment variables.") |
86 | | - |
| 94 | + |
87 | 95 | logger.info("Environment variables loaded successfully") |
88 | | - |
| 96 | + |
89 | 97 | return cls( |
90 | 98 | server=server, |
91 | 99 | database=database, |
92 | 100 | username=username, |
93 | 101 | password=password, |
94 | 102 | use_azure_credential=all([client_id, client_secret, tenant_id]) |
95 | 103 | ) |
| 104 | + |
| 105 | + @classmethod |
| 106 | + def from_key_vault(cls, vault_name: str, secret_name: str) -> 'SqlHelper': |
| 107 | + """ |
| 108 | + Create a SqlHelper instance by reading the connection string from Azure Key Vault. |
| 109 | + |
| 110 | + """ |
| 111 | + vault_url = f"https://{vault_name}.vault.azure.net" |
| 112 | + credential = DefaultAzureCredential(exclude_interactive_browser_credential=False) |
| 113 | + client = SecretClient(vault_url=vault_url, credential=credential) |
| 114 | + |
| 115 | + logger.info(f"Retrieving secret [{secret_name}] from Key Vault [{vault_name}]...") |
| 116 | + secret = client.get_secret(secret_name) |
| 117 | + |
| 118 | + if not secret.value: |
| 119 | + raise ValueError(f"Secret [{secret_name}] in Key Vault [{vault_name}] has no value") |
| 120 | + |
| 121 | + logger.info(f"Secret [{secret_name}] retrieved successfully from Key Vault [{vault_name}]") |
| 122 | + return cls.from_connection_string(secret.value) |
| 123 | + |
| 124 | + @classmethod |
| 125 | + def from_connection_string(cls, connection_string: str) -> 'SqlHelper': |
| 126 | + """ |
| 127 | + Create a SqlHelper instance from a connection string. |
| 128 | + |
| 129 | + This is useful when the connection string is stored in an environment variable |
| 130 | + (e.g., resolved by Azure App Service from Key Vault via @Microsoft.KeyVault(SecretUri=...)). |
| 131 | + |
| 132 | + """ |
| 133 | + parts = {} |
| 134 | + for part in connection_string.split(';'): |
| 135 | + if '=' in part: |
| 136 | + key, value = part.split('=', 1) |
| 137 | + parts[key.strip()] = value.strip() |
| 138 | + |
| 139 | + server = parts.get('Server', '').replace('tcp:', '').replace(',1433', '') |
| 140 | + database = parts.get('Database') |
| 141 | + username = parts.get('User ID') |
| 142 | + password = parts.get('Password') |
| 143 | + |
| 144 | + if not all([server, database, username, password]): |
| 145 | + raise ValueError( |
| 146 | + f"Could not parse all required parameters from connection string. " |
| 147 | + f"Found - Server: {bool(server)}, Database: {bool(database)}, " |
| 148 | + f"Username: {bool(username)}, Password: {bool(password)}" |
| 149 | + ) |
| 150 | + |
| 151 | + logger.info("Connection string parsed successfully") |
| 152 | + logger.info(f"Server: {server}, Database: {database}, Username: {username}") |
| 153 | + |
| 154 | + return cls( |
| 155 | + server=server, |
| 156 | + database=database, |
| 157 | + username=username, |
| 158 | + password=password, |
| 159 | + use_azure_credential=False |
| 160 | + ) |
96 | 161 |
|
97 | 162 | def _build_connection_string(self) -> str: |
98 | 163 | """Build the ODBC connection string.""" |
|
0 commit comments