1+ """Test script to verify async LLM calls are non-blocking"""
2+
3+ import asyncio
4+ import time
5+ from unittest .mock import AsyncMock , MagicMock
6+ from stagehand .llm .client import LLMClient
7+ from stagehand .llm .inference import observe , extract
8+
9+
10+ async def simulate_slow_llm_response (delay = 1.0 ):
11+ """Simulate a slow LLM API response"""
12+ await asyncio .sleep (delay )
13+ return MagicMock (
14+ usage = MagicMock (prompt_tokens = 100 , completion_tokens = 50 ),
15+ choices = [MagicMock (message = MagicMock (content = '{"elements": []}' ))]
16+ )
17+
18+
19+ async def test_parallel_execution ():
20+ """Test that multiple LLM calls can run in parallel"""
21+ print ("\n 🧪 Testing parallel async execution..." )
22+
23+ # Create mock LLM client
24+ mock_logger = MagicMock ()
25+ mock_logger .info = MagicMock ()
26+ mock_logger .debug = MagicMock ()
27+ mock_logger .error = MagicMock ()
28+
29+ llm_client = LLMClient (
30+ stagehand_logger = mock_logger ,
31+ default_model = "gpt-4o"
32+ )
33+
34+ # Mock the async create_response to simulate delay
35+ async def mock_create_response (** kwargs ):
36+ return await simulate_slow_llm_response (1.0 )
37+
38+ llm_client .create_response = mock_create_response
39+
40+ # Measure time for parallel execution
41+ start_time = time .time ()
42+
43+ # Run 3 observe calls in parallel
44+ tasks = [
45+ observe ("Find button 1" , "DOM content 1" , llm_client , logger = mock_logger ),
46+ observe ("Find button 2" , "DOM content 2" , llm_client , logger = mock_logger ),
47+ observe ("Find button 3" , "DOM content 3" , llm_client , logger = mock_logger ),
48+ ]
49+
50+ results = await asyncio .gather (* tasks )
51+ parallel_time = time .time () - start_time
52+
53+ print (f"✅ Parallel execution of 3 calls took: { parallel_time :.2f} s" )
54+ print (f" Expected ~1s (running in parallel), not 3s (sequential)" )
55+
56+ # Verify results
57+ assert len (results ) == 3
58+ for i , result in enumerate (results , 1 ):
59+ assert "elements" in result
60+ print (f" Result { i } : { result } " )
61+
62+ # Test sequential execution for comparison
63+ print ("\n 🧪 Testing sequential execution for comparison..." )
64+ start_time = time .time ()
65+
66+ result1 = await observe ("Find button 1" , "DOM content 1" , llm_client , logger = mock_logger )
67+ result2 = await observe ("Find button 2" , "DOM content 2" , llm_client , logger = mock_logger )
68+ result3 = await observe ("Find button 3" , "DOM content 3" , llm_client , logger = mock_logger )
69+
70+ sequential_time = time .time () - start_time
71+ print (f"✅ Sequential execution of 3 calls took: { sequential_time :.2f} s" )
72+ print (f" Expected ~3s (running sequentially)" )
73+
74+ # Parallel should be significantly faster
75+ assert parallel_time < sequential_time * 0.5 , "Parallel execution should be much faster than sequential"
76+
77+ print (f"\n 🎉 Async implementation is working correctly!" )
78+ print (f" Parallel speedup: { sequential_time / parallel_time :.2f} x faster" )
79+
80+
81+ async def test_real_llm_async ():
82+ """Test with real LiteLLM to ensure the async implementation works"""
83+ print ("\n 🧪 Testing with real LiteLLM (using mock responses)..." )
84+
85+ import litellm
86+ from unittest .mock import patch
87+
88+ # Mock litellm.acompletion to return test data
89+ async def mock_acompletion (** kwargs ):
90+ await asyncio .sleep (0.1 ) # Small delay to simulate API call
91+ return MagicMock (
92+ usage = MagicMock (prompt_tokens = 100 , completion_tokens = 50 ),
93+ choices = [MagicMock (message = MagicMock (content = '{"elements": [{"selector": "#test"}]}' ))]
94+ )
95+
96+ with patch ('litellm.acompletion' , new = mock_acompletion ):
97+ mock_logger = MagicMock ()
98+ mock_logger .info = MagicMock ()
99+ mock_logger .debug = MagicMock ()
100+ mock_logger .error = MagicMock ()
101+
102+ llm_client = LLMClient (
103+ stagehand_logger = mock_logger ,
104+ default_model = "gpt-4o"
105+ )
106+
107+ # Test that the actual async call works
108+ response = await llm_client .create_response (
109+ messages = [{"role" : "user" , "content" : "test" }],
110+ model = "gpt-4o"
111+ )
112+
113+ assert response is not None
114+ print (f"✅ Real LiteLLM async call successful" )
115+ print (f" Response: { response .choices [0 ].message .content } " )
116+
117+
118+ async def main ():
119+ """Run all tests"""
120+ print ("=" * 50 )
121+ print ("ASYNC IMPLEMENTATION VERIFICATION" )
122+ print ("=" * 50 )
123+
124+ try :
125+ await test_parallel_execution ()
126+ await test_real_llm_async ()
127+
128+ print ("\n " + "=" * 50 )
129+ print ("✅ ALL TESTS PASSED - ASYNC IS WORKING!" )
130+ print ("=" * 50 )
131+
132+ except Exception as e :
133+ print (f"\n ❌ Test failed: { e } " )
134+ raise
135+
136+
137+ if __name__ == "__main__" :
138+ asyncio .run (main ())
0 commit comments