Skip to content
This repository was archived by the owner on Dec 11, 2025. It is now read-only.

Commit 3187a57

Browse files
authored
Merge pull request #35 from Not-Diamond/t7-fix-system-prompt-o1
o1 fixes
2 parents 8cad6ac + af54ff7 commit 3187a57

4 files changed

Lines changed: 60 additions & 4 deletions

File tree

notdiamond/llms/client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,19 @@
3333
MissingLLMConfigs,
3434
)
3535
from notdiamond.llms.config import LLMConfig
36+
from notdiamond.llms.providers import is_o1_model
3637
from notdiamond.llms.request import (
3738
amodel_select,
3839
create_preference_id,
3940
model_select,
4041
report_latency,
4142
)
4243
from notdiamond.metrics.metric import Metric
43-
from notdiamond.prompts import _curly_escape, inject_system_prompt
44+
from notdiamond.prompts import (
45+
_curly_escape,
46+
inject_system_prompt,
47+
o1_system_prompt_translate,
48+
)
4449
from notdiamond.types import NDApiKeyValidator
4550

4651
LOGGER = logging.getLogger(__name__)
@@ -883,11 +888,13 @@ def invoke(
883888
messages, best_llm.system_prompt
884889
)
885890

891+
messages = o1_system_prompt_translate(messages, best_llm)
892+
886893
self.call_callbacks("on_model_select", best_llm, best_llm.model)
887894

888895
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
889896

890-
if self.tools:
897+
if self.tools and not is_o1_model(best_llm):
891898
llm = llm.bind_tools(self.tools)
892899

893900
if response_model is not None:
@@ -1082,11 +1089,13 @@ async def ainvoke(
10821089
messages, best_llm.system_prompt
10831090
)
10841091

1092+
messages = o1_system_prompt_translate(messages, best_llm)
1093+
10851094
self.call_callbacks("on_model_select", best_llm, best_llm.model)
10861095

10871096
llm = self._llm_from_config(best_llm, callbacks=self.callbacks)
10881097

1089-
if self.tools:
1098+
if self.tools and not is_o1_model(best_llm):
10901099
llm = llm.bind_tools(self.tools)
10911100

10921101
if response_model is not None:
@@ -1515,6 +1524,9 @@ def _llm_from_config(
15151524
"ChatOpenAI",
15161525
provider.provider,
15171526
)
1527+
if is_o1_model(provider):
1528+
passed_kwargs["temperature"] = 1.0
1529+
15181530
return ChatOpenAI(
15191531
openai_api_key=provider.api_key,
15201532
model_name=provider.model,

notdiamond/llms/providers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,12 @@ class NDLLMProviders(Enum):
175175

176176
def __new__(cls, provider, model):
177177
return LLMConfig(provider=provider, model=model)
178+
179+
180+
def is_o1_model(llm: LLMConfig):
181+
return llm in (
182+
NDLLMProviders.O1_PREVIEW,
183+
NDLLMProviders.O1_PREVIEW_2024_09_12,
184+
NDLLMProviders.O1_MINI,
185+
NDLLMProviders.O1_MINI_2024_09_12,
186+
)

notdiamond/prompts.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
import re
21
import logging
2+
import re
33
from typing import Dict, List
44

5+
from notdiamond.llms.config import LLMConfig
6+
from notdiamond.llms.providers import is_o1_model
7+
58
LOGGER = logging.getLogger(__name__)
69
LOGGER.setLevel(logging.INFO)
710

@@ -32,3 +35,19 @@ def _curly_escape(text: str) -> str:
3235
This function will not escape double curly braces or non-alphabetic characters.
3336
"""
3437
return re.sub(r"(?<!{){([a-zA-Z])}(?!})", r"{{\1}}", text)
38+
39+
40+
def o1_system_prompt_translate(
41+
messages: List[Dict[str, str]], llm: LLMConfig
42+
) -> List[Dict[str, str]]:
43+
if is_o1_model(llm):
44+
translated_messages = []
45+
for msg in messages:
46+
if msg["role"] == "system":
47+
translated_messages.append(
48+
{"role": "user", "content": msg["content"]}
49+
)
50+
else:
51+
translated_messages.append(msg)
52+
return translated_messages
53+
return messages

tests/test_llm_calls/test_openai.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,19 @@ def test_response_model(self, response_model, provider):
7474
assert isinstance(result, response_model)
7575
assert result.setup
7676
assert result.punchline
77+
78+
79+
def test_o1_with_system_prompt():
80+
provider = NDLLMProviders.O1_MINI
81+
nd_llm = NotDiamond(
82+
llm_configs=[provider], latency_tracking=False, hash_content=True
83+
)
84+
result, session_id, _ = nd_llm.invoke(
85+
[
86+
{"role": "system", "content": "You are a funny AI"},
87+
{"role": "user", "content": "Tell me a joke"},
88+
],
89+
)
90+
91+
assert session_id != "NO-SESSION-ID"
92+
assert len(result.content) > 0

0 commit comments

Comments
 (0)