Skip to content
This repository was archived by the owner on Dec 11, 2025. It is now read-only.

Commit af54ff7

Browse files
committed
Moving is_o1_model to providers.py
1 parent 8bbec7a commit af54ff7

3 files changed

Lines changed: 15 additions & 17 deletions

File tree

notdiamond/llms/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MissingLLMConfigs,
3434
)
3535
from notdiamond.llms.config import LLMConfig
36+
from notdiamond.llms.providers import is_o1_model
3637
from notdiamond.llms.request import (
3738
amodel_select,
3839
create_preference_id,
@@ -42,7 +43,6 @@
4243
from notdiamond.metrics.metric import Metric
4344
from notdiamond.prompts import (
4445
_curly_escape,
45-
_is_o1_model,
4646
inject_system_prompt,
4747
o1_system_prompt_translate,
4848
)
@@ -894,7 +894,7 @@ def invoke(
894894

895895
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
896896

897-
if self.tools and not _is_o1_model(best_llm):
897+
if self.tools and not is_o1_model(best_llm):
898898
llm = llm.bind_tools(self.tools)
899899

900900
if response_model is not None:
@@ -1095,7 +1095,7 @@ async def ainvoke(
10951095

10961096
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
10971097

1098-
if self.tools and not _is_o1_model(best_llm):
1098+
if self.tools and not is_o1_model(best_llm):
10991099
llm = llm.bind_tools(self.tools)
11001100

11011101
if response_model is not None:
@@ -1524,7 +1524,7 @@ def _llm_from_config(
15241524
"ChatOpenAI",
15251525
provider.provider,
15261526
)
1527-
if _is_o1_model(provider):
1527+
if is_o1_model(provider):
15281528
passed_kwargs["temperature"] = 1.0
15291529

15301530
return ChatOpenAI(

notdiamond/llms/providers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,12 @@ class NDLLMProviders(Enum):
175175

176176
def __new__(cls, provider, model):
177177
return LLMConfig(provider=provider, model=model)
178+
179+
180+
def is_o1_model(llm: LLMConfig):
181+
return llm in (
182+
NDLLMProviders.O1_PREVIEW,
183+
NDLLMProviders.O1_PREVIEW_2024_09_12,
184+
NDLLMProviders.O1_MINI,
185+
NDLLMProviders.O1_MINI_2024_09_12,
186+
)

notdiamond/prompts.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, List
44

55
from notdiamond.llms.config import LLMConfig
6-
from notdiamond.llms.providers import NDLLMProviders
6+
from notdiamond.llms.providers import is_o1_model
77

88
LOGGER = logging.getLogger(__name__)
99
LOGGER.setLevel(logging.INFO)
@@ -37,21 +37,10 @@ def _curly_escape(text: str) -> str:
3737
return re.sub(r"(?<!{){([a-zA-Z])}(?!})", r"{{\1}}", text)
3838

3939

40-
def _is_o1_model(llm: LLMConfig):
41-
if llm in (
42-
NDLLMProviders.O1_PREVIEW,
43-
NDLLMProviders.O1_PREVIEW_2024_09_12,
44-
NDLLMProviders.O1_MINI,
45-
NDLLMProviders.O1_MINI_2024_09_12,
46-
):
47-
return True
48-
return False
49-
50-
5140
def o1_system_prompt_translate(
5241
messages: List[Dict[str, str]], llm: LLMConfig
5342
) -> List[Dict[str, str]]:
54-
if _is_o1_model(llm):
43+
if is_o1_model(llm):
5544
translated_messages = []
5645
for msg in messages:
5746
if msg["role"] == "system":

0 commit comments

Comments
 (0)