-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmanaged_model.py
More file actions
121 lines (95 loc) · 4.02 KB
/
managed_model.py
File metadata and controls
121 lines (95 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import asyncio
from typing import List
from ldai import log
from ldai.models import AICompletionConfig, LDMessage
from ldai.providers.runner import Runner
from ldai.providers.types import JudgeResult, ManagedResult, RunnerResult
from ldai.tracker import LDAIConfigTracker
class ManagedModel:
"""
LaunchDarkly managed wrapper for AI model invocations.
Holds a Runner. Handles conversation management, judge evaluation
dispatch, and tracking automatically via ``create_tracker()``.
Obtain an instance via ``LDAIClient.create_model()``.
"""
def __init__(
self,
ai_config: AICompletionConfig,
model_runner: Runner,
):
self._ai_config = ai_config
self._model_runner = model_runner
self._messages: List[LDMessage] = []
async def run(self, prompt: str) -> ManagedResult:
"""
Run the model with a prompt string.
Appends the prompt to the conversation history, prepends any
system messages from the config, delegates to the runner, and
appends the response to the history.
:param prompt: The user prompt to send to the model
:return: ManagedResult containing the model's response, metric summary,
and an optional evaluations task
"""
tracker = self._ai_config.create_tracker()
user_message = LDMessage(role='user', content=prompt)
self._messages.append(user_message)
config_messages = self._ai_config.messages or []
all_messages = config_messages + self._messages
result: RunnerResult = await tracker.track_metrics_of_async(
lambda r: r.metrics,
lambda: self._model_runner.run(all_messages),
)
assistant_message = LDMessage(role='assistant', content=result.content)
input_text = '\r\n'.join(m.content for m in self._messages) if self._messages else ''
evaluations_task = self._track_judge_results(tracker, input_text, result.content)
self._messages.append(assistant_message)
return ManagedResult(
content=result.content,
metrics=tracker.get_summary(),
raw=result.raw,
parsed=result.parsed,
evaluations=evaluations_task,
)
def _track_judge_results(
self,
tracker: LDAIConfigTracker,
input_text: str,
output_text: str,
) -> asyncio.Task[List[JudgeResult]]:
evaluator_task = self._ai_config.evaluator.evaluate(input_text, output_text)
async def _run_and_track(eval_task: asyncio.Task) -> List[JudgeResult]:
results = await eval_task
for r in results:
if r.success:
try:
tracker.track_judge_result(r)
except Exception as exc:
log.warning("Judge evaluation failed: %s", exc)
else:
log.warning("Judge evaluation failed: %s", r.error_message)
return results
return asyncio.create_task(_run_and_track(evaluator_task))
def get_messages(self, include_config_messages: bool = False) -> List[LDMessage]:
"""
Get all messages in the conversation history.
:param include_config_messages: When True, prepends config messages.
:return: List of conversation messages.
"""
if include_config_messages:
return (self._ai_config.messages or []) + self._messages
return list(self._messages)
def append_messages(self, messages: List[LDMessage]) -> None:
"""
Append messages to the conversation history without invoking the model.
:param messages: Messages to append.
"""
self._messages.extend(messages)
def get_model_runner(self) -> Runner:
"""
Return the underlying runner for advanced use.
:return: The Runner instance.
"""
return self._model_runner
def get_config(self) -> AICompletionConfig:
"""Return the AI completion config."""
return self._ai_config