Skip to content

Commit 8911264

Browse files
committed
feat(e2e): add streaming endpoint E2E tests
Add E2E tests for Server-Sent Events (SSE) streaming functionality: - StreamEvent creation and serialization - Batch result streaming - Heartbeat generation - QueryStreamer configuration - SSE response structure - Integration tests with realistic data - Error scenario tests Implements Issue #43 Phase 3.
1 parent f1f8ba9 commit 8911264

1 file changed

Lines changed: 384 additions & 0 deletions

File tree

tests/e2e/test_streaming_e2e.py

Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
"""
2+
Streaming endpoint E2E tests for Arctic Text2SQL Agent.
3+
4+
These tests verify the Server-Sent Events (SSE) streaming functionality
5+
for query generation and result streaming.
6+
"""
7+
8+
import asyncio
9+
import contextlib
10+
import json
11+
12+
import pytest
13+
14+
from app.streaming import (
15+
QueryStreamer,
16+
StreamEvent,
17+
StreamEventType,
18+
create_sse_response,
19+
heartbeat_generator,
20+
stream_results,
21+
)
22+
from tests.e2e.seed_data import ExpectedValues
23+
24+
# =============================================================================
25+
# StreamEvent Tests
26+
# =============================================================================
27+
28+
29+
@pytest.mark.e2e
30+
@pytest.mark.e2e_streaming
31+
class TestStreamEvent:
32+
"""Tests for StreamEvent creation and serialization."""
33+
34+
def test_event_creation(self) -> None:
35+
"""Test StreamEvent can be created with all event types."""
36+
for event_type in StreamEventType:
37+
event = StreamEvent(
38+
event_type=event_type,
39+
data={"test": "data"},
40+
)
41+
assert event.event_type == event_type
42+
assert event.data == {"test": "data"}
43+
44+
def test_event_to_dict(self) -> None:
45+
"""Test StreamEvent serializes to dict correctly."""
46+
event = StreamEvent(
47+
event_type=StreamEventType.QUERY_START,
48+
data={"query": "test query"},
49+
timestamp=1234567890.0,
50+
)
51+
52+
result = event.to_dict()
53+
54+
assert result["event"] == "query_start"
55+
assert result["data"] == {"query": "test query"}
56+
assert result["timestamp"] == 1234567890.0
57+
58+
def test_event_to_json(self) -> None:
59+
"""Test StreamEvent serializes to JSON correctly."""
60+
event = StreamEvent(
61+
event_type=StreamEventType.SQL_GENERATED,
62+
data={"sql": "SELECT * FROM users"},
63+
)
64+
65+
json_str = event.to_json()
66+
parsed = json.loads(json_str)
67+
68+
assert parsed["event"] == "sql_generated"
69+
assert parsed["data"]["sql"] == "SELECT * FROM users"
70+
assert "timestamp" in parsed
71+
72+
73+
# =============================================================================
74+
# Stream Results Tests
75+
# =============================================================================
76+
77+
78+
@pytest.mark.e2e
79+
@pytest.mark.e2e_streaming
80+
class TestStreamResults:
81+
"""Tests for result streaming functionality."""
82+
83+
@pytest.mark.asyncio
84+
async def test_stream_empty_results(self) -> None:
85+
"""Test streaming empty result set."""
86+
events: list[StreamEvent] = []
87+
88+
async for event in stream_results([], batch_size=10):
89+
events.append(event)
90+
91+
# Should emit complete event even for empty results
92+
assert len(events) == 1
93+
assert events[0].event_type == StreamEventType.RESULT_COMPLETE
94+
assert events[0].data["total_rows"] == 0
95+
96+
@pytest.mark.asyncio
97+
async def test_stream_single_batch(self) -> None:
98+
"""Test streaming results in single batch."""
99+
results = [{"id": i, "name": f"item_{i}"} for i in range(5)]
100+
events: list[StreamEvent] = []
101+
102+
async for event in stream_results(results, batch_size=10):
103+
events.append(event)
104+
105+
# One batch + complete event
106+
assert len(events) == 2
107+
108+
# First should be batch
109+
assert events[0].event_type == StreamEventType.RESULT_BATCH
110+
assert events[0].data["batch_start"] == 0
111+
assert events[0].data["batch_end"] == 5
112+
assert len(events[0].data["rows"]) == 5
113+
114+
# Last should be complete
115+
assert events[1].event_type == StreamEventType.RESULT_COMPLETE
116+
assert events[1].data["total_rows"] == 5
117+
118+
@pytest.mark.asyncio
119+
async def test_stream_multiple_batches(self) -> None:
120+
"""Test streaming results in multiple batches."""
121+
results = [{"id": i} for i in range(25)]
122+
events: list[StreamEvent] = []
123+
124+
async for event in stream_results(results, batch_size=10):
125+
events.append(event)
126+
127+
# 3 batches + complete event
128+
assert len(events) == 4
129+
130+
# Verify batch boundaries
131+
batch_events = [
132+
e for e in events if e.event_type == StreamEventType.RESULT_BATCH
133+
]
134+
assert len(batch_events) == 3
135+
136+
assert batch_events[0].data["batch_start"] == 0
137+
assert batch_events[0].data["batch_end"] == 10
138+
139+
assert batch_events[1].data["batch_start"] == 10
140+
assert batch_events[1].data["batch_end"] == 20
141+
142+
assert batch_events[2].data["batch_start"] == 20
143+
assert batch_events[2].data["batch_end"] == 25
144+
145+
@pytest.mark.asyncio
146+
async def test_stream_preserves_row_data(self) -> None:
147+
"""Test streaming preserves complete row data."""
148+
original_row = {
149+
"id": 1,
150+
"name": "Test Customer",
151+
"email": "test@example.com",
152+
"amount": 99.99,
153+
"nested": {"key": "value"},
154+
}
155+
results = [original_row]
156+
157+
events: list[StreamEvent] = []
158+
async for event in stream_results(results, batch_size=10):
159+
events.append(event)
160+
161+
batch_event = events[0]
162+
streamed_row = batch_event.data["rows"][0]
163+
164+
assert streamed_row == original_row
165+
166+
167+
# =============================================================================
168+
# Heartbeat Tests
169+
# =============================================================================
170+
171+
172+
@pytest.mark.e2e
173+
@pytest.mark.e2e_streaming
174+
class TestHeartbeat:
175+
"""Tests for heartbeat generation."""
176+
177+
@pytest.mark.asyncio
178+
async def test_heartbeat_generator(self) -> None:
179+
"""Test heartbeat generator produces events at interval."""
180+
events: list[StreamEvent] = []
181+
182+
async def collect_heartbeats() -> None:
183+
count = 0
184+
async for event in heartbeat_generator(interval=0.1):
185+
events.append(event)
186+
count += 1
187+
if count >= 3:
188+
break
189+
190+
# Run with timeout
191+
with contextlib.suppress(asyncio.TimeoutError):
192+
await asyncio.wait_for(collect_heartbeats(), timeout=1.0)
193+
194+
assert len(events) >= 2
195+
for event in events:
196+
assert event.event_type == StreamEventType.HEARTBEAT
197+
assert "timestamp" in event.data
198+
199+
200+
# =============================================================================
201+
# QueryStreamer Tests
202+
# =============================================================================
203+
204+
205+
@pytest.mark.e2e
206+
@pytest.mark.e2e_streaming
207+
class TestQueryStreamer:
208+
"""Tests for QueryStreamer class."""
209+
210+
def test_streamer_initialization(self) -> None:
211+
"""Test QueryStreamer initializes with correct defaults."""
212+
streamer = QueryStreamer()
213+
assert streamer._batch_size == 100
214+
assert streamer._heartbeat_interval == 15.0
215+
216+
def test_streamer_custom_config(self) -> None:
217+
"""Test QueryStreamer accepts custom configuration."""
218+
streamer = QueryStreamer(batch_size=50, heartbeat_interval=5.0)
219+
assert streamer._batch_size == 50
220+
assert streamer._heartbeat_interval == 5.0
221+
222+
223+
# =============================================================================
224+
# SSE Response Tests
225+
# =============================================================================
226+
227+
228+
@pytest.mark.e2e
229+
@pytest.mark.e2e_streaming
230+
class TestSSEResponse:
231+
"""Tests for SSE response creation."""
232+
233+
@pytest.mark.asyncio
234+
async def test_create_sse_response_structure(self) -> None:
235+
"""Test SSE response has correct structure."""
236+
237+
async def simple_generator():
238+
yield StreamEvent(
239+
event_type=StreamEventType.QUERY_START,
240+
data={"query": "test"},
241+
)
242+
yield StreamEvent(
243+
event_type=StreamEventType.QUERY_COMPLETE,
244+
data={"total_time_ms": 100},
245+
)
246+
247+
response = create_sse_response(simple_generator())
248+
249+
# EventSourceResponse should be created
250+
assert response is not None
251+
assert hasattr(response, "body_iterator")
252+
253+
254+
# =============================================================================
255+
# Integration Tests with Real Data
256+
# =============================================================================
257+
258+
259+
@pytest.mark.e2e
260+
@pytest.mark.e2e_streaming
261+
class TestStreamingIntegration:
262+
"""Integration tests for streaming with realistic data."""
263+
264+
@pytest.mark.asyncio
265+
async def test_stream_customer_results(self) -> None:
266+
"""Test streaming customer-like result data."""
267+
# Simulate customer query results
268+
customers = [
269+
{
270+
"id": i,
271+
"name": f"Customer {i}",
272+
"email": f"customer{i}@example.com",
273+
"state": ["California", "New York", "Texas"][i % 3],
274+
}
275+
for i in range(ExpectedValues.TOTAL_CUSTOMERS)
276+
]
277+
278+
events: list[StreamEvent] = []
279+
async for event in stream_results(customers, batch_size=5):
280+
events.append(event)
281+
282+
# Should have batches + complete
283+
batch_events = [
284+
e for e in events if e.event_type == StreamEventType.RESULT_BATCH
285+
]
286+
complete_event = next(
287+
(e for e in events if e.event_type == StreamEventType.RESULT_COMPLETE),
288+
None,
289+
)
290+
291+
assert len(batch_events) == 2 # 10 customers / 5 batch_size = 2 batches
292+
assert complete_event is not None
293+
assert complete_event.data["total_rows"] == ExpectedValues.TOTAL_CUSTOMERS
294+
295+
@pytest.mark.asyncio
296+
async def test_stream_order_results(self) -> None:
297+
"""Test streaming order-like result data."""
298+
# Simulate order query results
299+
orders = [
300+
{
301+
"id": i,
302+
"customer_id": (i % 10) + 1,
303+
"amount": float(100 + i * 10),
304+
"status": ["pending", "completed", "shipped"][i % 3],
305+
}
306+
for i in range(ExpectedValues.TOTAL_ORDERS)
307+
]
308+
309+
events: list[StreamEvent] = []
310+
async for event in stream_results(orders, batch_size=10):
311+
events.append(event)
312+
313+
complete_event = next(
314+
(e for e in events if e.event_type == StreamEventType.RESULT_COMPLETE),
315+
None,
316+
)
317+
318+
assert complete_event is not None
319+
assert complete_event.data["total_rows"] == ExpectedValues.TOTAL_ORDERS
320+
321+
@pytest.mark.asyncio
322+
async def test_stream_aggregation_results(self) -> None:
323+
"""Test streaming aggregation results."""
324+
# Simulate GROUP BY results
325+
aggregated = [
326+
{"status": "completed", "count": 8, "total_amount": 2500.00},
327+
{"status": "pending", "count": 4, "total_amount": 1200.00},
328+
{"status": "shipped", "count": 3, "total_amount": 800.00},
329+
]
330+
331+
events: list[StreamEvent] = []
332+
async for event in stream_results(aggregated, batch_size=10):
333+
events.append(event)
334+
335+
# Small result set should be single batch
336+
assert len(events) == 2 # 1 batch + complete
337+
338+
batch_event = events[0]
339+
assert batch_event.event_type == StreamEventType.RESULT_BATCH
340+
assert len(batch_event.data["rows"]) == 3
341+
342+
343+
# =============================================================================
344+
# Error Scenario Tests
345+
# =============================================================================
346+
347+
348+
@pytest.mark.e2e
349+
@pytest.mark.e2e_streaming
350+
class TestStreamingErrors:
351+
"""Tests for streaming error scenarios."""
352+
353+
def test_query_error_event(self) -> None:
354+
"""Test query error event structure."""
355+
event = StreamEvent(
356+
event_type=StreamEventType.QUERY_ERROR,
357+
data={
358+
"error": "Connection timeout",
359+
"error_type": "DatabaseConnectionError",
360+
"stage": "execution",
361+
},
362+
)
363+
364+
result = event.to_dict()
365+
366+
assert result["event"] == "query_error"
367+
assert result["data"]["error"] == "Connection timeout"
368+
assert result["data"]["error_type"] == "DatabaseConnectionError"
369+
370+
def test_all_event_types_serializable(self) -> None:
371+
"""Test all event types can be serialized to JSON."""
372+
for event_type in StreamEventType:
373+
event = StreamEvent(
374+
event_type=event_type,
375+
data={"test_key": "test_value", "number": 123, "nested": {"a": 1}},
376+
)
377+
378+
# Should not raise
379+
json_str = event.to_json()
380+
parsed = json.loads(json_str)
381+
382+
assert parsed["event"] == event_type.value
383+
assert "data" in parsed
384+
assert "timestamp" in parsed

0 commit comments

Comments
 (0)