22
33from __future__ import annotations
44
5- import asyncio
65from typing import TYPE_CHECKING
76from unittest .mock import AsyncMock , MagicMock , patch
87
1716
1817
1918@pytest .mark .asyncio
20- async def test_diagnose_failed_job () -> None :
19+ @patch ("adaptive_scheduler._server_support.llm_manager.ChatOpenAI" )
20+ async def test_diagnose_failed_job (mock_chat_openai : MagicMock ) -> None :
2121 """Test that the diagnose_failed_job method returns a diagnosis."""
22+ mock_llm = mock_chat_openai .return_value
23+ mock_llm .agenerate = AsyncMock (
24+ return_value = MagicMock (
25+ generations = [[MagicMock (text = "diagnosis" )]],
26+ ),
27+ )
28+
2229 llm_manager = LLMManager ()
2330 job_id = "test_job"
24- with patch ("adaptive_scheduler._server_support.llm_manager.aiofiles.open" ) as mock_open :
31+ with patch (
32+ "adaptive_scheduler._server_support.llm_manager.aiofiles.open" ,
33+ ) as mock_open :
2534 mock_open .return_value .__aenter__ .return_value .read .return_value = (
2635 "This is a log file with an error."
2736 )
2837 diagnosis = await llm_manager .diagnose_failed_job (job_id )
29- assert isinstance ( diagnosis , str )
38+ assert diagnosis == "diagnosis"
3039
3140
3241@pytest .mark .asyncio
33- async def test_chat () -> None :
42+ @patch ("adaptive_scheduler._server_support.llm_manager.ChatOpenAI" )
43+ async def test_chat (mock_chat_openai : MagicMock ) -> None :
3444 """Test that the chat method returns a response."""
45+ mock_llm = mock_chat_openai .return_value
46+ mock_llm .agenerate = AsyncMock (
47+ return_value = MagicMock (
48+ generations = [[MagicMock (text = "response" )]],
49+ ),
50+ )
3551 llm_manager = LLMManager ()
3652 message = "Hello, world!"
3753 response = await llm_manager .chat (message )
38- assert isinstance ( response , str )
54+ assert response == "response"
3955
4056
4157@pytest .fixture
42- def llm_manager () -> LLMManager :
43- """An LLMManager instance."""
58+ @patch ("adaptive_scheduler._server_support.llm_manager.ChatOpenAI" )
59+ def llm_manager (mock_chat_openai : MagicMock ) -> LLMManager : # noqa: ARG001
60+ """An LLMManager instance with a mocked ChatOpenAI."""
4461 return LLMManager ()
4562
4663
4764@pytest .fixture
48- def run_manager () -> RunManager :
65+ @patch ("adaptive_scheduler._server_support.llm_manager.ChatOpenAI" )
66+ def run_manager (mock_chat_openai : MagicMock ) -> RunManager : # noqa: ARG001
4967 """A RunManager instance with a mocked scheduler."""
5068 scheduler = MagicMock ()
5169 learners = [MagicMock ()]
@@ -103,25 +121,26 @@ async def test_job_manager_diagnoses_failed_job_async(
103121async def test_llm_manager_cache (llm_manager : LLMManager ) -> None :
104122 """Test that the LLMManager caches diagnoses."""
105123 job_id = "test_job"
106- with (
107- patch .object (
108- llm_manager ,
109- "_read_log_file" ,
110- return_value = "log content" ,
124+ llm_manager .llm .agenerate = AsyncMock (
125+ return_value = MagicMock (
126+ generations = [[MagicMock (text = "diagnosis" )]],
111127 ),
112- patch . object (
113- llm_manager ,
114- "_simulate_llm_call" ,
115- return_value = "diagnosis " ,
116- ) as mock_llm_call ,
128+ )
129+ with patch . object (
130+ llm_manager ,
131+ "_read_log_file " ,
132+ return_value = "log content" ,
117133 ):
118134 await llm_manager .diagnose_failed_job (job_id )
119135 await llm_manager .diagnose_failed_job (job_id )
120- mock_llm_call .assert_called_once ()
136+ llm_manager . llm . agenerate .assert_called_once ()
121137
122138
123139@pytest .mark .asyncio
124- async def test_diagnose_failed_job_file_not_found () -> None :
140+ @patch ("adaptive_scheduler._server_support.llm_manager.ChatOpenAI" )
141+ async def test_diagnose_failed_job_file_not_found (
142+ mock_chat_openai : MagicMock , # noqa: ARG001
143+ ) -> None :
125144 """Test that the diagnose_failed_job method handles a missing log file."""
126145 llm_manager = LLMManager ()
127146 job_id = "test_job"
@@ -136,13 +155,22 @@ async def test_diagnose_failed_job_file_not_found() -> None:
136155@pytest .mark .asyncio
137156async def test_chat_history (llm_manager : LLMManager ) -> None :
138157 """Test that the chat history is maintained."""
158+ llm_manager .llm .agenerate = AsyncMock (
159+ return_value = MagicMock (
160+ generations = [[MagicMock (text = "response" )]],
161+ ),
162+ )
139163 await llm_manager .chat ("Hello" )
140164 await llm_manager .chat ("How are you?" )
141165 assert len (llm_manager ._chat_history ) == 4
142166
143167
144168@pytest .mark .asyncio
145- async def test_read_log_file (tmp_path : Path ) -> None :
169+ @patch ("adaptive_scheduler._server_support.llm_manager.ChatOpenAI" )
170+ async def test_read_log_file (
171+ mock_chat_openai : MagicMock , # noqa: ARG001
172+ tmp_path : Path ,
173+ ) -> None :
146174 """Test that the _read_log_file method reads a file asynchronously."""
147175 llm_manager = LLMManager ()
148176 log_content = "This is a test log file."
@@ -151,70 +179,3 @@ async def test_read_log_file(tmp_path: Path) -> None:
151179
152180 read_content = await llm_manager ._read_log_file (str (log_file ))
153181 assert read_content == log_content
154-
155-
156- @pytest .mark .asyncio
157- async def test_chat_widget_callbacks (run_manager : RunManager ) -> None :
158- """Test the async callbacks of the chat_widget."""
159- from adaptive_scheduler .widgets import chat_widget
160-
161- # Mock the chat_widget dependencies
162- with (
163- patch ("ipywidgets.Text" ) as mock_text ,
164- patch ("ipywidgets.Textarea" ) as mock_textarea ,
165- patch ("ipywidgets.Dropdown" ) as mock_dropdown ,
166- patch ("ipywidgets.VBox" ),
167- patch ("adaptive_scheduler.widgets._add_title" ),
168- ):
169- # Create instances of the mocked widgets
170- text_input = mock_text .return_value
171- chat_history = mock_textarea .return_value
172- failed_job_dropdown = mock_dropdown .return_value
173-
174- # Call the widget function to set up the callbacks
175- chat_widget (run_manager )
176-
177- # --- Test on_submit ---
178- # Get the on_submit wrapper from the mock
179- on_submit_wrapper = text_input .on_submit .call_args [0 ][0 ]
180-
181- # Get a reference to the value mock before it's changed
182- chat_history_value_mock = chat_history .value
183- chat_history_value_mock .__iadd__ .return_value = chat_history_value_mock
184-
185- # Mock the chat method
186- with patch .object (
187- run_manager .llm_manager ,
188- "chat" ,
189- return_value = "Test response" ,
190- ) as mock_chat :
191- # Simulate the submission
192- text_input .value = "Test message"
193- on_submit_wrapper (text_input )
194- await asyncio .sleep (0 ) # allow the task to run
195-
196- # Assertions
197- mock_chat .assert_called_once_with ("Test message" )
198- iadd_calls = chat_history_value_mock .__iadd__ .call_args_list
199- assert "You: Test message" in iadd_calls [0 ].args [0 ]
200- assert "LLM: Test response" in iadd_calls [1 ].args [0 ]
201-
202- # --- Test on_failed_job_change ---
203- # Get the on_failed_job_change wrapper from the mock
204- on_failed_job_change_wrapper = failed_job_dropdown .observe .call_args [0 ][0 ]
205-
206- # Mock the diagnose_failed_job method
207- with patch .object (
208- run_manager .llm_manager ,
209- "diagnose_failed_job" ,
210- return_value = "Test diagnosis" ,
211- ) as mock_diagnose :
212- # Simulate the dropdown change
213- change = {"new" : "test_job_id" }
214- on_failed_job_change_wrapper (change )
215- await asyncio .sleep (0 ) # allow the task to run
216-
217- # Assertions
218- mock_diagnose .assert_called_once_with ("test_job_id" )
219- assert "Diagnosis for job test_job_id" in chat_history .value
220- assert "Test diagnosis" in chat_history .value
0 commit comments