Skip to content
This repository was archived by the owner on Dec 11, 2025. It is now read-only.

Commit 8cad6ac

Browse files
authored
Merge pull request #34 from Not-Diamond/a9-langchain-ndkwargs
Adding support for optional nd_kwargs as Runnable params.
2 parents fdfa362 + 933a859 commit 8cad6ac

2 files changed

Lines changed: 81 additions & 35 deletions

File tree

notdiamond/toolkit/langchain.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,15 @@ def __init__(
4949
nd_llm_configs: Optional[List] = None,
5050
nd_api_key: Optional[str] = None,
5151
nd_client: Optional[Any] = None,
52+
nd_kwargs: Optional[Dict[str, Any]] = None,
5253
):
54+
"""
55+
Params:
56+
nd_llm_configs: List of LLM configs to use.
57+
nd_api_key: Not Diamond API key.
58+
nd_client: Not Diamond client.
59+
nd_kwargs: Keyword arguments to pass directly to model_select.
60+
"""
5361
if not nd_client:
5462
if not nd_api_key or not nd_llm_configs:
5563
raise ValueError(
@@ -81,19 +89,21 @@ def __init__(
8189
self.client = nd_client
8290
self.api_key = nd_client.api_key
8391
self.llm_configs = nd_client.llm_configs
92+
self.nd_kwargs = nd_kwargs or dict()
8493

8594
def _model_select(self, input: LanguageModelInput) -> str:
8695
messages = _convert_input_to_message_dicts(input)
96+
print(self.nd_kwargs)
8797
_, provider = self.client.chat.completions.model_select(
88-
messages=messages
98+
messages=messages, **self.nd_kwargs
8999
)
90100
provider_str = _nd_provider_to_langchain_provider(str(provider))
91101
return provider_str
92102

93103
async def _amodel_select(self, input: LanguageModelInput) -> str:
94104
messages = _convert_input_to_message_dicts(input)
95105
_, provider = await self.client.chat.completions.amodel_select(
96-
messages=messages
106+
messages=messages, **self.nd_kwargs
97107
)
98108
provider_str = _nd_provider_to_langchain_provider(str(provider))
99109
return provider_str
@@ -154,15 +164,28 @@ def __init__(
154164
nd_llm_configs: Optional[List] = None,
155165
nd_api_key: Optional[str] = None,
156166
nd_client: Optional[Any] = None,
167+
nd_kwargs: Optional[Dict[str, Any]] = None,
157168
**kwargs: Optional[Dict[Any, Any]],
158169
) -> None:
170+
"""
171+
Params:
172+
nd_llm_configs: List of LLM configs to use.
173+
nd_api_key: Not Diamond API key.
174+
nd_client: Not Diamond client.
175+
nd_kwargs: Keyword arguments to pass directly to model_select.
176+
"""
177+
_nd_kwargs = {
178+
kw: kwargs[kw] for kw in kwargs.keys() if kw.startswith("nd_")
179+
}
180+
if nd_kwargs:
181+
_nd_kwargs.update(nd_kwargs)
182+
159183
self._ndrunnable = NotDiamondRunnable(
160184
nd_api_key=nd_api_key,
161185
nd_llm_configs=nd_llm_configs,
162186
nd_client=nd_client,
187+
nd_kwargs=_nd_kwargs,
163188
)
164-
_nd_kwargs = {kw for kw in kwargs.keys() if kw.startswith("nd_")}
165-
166189
_routed_fields = ["model", "model_provider"]
167190
if configurable_fields is None:
168191
configurable_fields = []

tests/test_toolkit/langchain/test_unit.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,47 +44,61 @@ def nd_client(llm_configs: List[Any]) -> Any:
4444

4545
@pytest.fixture
4646
def not_diamond_runnable(nd_client: Any) -> NotDiamondRunnable:
47-
return NotDiamondRunnable(nd_client=nd_client)
47+
return NotDiamondRunnable(
48+
nd_client=nd_client, nd_kwargs={"tradeoff": "cost"}
49+
)
4850

4951

5052
@pytest.fixture
5153
def not_diamond_routed_runnable(nd_client: Any) -> NotDiamondRoutedRunnable:
52-
routed_runnable = NotDiamondRoutedRunnable(nd_client=nd_client)
54+
routed_runnable = NotDiamondRoutedRunnable(
55+
nd_client=nd_client, nd_kwargs={"tradeoff": "cost"}
56+
)
5357
routed_runnable._configurable_model = MagicMock(spec=_ConfigurableModel)
5458
return routed_runnable
5559

5660

5761
class TestNotDiamondRunnable:
5862
def test_model_select(
59-
self, not_diamond_runnable: NotDiamondRunnable, llm_configs: List
63+
self,
64+
not_diamond_runnable: NotDiamondRunnable,
65+
llm_configs: List,
66+
nd_client,
6067
) -> None:
61-
actual_select = not_diamond_runnable._model_select("Hello, world!")
68+
prompt = "Hello, world!"
69+
actual_select = not_diamond_runnable._model_select(prompt)
6270
assert str(actual_select) in [
6371
_nd_provider_to_langchain_provider(str(config))
6472
for config in llm_configs
6573
]
74+
assert nd_client.model_select.called_with(prompt, tradeoff="cost")
6675

6776
@pytest.mark.asyncio
6877
async def test_amodel_select(
69-
self, not_diamond_runnable: NotDiamondRunnable, llm_configs: List
78+
self,
79+
not_diamond_runnable: NotDiamondRunnable,
80+
llm_configs: List,
81+
nd_client,
7082
) -> None:
71-
actual_select = await not_diamond_runnable._amodel_select(
72-
"Hello, world!"
73-
)
83+
prompt = "Hello, world!"
84+
actual_select = await not_diamond_runnable._amodel_select(prompt)
7485
assert str(actual_select) in [
7586
_nd_provider_to_langchain_provider(str(config))
7687
for config in llm_configs
7788
]
89+
assert nd_client.amodel_select.called_with(prompt, tradeoff="cost")
7890

7991

8092
class TestNotDiamondRoutedRunnable:
8193
def test_invoke(
82-
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
94+
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client
8395
) -> None:
84-
not_diamond_routed_runnable.invoke("Hello, world!")
96+
prompt = "Hello, world!"
97+
not_diamond_routed_runnable.invoke(prompt)
8598
assert (
8699
not_diamond_routed_runnable._configurable_model.invoke.called # type: ignore[attr-defined]
87100
), f"{not_diamond_routed_runnable._configurable_model}"
101+
assert nd_client.model_select.called_with(prompt, tradeoff="cost")
88102

89103
# Check the call list
90104
call_list = (
@@ -95,40 +109,44 @@ def test_invoke(
95109
assert args[0] == "Hello, world!"
96110

97111
def test_stream(
98-
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
112+
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client
99113
) -> None:
100-
for result in not_diamond_routed_runnable.stream("Hello, world!"):
114+
prompt = "Hello, world!"
115+
for result in not_diamond_routed_runnable.stream(prompt):
101116
assert result is not None
102117
assert (
103118
not_diamond_routed_runnable._configurable_model.stream.called # type: ignore[attr-defined]
104119
), f"{not_diamond_routed_runnable._configurable_model}"
120+
assert nd_client.model_select.called_with(prompt, tradeoff="cost")
105121

106122
def test_batch(
107-
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
123+
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client
108124
) -> None:
109-
not_diamond_routed_runnable.batch(
110-
["Hello, world!", "How are you today?"]
111-
)
125+
prompts = ["Hello, world!", "How are you today?"]
126+
not_diamond_routed_runnable.batch(prompts)
112127
assert (
113128
not_diamond_routed_runnable._configurable_model.batch.called # type: ignore[attr-defined]
114129
), f"{not_diamond_routed_runnable._configurable_model}"
130+
assert nd_client.model_select.called_with(prompts, tradeoff="cost")
115131

116132
# Check the call list
117133
call_list = (
118134
not_diamond_routed_runnable._configurable_model.batch.call_args_list # type: ignore[attr-defined]
119135
)
120136
assert len(call_list) == 1
121137
args, kwargs = call_list[0]
122-
assert args[0] == ["Hello, world!", "How are you today?"]
138+
assert args[0] == prompts
123139

124140
@pytest.mark.asyncio
125141
async def test_ainvoke(
126-
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
142+
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client
127143
) -> None:
128-
await not_diamond_routed_runnable.ainvoke("Hello, world!")
144+
prompt = "Hello, world!"
145+
await not_diamond_routed_runnable.ainvoke(prompt)
129146
assert (
130147
not_diamond_routed_runnable._configurable_model.ainvoke.called # type: ignore[attr-defined]
131148
), f"{not_diamond_routed_runnable._configurable_model}"
149+
assert nd_client.amodel_select.called_with(prompt, tradeoff="cost")
132150

133151
# Check the call list
134152
call_list = (
@@ -140,34 +158,34 @@ async def test_ainvoke(
140158

141159
@pytest.mark.asyncio
142160
async def test_astream(
143-
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
161+
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client
144162
) -> None:
145-
async for result in not_diamond_routed_runnable.astream(
146-
"Hello, world!"
147-
):
163+
prompt = "Hello, world!"
164+
async for result in not_diamond_routed_runnable.astream(prompt):
148165
assert result is not None
149166
assert (
150167
not_diamond_routed_runnable._configurable_model.astream.called # type: ignore[attr-defined]
151168
), f"{not_diamond_routed_runnable._configurable_model}"
169+
assert nd_client.amodel_select.called_with(prompt, tradeoff="cost")
152170

153171
@pytest.mark.asyncio
154172
async def test_abatch(
155-
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable
173+
self, not_diamond_routed_runnable: NotDiamondRoutedRunnable, nd_client
156174
) -> None:
157-
await not_diamond_routed_runnable.abatch(
158-
["Hello, world!", "How are you today?"]
159-
)
175+
prompts = ["Hello, world!", "How are you today?"]
176+
await not_diamond_routed_runnable.abatch(prompts)
160177
assert (
161178
not_diamond_routed_runnable._configurable_model.abatch.called # type: ignore[attr-defined]
162179
), f"{not_diamond_routed_runnable._configurable_model}"
180+
assert nd_client.amodel_select.called_with(prompts, tradeoff="cost")
163181

164182
# Check the call list
165183
call_list = (
166184
not_diamond_routed_runnable._configurable_model.abatch.call_args_list # type: ignore[attr-defined]
167185
)
168186
assert len(call_list) == 1
169187
args, kwargs = call_list[0]
170-
assert args[0] == ["Hello, world!", "How are you today?"]
188+
assert args[0] == prompts
171189

172190
def test_invokable_mock(self) -> None:
173191
target_model = "openai/gpt-4o"
@@ -184,15 +202,19 @@ def test_invokable_mock(self) -> None:
184202

185203
mock_client = MagicMock(spec=LLM)
186204

205+
test_prompt = "Test prompt"
187206
with patch(
188207
"notdiamond.toolkit.langchain.init_chat_model", autospec=True
189208
) as mock_method:
190209
mock_method.return_value = mock_client
191-
runnable = NotDiamondRoutedRunnable(nd_client=nd_client)
192-
runnable.invoke("Test prompt")
210+
runnable = NotDiamondRoutedRunnable(
211+
nd_client=nd_client, nd_kwargs={"tradeoff": "cost"}
212+
)
213+
runnable.invoke(test_prompt)
193214
assert (
194215
mock_client.invoke.called # type: ignore[attr-defined]
195216
), f"{mock_client}"
217+
assert nd_client.model_select.called_with(test_prompt, tradeoff="cost")
196218

197219
mock_client.reset_mock()
198220

@@ -203,10 +225,11 @@ def test_invokable_mock(self) -> None:
203225
runnable = NotDiamondRoutedRunnable(
204226
nd_api_key="sk-...", nd_llm_configs=[target_model]
205227
)
206-
runnable.invoke("Test prompt")
228+
runnable.invoke(test_prompt)
207229
assert (
208230
mock_client.invoke.called # type: ignore[attr-defined]
209231
), f"{mock_client}"
232+
assert nd_client.model_select.called_with(test_prompt, tradeoff="cost")
210233

211234
def test_init_perplexity(self) -> None:
212235
target_model = "perplexity/llama-3.1-sonar-large-128k-online"

0 commit comments

Comments
 (0)