Skip to content

Commit 49bbb87

Browse files
committed
.
1 parent 880ecbe commit 49bbb87

3 files changed

Lines changed: 86 additions & 120 deletions

File tree

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import random
4+
from typing import TYPE_CHECKING
55

66
import aiofiles
7+
from langchain.chat_models import ChatOpenAI
8+
from langchain.schema import HumanMessage, SystemMessage
79

810
from .base_manager import BaseManager
911

12+
if TYPE_CHECKING:
13+
from langchain.schema import BaseMessage
14+
1015

1116
class LLMManager(BaseManager):
1217
"""A manager for handling interactions with a language model."""
1318

14-
def __init__(self) -> None:
19+
def __init__(self, model_name: str = "gpt-4") -> None:
1520
super().__init__()
21+
self.llm = ChatOpenAI(model_name=model_name)
1622
self._diagnoses_cache: dict[str, str] = {}
17-
self._chat_history: list[dict[str, str]] = []
23+
self._chat_history: list[BaseMessage] = []
1824

1925
async def _manage(self) -> None:
2026
"""The main loop for the manager."""
@@ -47,31 +53,24 @@ async def diagnose_failed_job(self, job_id: str) -> str:
4753
log_content = await self._read_log_file(log_path)
4854
if log_content == "Log file not found.":
4955
return log_content
50-
# In a real implementation, this would be an API call to an LLM
51-
diagnosis = await self._simulate_llm_call(
52-
f"Analyze the following log and determine the cause of failure:\n\n{log_content}",
53-
)
56+
57+
messages: list[BaseMessage] = [
58+
SystemMessage(
59+
content="You are a helpful assistant that analyzes job failure logs.",
60+
),
61+
HumanMessage(
62+
content=f"Analyze the following log and determine the cause of failure:\n\n{log_content}",
63+
),
64+
]
65+
response = await self.llm.agenerate([messages])
66+
diagnosis = response.generations[0][0].text
5467
self._diagnoses_cache[job_id] = diagnosis
5568
return diagnosis
5669

5770
async def chat(self, message: str) -> str:
5871
"""Handles a chat message and returns a response."""
59-
self._chat_history.append({"role": "user", "content": message})
60-
# In a real implementation, this would be an API call to an LLM
61-
response = await self._simulate_llm_call(str(self._chat_history))
62-
self._chat_history.append({"role": "assistant", "content": response})
63-
return response
64-
65-
async def _simulate_llm_call(self, prompt: str) -> str: # noqa: ARG002
66-
"""Simulates a call to a language model."""
67-
await asyncio.sleep(
68-
random.uniform(0.1, 0.5), # noqa: S311
69-
) # Simulate network latency
70-
responses = [
71-
"It seems like there was a `FileNotFoundError`. Check if the input files are correctly specified.",
72-
"The job failed due to a `MemoryError`. Try requesting more memory for your job.",
73-
"I see a `ValueError` in the logs. It seems like an invalid argument was passed to a function.",
74-
"The simulation diverged. You might want to adjust the simulation parameters.",
75-
"I'm not sure what went wrong. Could you provide more details?",
76-
]
77-
return random.choice(responses) # noqa: S311
72+
self._chat_history.append(HumanMessage(content=message))
73+
response = await self.llm.agenerate([self._chat_history])
74+
result = response.generations[0][0].text
75+
self._chat_history.append(SystemMessage(content=result))
76+
return result

adaptive_scheduler/widgets.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -959,9 +959,12 @@ async def on_submit(sender: ipyw.Text) -> None:
959959
sender.value = ""
960960
if run_manager.llm_manager is None:
961961
return
962-
response = await run_manager.llm_manager.chat(message)
963-
chat_history.value += f"You: {message}\n"
964-
chat_history.value += f"LLM: {response}\n"
962+
try:
963+
response = await run_manager.llm_manager.chat(message)
964+
chat_history.value += f"You: {message}\n"
965+
chat_history.value += f"LLM: {response}\n"
966+
except Exception as e: # noqa: BLE001
967+
chat_history.value += f"Error: {e}\n"
965968

966969
def on_submit_wrapper(sender: ipyw.Text) -> None:
967970
task = asyncio.create_task(on_submit(sender))
@@ -986,8 +989,11 @@ async def on_failed_job_change(change: dict[str, Any]) -> None:
986989
job_id = change["new"]
987990
if run_manager.llm_manager is None:
988991
return
989-
diagnosis = await run_manager.llm_manager.diagnose_failed_job(job_id)
990-
chat_history.value = f"Diagnosis for job {job_id}:\n{diagnosis}\n"
992+
try:
993+
diagnosis = await run_manager.llm_manager.diagnose_failed_job(job_id)
994+
chat_history.value = f"Diagnosis for job {job_id}:\n{diagnosis}\n"
995+
except Exception as e: # noqa: BLE001
996+
chat_history.value = f"Error: {e}\n"
991997

