|
33 | 33 | MissingLLMConfigs, |
34 | 34 | ) |
35 | 35 | from notdiamond.llms.config import LLMConfig |
| 36 | +from notdiamond.llms.providers import is_o1_model |
36 | 37 | from notdiamond.llms.request import ( |
37 | 38 | amodel_select, |
38 | 39 | create_preference_id, |
39 | 40 | model_select, |
40 | 41 | report_latency, |
41 | 42 | ) |
42 | 43 | from notdiamond.metrics.metric import Metric |
43 | | -from notdiamond.prompts import _curly_escape, inject_system_prompt |
| 44 | +from notdiamond.prompts import ( |
| 45 | + _curly_escape, |
| 46 | + inject_system_prompt, |
| 47 | + o1_system_prompt_translate, |
| 48 | +) |
44 | 49 | from notdiamond.types import NDApiKeyValidator |
45 | 50 |
|
46 | 51 | LOGGER = logging.getLogger(__name__) |
@@ -883,11 +888,13 @@ def invoke( |
883 | 888 | messages, best_llm.system_prompt |
884 | 889 | ) |
885 | 890 |
|
| 891 | + messages = o1_system_prompt_translate(messages, best_llm) |
| 892 | + |
886 | 893 | self.call_callbacks("on_model_select", best_llm, best_llm.model) |
887 | 894 |
|
888 | 895 | llm = self._llm_from_config(best_llm, callbacks=self.callbacks) |
889 | 896 |
|
890 | | - if self.tools: |
| 897 | + if self.tools and not is_o1_model(best_llm): |
891 | 898 | llm = llm.bind_tools(self.tools) |
892 | 899 |
|
893 | 900 | if response_model is not None: |
@@ -1082,11 +1089,13 @@ async def ainvoke( |
1082 | 1089 | messages, best_llm.system_prompt |
1083 | 1090 | ) |
1084 | 1091 |
|
| 1092 | + messages = o1_system_prompt_translate(messages, best_llm) |
| 1093 | + |
1085 | 1094 | self.call_callbacks("on_model_select", best_llm, best_llm.model) |
1086 | 1095 |
|
1087 | 1096 | llm = self._llm_from_config(best_llm, callbacks=self.callbacks) |
1088 | 1097 |
|
1089 | | - if self.tools: |
| 1098 | + if self.tools and not is_o1_model(best_llm): |
1090 | 1099 | llm = llm.bind_tools(self.tools) |
1091 | 1100 |
|
1092 | 1101 | if response_model is not None: |
@@ -1515,6 +1524,9 @@ def _llm_from_config( |
1515 | 1524 | "ChatOpenAI", |
1516 | 1525 | provider.provider, |
1517 | 1526 | ) |
| 1527 | + if is_o1_model(provider): |
| 1528 | + passed_kwargs["temperature"] = 1.0 |
| 1529 | + |
1518 | 1530 | return ChatOpenAI( |
1519 | 1531 | openai_api_key=provider.api_key, |
1520 | 1532 | model_name=provider.model, |
|
0 commit comments