-
Notifications
You must be signed in to change notification settings - Fork 140
Expand file tree
/
Copy pathtest_telemetry_e2e.py
More file actions
354 lines (289 loc) · 14.7 KB
/
test_telemetry_e2e.py
File metadata and controls
354 lines (289 loc) · 14.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
E2E test for telemetry - verifies telemetry behavior with different scenarios
"""
import time
import threading
import logging
from contextlib import contextmanager
from unittest.mock import patch
import pytest
from concurrent.futures import wait
import databricks.sql as sql
from databricks.sql.telemetry.telemetry_client import (
TelemetryClient,
TelemetryClientFactory,
)
log = logging.getLogger(__name__)
class TelemetryTestBase:
"""Simplified test base class for telemetry e2e tests"""
@pytest.fixture(autouse=True)
def get_details(self, connection_details):
self.arguments = connection_details.copy()
def connection_params(self):
return {
"server_hostname": self.arguments["host"],
"http_path": self.arguments["http_path"],
"access_token": self.arguments.get("access_token"),
}
@contextmanager
def connection(self, extra_params=()):
connection_params = dict(self.connection_params(), **dict(extra_params))
log.info("Connecting with args: {}".format(connection_params))
conn = sql.connect(**connection_params)
try:
yield conn
finally:
conn.close()
@pytest.mark.serial
class TestTelemetryE2E(TelemetryTestBase):
"""E2E tests for telemetry scenarios - must run serially due to shared host-level telemetry client"""
@pytest.fixture(autouse=True)
def telemetry_setup_teardown(self):
"""Clean up telemetry client state before and after each test"""
try:
yield
finally:
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._executor = None
TelemetryClientFactory._stop_flush_thread()
TelemetryClientFactory._initialized = False
# Clear feature flags cache to prevent state leakage between tests
from databricks.sql.common.feature_flag import FeatureFlagsContextFactory
with FeatureFlagsContextFactory._lock:
FeatureFlagsContextFactory._context_map.clear()
if FeatureFlagsContextFactory._executor:
FeatureFlagsContextFactory._executor.shutdown(wait=False)
FeatureFlagsContextFactory._executor = None
@pytest.fixture
def telemetry_interceptors(self):
"""Setup reusable telemetry interceptors as a fixture"""
capture_lock = threading.Lock()
captured_events = []
captured_futures = []
original_export = TelemetryClient._export_event
original_callback = TelemetryClient._telemetry_request_callback
def export_wrapper(self_client, event):
with capture_lock:
captured_events.append(event)
return original_export(self_client, event)
def callback_wrapper(self_client, future, sent_count):
with capture_lock:
captured_futures.append(future)
original_callback(self_client, future, sent_count)
return captured_events, captured_futures, export_wrapper, callback_wrapper
# ==================== ASSERTION HELPERS ====================
def assert_system_config(self, event):
"""Assert system configuration fields"""
sys_config = event.entry.sql_driver_log.system_configuration
assert sys_config is not None
# Check all required fields are non-empty
for field in ['driver_name', 'driver_version', 'os_name', 'os_version',
'os_arch', 'runtime_name', 'runtime_version', 'runtime_vendor',
'locale_name', 'char_set_encoding']:
value = getattr(sys_config, field)
assert value and len(value) > 0, f"{field} should not be None or empty"
assert sys_config.driver_name == "Databricks SQL Python Connector"
def assert_connection_params(self, event, expected_http_path=None):
"""Assert connection parameters"""
conn_params = event.entry.sql_driver_log.driver_connection_params
assert conn_params is not None
assert conn_params.http_path
assert conn_params.host_info is not None
assert conn_params.auth_mech is not None
if expected_http_path:
assert conn_params.http_path == expected_http_path
if conn_params.socket_timeout is not None:
assert conn_params.socket_timeout > 0
def assert_statement_execution(self, event):
"""Assert statement execution details"""
sql_op = event.entry.sql_driver_log.sql_operation
assert sql_op is not None
assert sql_op.statement_type is not None
assert sql_op.execution_result is not None
assert hasattr(sql_op, "retry_count")
if sql_op.retry_count is not None:
assert sql_op.retry_count >= 0
latency = event.entry.sql_driver_log.operation_latency_ms
assert latency is not None and latency >= 0
def assert_error_info(self, event, expected_error_name=None):
"""Assert error information"""
error_info = event.entry.sql_driver_log.error_info
assert error_info is not None
assert error_info.error_name and len(error_info.error_name) > 0
assert error_info.stack_trace and len(error_info.stack_trace) > 0
if expected_error_name:
assert error_info.error_name == expected_error_name
def verify_events(self, captured_events, captured_futures, expected_count):
"""Common verification for event count and HTTP responses"""
if expected_count == 0:
assert len(captured_events) == 0, f"Expected 0 events, got {len(captured_events)}"
assert len(captured_futures) == 0, f"Expected 0 responses, got {len(captured_futures)}"
else:
assert len(captured_events) == expected_count, \
f"Expected {expected_count} events, got {len(captured_events)}"
time.sleep(2)
done, _ = wait(captured_futures, timeout=10)
assert len(done) == expected_count, \
f"Expected {expected_count} responses, got {len(done)}"
for future in done:
response = future.result()
assert 200 <= response.status < 300
# Assert common fields for all events
for event in captured_events:
self.assert_system_config(event)
self.assert_connection_params(event, self.arguments["http_path"])
# ==================== PARAMETERIZED TESTS ====================
@pytest.mark.parametrize("enable_telemetry,force_enable,expected_count,test_id", [
(True, False, 2, "enable_on_force_off"),
(False, True, 2, "enable_off_force_on"),
(False, False, 0, "both_off"),
(None, None, 2, "default_behavior"),
])
def test_telemetry_flags(self, telemetry_interceptors, enable_telemetry,
force_enable, expected_count, test_id):
"""Test telemetry behavior with different flag combinations"""
captured_events, captured_futures, export_wrapper, callback_wrapper = \
telemetry_interceptors
with patch.object(TelemetryClient, "_export_event", export_wrapper), \
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
extra_params = {"telemetry_batch_size": 1}
if enable_telemetry is not None:
extra_params["enable_telemetry"] = enable_telemetry
if force_enable is not None:
extra_params["force_enable_telemetry"] = force_enable
with self.connection(extra_params=extra_params) as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
# Give time for async telemetry submission after connection closes
time.sleep(0.5)
self.verify_events(captured_events, captured_futures, expected_count)
# Assert statement execution on latency event (if events exist)
if expected_count > 0:
self.assert_statement_execution(captured_events[-1])
@pytest.mark.parametrize("query,expected_error", [
("SELECT * FROM WHERE INVALID SYNTAX 12345", "ServerOperationError"),
("SELECT * FROM non_existent_table_xyz_12345", None),
])
def test_sql_errors(self, telemetry_interceptors, query, expected_error):
"""Test telemetry captures error information for different SQL errors"""
captured_events, captured_futures, export_wrapper, callback_wrapper = \
telemetry_interceptors
with patch.object(TelemetryClient, "_export_event", export_wrapper), \
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
with self.connection(extra_params={
"force_enable_telemetry": True,
"telemetry_batch_size": 1,
}) as conn:
with conn.cursor() as cursor:
with pytest.raises(Exception):
cursor.execute(query)
cursor.fetchone()
time.sleep(2)
wait(captured_futures, timeout=10)
assert len(captured_events) >= 1
# Find event with error_info
error_event = next((e for e in captured_events
if e.entry.sql_driver_log.error_info), None)
assert error_event is not None
self.assert_system_config(error_event)
self.assert_connection_params(error_event, self.arguments["http_path"])
self.assert_error_info(error_event, expected_error)
def test_metadata_operation(self, telemetry_interceptors):
"""Test telemetry for metadata operations (getCatalogs)"""
captured_events, captured_futures, export_wrapper, callback_wrapper = \
telemetry_interceptors
with patch.object(TelemetryClient, "_export_event", export_wrapper), \
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
with self.connection(extra_params={
"force_enable_telemetry": True,
"telemetry_batch_size": 1,
}) as conn:
with conn.cursor() as cursor:
catalogs = cursor.catalogs()
catalogs.fetchall()
time.sleep(2)
wait(captured_futures, timeout=10)
assert len(captured_events) >= 1
for event in captured_events:
self.assert_system_config(event)
self.assert_connection_params(event, self.arguments["http_path"])
def test_direct_results(self, telemetry_interceptors):
"""Test telemetry with direct results (use_cloud_fetch=False)"""
captured_events, captured_futures, export_wrapper, callback_wrapper = \
telemetry_interceptors
with patch.object(TelemetryClient, "_export_event", export_wrapper), \
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
with self.connection(extra_params={
"force_enable_telemetry": True,
"telemetry_batch_size": 1,
"use_cloud_fetch": False,
}) as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 100")
result = cursor.fetchall()
assert len(result) == 1 and result[0][0] == 100
time.sleep(2)
wait(captured_futures, timeout=10)
assert len(captured_events) >= 2
for event in captured_events:
self.assert_system_config(event)
self.assert_connection_params(event, self.arguments["http_path"])
self.assert_statement_execution(captured_events[-1])
@pytest.mark.parametrize("close_type", [
"context_manager",
"explicit_cursor",
"explicit_connection",
"implicit_fetchall",
])
def test_cloudfetch_with_different_close_patterns(self, telemetry_interceptors,
close_type):
"""Test telemetry with cloud fetch using different resource closing patterns"""
captured_events, captured_futures, export_wrapper, callback_wrapper = \
telemetry_interceptors
with patch.object(TelemetryClient, "_export_event", export_wrapper), \
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
if close_type == "explicit_connection":
# Test explicit connection close
conn = sql.connect(
**self.connection_params(),
force_enable_telemetry=True,
telemetry_batch_size=1,
use_cloud_fetch=True,
)
cursor = conn.cursor()
cursor.execute("SELECT * FROM range(1000)")
result = cursor.fetchall()
assert len(result) == 1000
conn.close()
else:
# Other patterns use connection context manager
with self.connection(extra_params={
"force_enable_telemetry": True,
"telemetry_batch_size": 1,
"use_cloud_fetch": True,
}) as conn:
if close_type == "context_manager":
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM range(1000)")
result = cursor.fetchall()
assert len(result) == 1000
elif close_type == "explicit_cursor":
cursor = conn.cursor()
cursor.execute("SELECT * FROM range(1000)")
result = cursor.fetchall()
assert len(result) == 1000
cursor.close()
elif close_type == "implicit_fetchall":
cursor = conn.cursor()
cursor.execute("SELECT * FROM range(1000)")
result = cursor.fetchall()
assert len(result) == 1000
time.sleep(2)
wait(captured_futures, timeout=10)
assert len(captured_events) >= 2
for event in captured_events:
self.assert_system_config(event)
self.assert_connection_params(event, self.arguments["http_path"])
self.assert_statement_execution(captured_events[-1])