992998
def on_failed_job_change_wrapper(change: dict[str, Any]) -> None:
993999
task = asyncio.create_task(on_failed_job_change(change))

tests/test_llm_manager.py

Lines changed: 50 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import asyncio
65
from typing import TYPE_CHECKING
76
from unittest.mock import AsyncMock, MagicMock, patch
87

@@ -17,35 +16,54 @@
1716

1817

1918
@pytest.mark.asyncio
20-
async def test_diagnose_failed_job() -> None:
19+
@patch("adaptive_scheduler._server_support.llm_manager.ChatOpenAI")
20+
async def test_diagnose_failed_job(mock_chat_openai: MagicMock) -> None:
2121
"""Test that the diagnose_failed_job method returns a diagnosis."""
22+
mock_llm = mock_chat_openai.return_value
23+
mock_llm.agenerate = AsyncMock(
24+
return_value=MagicMock(
25+
generations=[[MagicMock(text="diagnosis")]],
26+
),
27+
)
28+
2229
llm_manager = LLMManager()
2330
job_id = "test_job"
24-
with patch("adaptive_scheduler._server_support.llm_manager.aiofiles.open") as mock_open:
31+
with patch(
32+
"adaptive_scheduler._server_support.llm_manager.aiofiles.open",
33+
) as mock_open:
2534
mock_open.return_value.__aenter__.return_value.read.return_value = (
2635
"This is a log file with an error."
2736
)
2837
diagnosis = await llm_manager.diagnose_failed_job(job_id)
29-
assert isinstance(diagnosis, str)
38+
assert diagnosis == "diagnosis"
3039

3140

3241
@pytest.mark.asyncio
33-
async def test_chat() -> None:
42+
@patch("adaptive_scheduler._server_support.llm_manager.ChatOpenAI")
43+
async def test_chat(mock_chat_openai: MagicMock) -> None:
3444
"""Test that the chat method returns a response."""
45+
mock_llm = mock_chat_openai.return_value
46+
mock_llm.agenerate = AsyncMock(
47+
return_value=MagicMock(
48+
generations=[[MagicMock(text="response")]],
49+
),
50+
)
3551
llm_manager = LLMManager()
3652
message = "Hello, world!"
3753
response = await llm_manager.chat(message)
38-
assert isinstance(response, str)
54+
assert response == "response"
3955

4056

4157
@pytest.fixture
42-
def llm_manager() -> LLMManager:
43-
"""An LLMManager instance."""
58+
@patch("adaptive_scheduler._server_support.llm_manager.ChatOpenAI")
59+
def llm_manager(mock_chat_openai: MagicMock) -> LLMManager: # noqa: ARG001
60+
"""An LLMManager instance with a mocked ChatOpenAI."""
4461
return LLMManager()
4562

4663

