-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathchat_model_factory.py
More file actions
146 lines (131 loc) · 5.42 KB
/
chat_model_factory.py
File metadata and controls
146 lines (131 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Chat model factory with legacy/new implementation switching.
The ``EnabledNewLlmClients`` feature flag is sourced from ``uipath_agents`` and
passed through as the ``use_new_llm_clients`` argument of :func:`get_chat_model`.
- ``use_new_llm_clients=True`` (default): routes to the new
``uipath_langchain_client`` factory.
- ``use_new_llm_clients=False``: routes to the legacy in-repo clients under
:mod:`uipath_langchain.chat._legacy`, preserving behavior exactly as it was
before the ``uipath_langchain_client`` migration.
"""
from typing import Any, Final
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseChatModel
from uipath_langchain_client.base_client import UiPathBaseChatModel
from uipath_langchain_client.factory import get_chat_model as get_chat_model_factory
from uipath_langchain_client.settings import (
ApiFlavor,
RoutingMode,
UiPathBaseSettings,
VendorType,
)
_UNSET: Final[Any] = object()
DEFAULT_TIMEOUT_SECONDS: Final[float] = 300.0
DEFAULT_MAX_TOKENS: Final[int] = 1000
DEFAULT_TEMPERATURE: Final[float] = 0.0
DEFAULT_MAX_RETRIES: Final[int] = 3
def get_chat_model(
model: str,
*,
byo_connection_id: str | None = None,
client_settings: UiPathBaseSettings | None = None,
routing_mode: RoutingMode | str = RoutingMode.PASSTHROUGH,
vendor_type: VendorType | str | None = None,
api_flavor: ApiFlavor | str | None = None,
custom_class: type[UiPathBaseChatModel] | None = None,
temperature: float | None = DEFAULT_TEMPERATURE,
max_tokens: int | None = DEFAULT_MAX_TOKENS,
timeout: float | None = DEFAULT_TIMEOUT_SECONDS,
max_retries: int | None = DEFAULT_MAX_RETRIES,
callbacks: Callbacks = _UNSET,
agenthub_config: str | None = None,
use_new_llm_clients: bool = True,
**kwargs: Any,
) -> BaseChatModel:
"""Create and configure a chat model, dispatching legacy vs new clients.
Args:
model: The model name (e.g., ``"gpt-4o"``, ``"claude-3-sonnet"``).
byo_connection_id: Optional Integration Service connection ID.
client_settings: Overrides the default ``uipath_langchain_client`` settings.
routing_mode: ``PASSTHROUGH`` (vendor-specific) or ``NORMALIZED``.
vendor_type: Filter models by vendor; auto-detected when omitted.
api_flavor: Vendor-specific API flavor (e.g. OpenAI Responses, Bedrock
Converse). Auto-detected when omitted.
custom_class: Custom ``UiPathBaseChatModel`` subclass to instantiate
instead of the auto-detected one.
temperature: Sampling temperature. Defaults to 0.0. Pass ``None`` to
omit the parameter when the underlying client supports it.
max_tokens: Maximum output tokens. Defaults to 1000 to match the
historical default from ``UiPathRequestMixin``. Pass ``None`` to
forward an explicit unset value (lets the underlying client apply
its own default or use no limit).
timeout: Request timeout in seconds. Defaults to 300 seconds.
max_retries: Max retry count. Defaults to 3.
callbacks: LangChain callbacks (handlers or a manager) attached to the
returned chat model. Accepts ``list[BaseCallbackHandler]`` or a
``BaseCallbackManager``. Forwarded only when explicitly set.
Ignored by the legacy factory.
agenthub_config: AgentHub config header value. Required by the legacy
factory; forwarded to the new factory.
use_new_llm_clients: Routes to the new ``uipath_langchain_client``
factory when True (default). When False, routes to the legacy
in-repo clients.
**kwargs: Forwarded to the underlying factory. The legacy factory
accepts ``disable_streaming``; the new factory forwards extras as
model kwargs to the LangChain constructor.
Returns:
A configured ``BaseChatModel`` instance.
"""
if not use_new_llm_clients:
return _legacy_chat_model(
model,
temperature=temperature,
max_tokens=max_tokens,
agenthub_config=agenthub_config,
byo_connection_id=byo_connection_id,
**kwargs,
)
optional_kwargs = {
k: v
for k, v in {
"temperature": temperature,
"max_tokens": max_tokens,
"timeout": timeout,
"max_retries": max_retries,
"callbacks": callbacks,
}.items()
if v is not _UNSET
}
return get_chat_model_factory(
model,
byo_connection_id=byo_connection_id,
client_settings=client_settings,
routing_mode=routing_mode,
vendor_type=vendor_type,
api_flavor=api_flavor,
custom_class=custom_class,
agenthub_config=agenthub_config,
**optional_kwargs,
**kwargs,
)
def _legacy_chat_model(
model: str,
*,
temperature: float | None,
max_tokens: int | None,
agenthub_config: str | None,
byo_connection_id: str | None,
**kwargs: Any,
) -> BaseChatModel:
if agenthub_config is None:
raise ValueError("agenthub_config is required when use_new_llm_clients=False")
from uipath_langchain.chat._legacy.chat_model_factory import (
get_chat_model as _legacy_get_chat_model,
)
return _legacy_get_chat_model(
model,
temperature,
max_tokens,
agenthub_config,
byo_connection_id,
**kwargs,
)