|
1 | 1 | import asyncio |
2 | | -import warnings |
3 | 2 | from typing import List, Union |
4 | 3 |
|
5 | 4 | from ldai import log |
6 | 5 | from ldai.models import AICompletionConfig, LDMessage |
7 | 6 | from ldai.providers.model_runner import ModelRunner |
8 | 7 | from ldai.providers.runner import Runner |
9 | | -from ldai.providers.types import JudgeResult, ManagedResult, ModelResponse, RunnerResult |
| 8 | +from ldai.providers.types import JudgeResult, ManagedResult, RunnerResult |
10 | 9 | from ldai.tracker import LDAIConfigTracker |
11 | 10 |
|
12 | 11 |
|
@@ -48,86 +47,38 @@ async def run(self, prompt: str) -> ManagedResult: |
48 | 47 | config_messages = self._ai_config.messages or [] |
49 | 48 | all_messages = config_messages + self._messages |
50 | 49 |
|
51 | | - result: Union[RunnerResult, ModelResponse] = await tracker.track_metrics_of_async( |
| 50 | + result: RunnerResult = await tracker.track_metrics_of_async( |
52 | 51 | lambda r: r.metrics, |
53 | 52 | lambda: self._invoke_runner(all_messages), |
54 | 53 | ) |
55 | 54 |
|
56 | | - # Support both new RunnerResult and legacy ModelResponse |
57 | | - if isinstance(result, RunnerResult): |
58 | | - content = result.content |
59 | | - raw = result.raw |
60 | | - parsed = result.parsed |
61 | | - assistant_message = LDMessage(role='assistant', content=content) |
62 | | - else: |
63 | | - content = result.message.content |
64 | | - raw = getattr(result, 'raw', None) |
65 | | - parsed = getattr(result, 'parsed', None) |
66 | | - assistant_message = result.message |
| 55 | + assistant_message = LDMessage(role='assistant', content=result.content) |
67 | 56 |
|
68 | 57 | input_text = '\r\n'.join(m.content for m in self._messages) if self._messages else '' |
69 | 58 |
|
70 | | - evaluations_task = self._track_judge_results(tracker, input_text, content) |
| 59 | + evaluations_task = self._track_judge_results(tracker, input_text, result.content) |
71 | 60 |
|
72 | 61 | self._messages.append(assistant_message) |
73 | 62 |
|
74 | 63 | return ManagedResult( |
75 | | - content=content, |
| 64 | + content=result.content, |
76 | 65 | metrics=tracker.get_summary(), |
77 | | - raw=raw, |
78 | | - parsed=parsed, |
| 66 | + raw=result.raw, |
| 67 | + parsed=result.parsed, |
79 | 68 | evaluations=evaluations_task, |
80 | 69 | ) |
81 | 70 |
|
82 | | - async def _invoke_runner( |
83 | | - self, all_messages: List[LDMessage] |
84 | | - ) -> Union[RunnerResult, ModelResponse]: |
| 71 | + async def _invoke_runner(self, all_messages: List[LDMessage]) -> RunnerResult: |
85 | 72 | """ |
86 | 73 | Delegate to the runner. Supports both the new ``Runner`` protocol |
87 | 74 | (``run(messages) → RunnerResult``) and the legacy ``ModelRunner`` |
88 | | - (``invoke_model(messages) → ModelResponse``). |
| 75 | + (``invoke_model(messages) → RunnerResult``). |
89 | 76 | """ |
90 | 77 | if isinstance(self._model_runner, Runner): |
91 | 78 | return await self._model_runner.run(all_messages) |
92 | 79 | # Legacy ModelRunner path |
93 | 80 | return await self._model_runner.invoke_model(all_messages) # type: ignore[union-attr] |
94 | 81 |
|
95 | | - async def invoke(self, prompt: str) -> ModelResponse: |
96 | | - """ |
97 | | - Invoke the model with a prompt string. |
98 | | -
|
99 | | - .. deprecated:: |
100 | | - Use :meth:`run` instead. This method will be removed in a future |
101 | | - release once the migration to :class:`ManagedResult` is complete. |
102 | | -
|
103 | | - :param prompt: The user prompt to send to the model |
104 | | - :return: ModelResponse containing the model's response and metrics |
105 | | - """ |
106 | | - warnings.warn( |
107 | | - "ManagedModel.invoke() is deprecated. Use run() instead.", |
108 | | - DeprecationWarning, |
109 | | - stacklevel=2, |
110 | | - ) |
111 | | - tracker = self._ai_config.create_tracker() |
112 | | - |
113 | | - user_message = LDMessage(role='user', content=prompt) |
114 | | - self._messages.append(user_message) |
115 | | - |
116 | | - config_messages = self._ai_config.messages or [] |
117 | | - all_messages = config_messages + self._messages |
118 | | - |
119 | | - response: ModelResponse = await tracker.track_metrics_of_async( |
120 | | - lambda result: result.metrics, |
121 | | - lambda: self._model_runner.invoke_model(all_messages), # type: ignore[union-attr] |
122 | | - ) |
123 | | - |
124 | | - input_text = '\r\n'.join(m.content for m in self._messages) if self._messages else '' |
125 | | - output_text = response.message.content |
126 | | - response.evaluations = self._track_judge_results(tracker, input_text, output_text) |
127 | | - |
128 | | - self._messages.append(response.message) |
129 | | - return response |
130 | | - |
131 | 82 | def _track_judge_results( |
132 | 83 | self, |
133 | 84 | tracker: LDAIConfigTracker, |
|
0 commit comments