4764
@pytest.fixture
48-
def run_manager() -> RunManager:
65+
@patch("adaptive_scheduler._server_support.llm_manager.ChatOpenAI")
66+
def run_manager(mock_chat_openai: MagicMock) -> RunManager: # noqa: ARG001
4967
"""A RunManager instance with a mocked scheduler."""
5068
scheduler = MagicMock()
5169
learners = [MagicMock()]
@@ -103,25 +121,26 @@ async def test_job_manager_diagnoses_failed_job_async(
103121
async def test_llm_manager_cache(llm_manager: LLMManager) -> None:
104122
"""Test that the LLMManager caches diagnoses."""
105123
job_id = "test_job"
106-
with (
107-
patch.object(
108-
llm_manager,
109-
"_read_log_file",
110-
return_value="log content",
124+
llm_manager.llm.agenerate = AsyncMock(
125+
return_value=MagicMock(
126+
generations=[[MagicMock(text="diagnosis")]],
111127
),
112-
patch.object(
113-
llm_manager,
114-
"_simulate_llm_call",
115-
return_value="diagnosis",
116-
) as mock_llm_call,
128+
)
129+
with patch.object(
130+
llm_manager,
131+
"_read_log_file",
132+
return_value="log content",
117133
):
118134
await llm_manager.diagnose_failed_job(job_id)
119135
await llm_manager.diagnose_failed_job(job_id)
120-
mock_llm_call.assert_called_once()
136+
llm_manager.llm.agenerate.assert_called_once()
121137

122138

123139
@pytest.mark.asyncio
124-
async def test_diagnose_failed_job_file_not_found() -> None:
140+
@patch("adaptive_scheduler._server_support.llm_manager.ChatOpenAI")
141+
async def test_diagnose_failed_job_file_not_found(
142+
mock_chat_openai: MagicMock, # noqa: ARG001
143+
) -> None:
125144
"""Test that the diagnose_failed_job method handles a missing log file."""
126145
llm_manager = LLMManager()
127146
job_id = "test_job"
@@ -136,13 +155,22 @@ async def test_diagnose_failed_job_file_not_found() -> None:
136155
@pytest.mark.asyncio
137156
async def test_chat_history(llm_manager: LLMManager) -> None:
138157
"""Test that the chat history is maintained."""
158+
llm_manager.llm.agenerate = AsyncMock(
159+
return_value=MagicMock(
160+
generations=[[MagicMock(text="response")]],
161+
),
162+
)
139163
await llm_manager.chat("Hello")
140164
await llm_manager.chat("How are you?")
141165
assert len(llm_manager._chat_history) == 4
142166

143167

144168
@pytest.mark.asyncio
145-
async def test_read_log_file(tmp_path: Path) -> None:
169+
@patch("adaptive_scheduler._server_support.llm_manager.ChatOpenAI")
170+
async def test_read_log_file(
171+
mock_chat_openai: MagicMock, # noqa: ARG001
172+
tmp_path: Path,
173+
) -> None:
146174
"""Test that the _read_log_file method reads a file asynchronously."""
147175
llm_manager = LLMManager()
148176
log_content = "This is a test log file."
@@ -151,70 +179,3 @@ async def test_read_log_file(tmp_path: Path) -> None:
151179

152180
read_content = await llm_manager._read_log_file(str(log_file))
153181
assert read_content == log_content
154-
155-
156-
@pytest.mark.asyncio
157-
async def test_chat_widget_callbacks(run_manager: RunManager) -> None:
158-
"""Test the async callbacks of the chat_widget."""
159-
from adaptive_scheduler.widgets import chat_widget
160-
161-
# Mock the chat_widget dependencies
162-
with (
163-
patch("ipywidgets.Text") as mock_text,
164-
patch("ipywidgets.Textarea") as mock_textarea,
165-
patch("ipywidgets.Dropdown") as mock_dropdown,
166-
patch("ipywidgets.VBox"),
167-
patch("adaptive_scheduler.widgets._add_title"),
168-
):
169-
# Create instances of the mocked widgets
170-
text_input = mock_text.return_value
171-
chat_history = mock_textarea.return_value
172-
failed_job_dropdown = mock_dropdown.return_value
173-
174-
# Call the widget function to set up the callbacks
175-
chat_widget(run_manager)
176-
177-
# --- Test on_submit ---
178-
# Get the on_submit wrapper from the mock
179-
on_submit_wrapper = text_input.on_submit.call_args[0][0]
180-
181-
# Get a reference to the value mock before it's changed
182-
chat_history_value_mock = chat_history.value
183-
chat_history_value_mock.__iadd__.return_value = chat_history_value_mock
184-
185-
# Mock the chat method
186-
with patch.object(
187-
run_manager.llm_manager,
188-
"chat",
189-
return_value="Test response",
190-
) as mock_chat:
191-
# Simulate the submission
192-
text_input.value = "Test message"
193-
on_submit_wrapper(text_input)
194-
await asyncio.sleep(0) # allow the task to run
195-
196-
# Assertions
197-
mock_chat.assert_called_once_with("Test message")
198-
iadd_calls = chat_history_value_mock.__iadd__.call_args_list
199-
assert "You: Test message" in iadd_calls[0].args[0]
200-
assert "LLM: Test response" in iadd_calls[1].args[0]
201-
202-
# --- Test on_failed_job_change ---
203-
# Get the on_failed_job_change wrapper from the mock
204-
on_failed_job_change_wrapper = failed_job_dropdown.observe.call_args[0][0]
205-
206-
# Mock the diagnose_failed_job method
207-
with patch.object(
208-
run_manager.llm_manager,
209-
"diagnose_failed_job",
210-
return_value="Test diagnosis",
211-
) as mock_diagnose:
212-
# Simulate the dropdown change
213-
change = {"new": "test_job_id"}
214-
on_failed_job_change_wrapper(change)
215-
await asyncio.sleep(0) # allow the task to run
216-
217-
# Assertions
218-
mock_diagnose.assert_called_once_with("test_job_id")
219-
assert "Diagnosis for job test_job_id" in chat_history.value
220-
assert "Test diagnosis" in chat_history.value

0 commit comments

Comments
 (0)