@@ -44,47 +44,61 @@ def nd_client(llm_configs: List[Any]) -> Any:
4444
4545@pytest .fixture
4646def not_diamond_runnable (nd_client : Any ) -> NotDiamondRunnable :
47- return NotDiamondRunnable (nd_client = nd_client )
47+ return NotDiamondRunnable (
48+ nd_client = nd_client , nd_kwargs = {"tradeoff" : "cost" }
49+ )
4850
4951
5052@pytest .fixture
5153def not_diamond_routed_runnable (nd_client : Any ) -> NotDiamondRoutedRunnable :
52- routed_runnable = NotDiamondRoutedRunnable (nd_client = nd_client )
54+ routed_runnable = NotDiamondRoutedRunnable (
55+ nd_client = nd_client , nd_kwargs = {"tradeoff" : "cost" }
56+ )
5357 routed_runnable ._configurable_model = MagicMock (spec = _ConfigurableModel )
5458 return routed_runnable
5559
5660
5761class TestNotDiamondRunnable :
5862 def test_model_select (
59- self , not_diamond_runnable : NotDiamondRunnable , llm_configs : List
63+ self ,
64+ not_diamond_runnable : NotDiamondRunnable ,
65+ llm_configs : List ,
66+ nd_client ,
6067 ) -> None :
61- actual_select = not_diamond_runnable ._model_select ("Hello, world!" )
68+ prompt = "Hello, world!"
69+ actual_select = not_diamond_runnable ._model_select (prompt )
6270 assert str (actual_select ) in [
6371 _nd_provider_to_langchain_provider (str (config ))
6472 for config in llm_configs
6573 ]
74+ assert nd_client .model_select .called_with (prompt , tradeoff = "cost" )
6675
6776 @pytest .mark .asyncio
6877 async def test_amodel_select (
69- self , not_diamond_runnable : NotDiamondRunnable , llm_configs : List
78+ self ,
79+ not_diamond_runnable : NotDiamondRunnable ,
80+ llm_configs : List ,
81+ nd_client ,
7082 ) -> None :
71- actual_select = await not_diamond_runnable ._amodel_select (
72- "Hello, world!"
73- )
83+ prompt = "Hello, world!"
84+ actual_select = await not_diamond_runnable ._amodel_select (prompt )
7485 assert str (actual_select ) in [
7586 _nd_provider_to_langchain_provider (str (config ))
7687 for config in llm_configs
7788 ]
89+ assert nd_client .amodel_select .called_with (prompt , tradeoff = "cost" )
7890
7991
8092class TestNotDiamondRoutedRunnable :
8193 def test_invoke (
82- self , not_diamond_routed_runnable : NotDiamondRoutedRunnable
94+ self , not_diamond_routed_runnable : NotDiamondRoutedRunnable , nd_client
8395 ) -> None :
84- not_diamond_routed_runnable .invoke ("Hello, world!" )
96+ prompt = "Hello, world!"
97+ not_diamond_routed_runnable .invoke (prompt )
8598 assert (
8699 not_diamond_routed_runnable ._configurable_model .invoke .called # type: ignore[attr-defined]
87100 ), f"{ not_diamond_routed_runnable ._configurable_model } "
101+ assert nd_client .model_select .called_with (prompt , tradeoff = "cost" )
88102
89103 # Check the call list
90104 call_list = (
@@ -95,40 +109,44 @@ def test_invoke(
95109 assert args [0 ] == "Hello, world!"
96110
97111 def test_stream (
98- self , not_diamond_routed_runnable : NotDiamondRoutedRunnable
112+ self , not_diamond_routed_runnable : NotDiamondRoutedRunnable , nd_client
99113 ) -> None :
100- for result in not_diamond_routed_runnable .stream ("Hello, world!" ):
114+ prompt = "Hello, world!"
115+ for result in not_diamond_routed_runnable .stream (prompt ):
101116 assert result is not None
102117 assert (
103118 not_diamond_routed_runnable ._configurable_model .stream .called # type: ignore[attr-defined]
104119 ), f"{ not_diamond_routed_runnable ._configurable_model } "
120+ assert nd_client .model_select .called_with (prompt , tradeoff = "cost" )
105121
106122 def test_batch (
107- self , not_diamond_routed_runnable : NotDiamondRoutedRunnable
123+ self , not_diamond_routed_runnable : NotDiamondRoutedRunnable , nd_client
108124 ) -> None :
109- not_diamond_routed_runnable .batch (
110- ["Hello, world!" , "How are you today?" ]
111- )
125+ prompts = ["Hello, world!" , "How are you today?" ]
126+ not_diamond_routed_runnable .batch (prompts )
112127 assert (
113128 not_diamond_routed_runnable ._configurable_model .batch .called # type: ignore[attr-defined]
114129 ), f"{ not_diamond_routed_runnable ._configurable_model } "
130+ assert nd_client .model_select .called_with (prompts , tradeoff = "cost" )
115131
116132 # Check the call list
117133 call_list = (
118134 not_diamond_routed_runnable ._configurable_model .batch .call_args_list # type: ignore[attr-defined]
119135 )
120136 assert len (call_list ) == 1
121137 args , kwargs = call_list [0 ]
122- assert args [0 ] == [ "Hello, world!" , "How are you today?" ]
138+ assert args [0 ] == prompts
123139
124140 @pytest .mark .asyncio
125141 async def test_ainvoke (
126- self , not_diamond_routed_runnable : NotDiamondRoutedRunnable
142+ self , not_diamond_routed_runnable : NotDiamondRoutedRunnable , nd_client
127143 ) -> None :
128- await not_diamond_routed_runnable .ainvoke ("Hello, world!" )
144+ prompt = "Hello, world!"
145+ await not_diamond_routed_runnable .ainvoke (prompt )
129146 assert (
130147 not_diamond_routed_runnable ._configurable_model .ainvoke .called # type: ignore[attr-defined]
131148 ), f"{ not_diamond_routed_runnable ._configurable_model } "
149+ assert nd_client .amodel_select .called_with (prompt , tradeoff = "cost" )
132150
133151 # Check the call list
134152 call_list = (
@@ -140,34 +158,34 @@ async def test_ainvoke(
140158
141159 @pytest .mark .asyncio
142160 async def test_astream (
143- self , not_diamond_routed_runnable : NotDiamondRoutedRunnable
161+ self , not_diamond_routed_runnable : NotDiamondRoutedRunnable , nd_client
144162 ) -> None :
145- async for result in not_diamond_routed_runnable .astream (
146- "Hello, world!"
147- ):
163+ prompt = "Hello, world!"
164+ async for result in not_diamond_routed_runnable .astream (prompt ):
148165 assert result is not None
149166 assert (
150167 not_diamond_routed_runnable ._configurable_model .astream .called # type: ignore[attr-defined]
151168 ), f"{ not_diamond_routed_runnable ._configurable_model } "
169+ assert nd_client .amodel_select .called_with (prompt , tradeoff = "cost" )
152170
153171 @pytest .mark .asyncio
154172 async def test_abatch (
155- self , not_diamond_routed_runnable : NotDiamondRoutedRunnable
173+ self , not_diamond_routed_runnable : NotDiamondRoutedRunnable , nd_client
156174 ) -> None :
157- await not_diamond_routed_runnable .abatch (
158- ["Hello, world!" , "How are you today?" ]
159- )
175+ prompts = ["Hello, world!" , "How are you today?" ]
176+ await not_diamond_routed_runnable .abatch (prompts )
160177 assert (
161178 not_diamond_routed_runnable ._configurable_model .abatch .called # type: ignore[attr-defined]
162179 ), f"{ not_diamond_routed_runnable ._configurable_model } "
180+ assert nd_client .amodel_select .called_with (prompts , tradeoff = "cost" )
163181
164182 # Check the call list
165183 call_list = (
166184 not_diamond_routed_runnable ._configurable_model .abatch .call_args_list # type: ignore[attr-defined]
167185 )
168186 assert len (call_list ) == 1
169187 args , kwargs = call_list [0 ]
170- assert args [0 ] == [ "Hello, world!" , "How are you today?" ]
188+ assert args [0 ] == prompts
171189
172190 def test_invokable_mock (self ) -> None :
173191 target_model = "openai/gpt-4o"
@@ -184,15 +202,19 @@ def test_invokable_mock(self) -> None:
184202
185203 mock_client = MagicMock (spec = LLM )
186204
205+ test_prompt = "Test prompt"
187206 with patch (
188207 "notdiamond.toolkit.langchain.init_chat_model" , autospec = True
189208 ) as mock_method :
190209 mock_method .return_value = mock_client
191- runnable = NotDiamondRoutedRunnable (nd_client = nd_client )
192- runnable .invoke ("Test prompt" )
210+ runnable = NotDiamondRoutedRunnable (
211+ nd_client = nd_client , nd_kwargs = {"tradeoff" : "cost" }
212+ )
213+ runnable .invoke (test_prompt )
193214 assert (
194215 mock_client .invoke .called # type: ignore[attr-defined]
195216 ), f"{ mock_client } "
217+ assert nd_client .model_select .called_with (test_prompt , tradeoff = "cost" )
196218
197219 mock_client .reset_mock ()
198220
@@ -203,10 +225,11 @@ def test_invokable_mock(self) -> None:
203225 runnable = NotDiamondRoutedRunnable (
204226 nd_api_key = "sk-..." , nd_llm_configs = [target_model ]
205227 )
206- runnable .invoke ("Test prompt" )
228+ runnable .invoke (test_prompt )
207229 assert (
208230 mock_client .invoke .called # type: ignore[attr-defined]
209231 ), f"{ mock_client } "
232+ assert nd_client .model_select .called_with (test_prompt , tradeoff = "cost" )
210233
211234 def test_init_perplexity (self ) -> None :
212235 target_model = "perplexity/llama-3.1-sonar-large-128k-online"
0 commit comments