11from datetime import datetime
2+ from typing import Any
23
34import pytest
45from lance_context .api import Context , _coerce_vector , _normalize_record , _normalize_search_hit
@@ -8,6 +9,10 @@ class DummyInner:
89 def __init__ (self ) -> None :
910 self .search_calls : list [tuple [list [float ], int | None ]] = []
1011 self .list_calls : list [tuple [int | None , int | None ]] = []
12+ self .add_calls : list [tuple [str , Any , str | None , list [float ] | None ]] = []
13+
14+ def add (self , role : str , content : Any , data_type : str | None , embedding : list [float ] | None ):
15+ self .add_calls .append ((role , content , data_type , embedding ))
1116
1217 def search (self , vector : list [float ], limit : int | None ):
1318 self .search_calls .append ((vector , limit ))
@@ -140,3 +145,48 @@ def test_context_list_default_args():
140145 ctx .list ()
141146
142147 assert dummy .list_calls == [(None , None )]
148+
149+
150+ def test_context_add_with_embedding ():
151+ ctx = Context .__new__ (Context )
152+ dummy = DummyInner ()
153+ ctx ._inner = dummy # type: ignore[attr-defined]
154+
155+ embedding = [0.1 , 0.2 , 0.3 ]
156+ ctx .add ("user" , "hello" , embedding = embedding )
157+
158+ assert len (dummy .add_calls ) == 1
159+ role , content , data_type , passed_embedding = dummy .add_calls [0 ]
160+ assert role == "user"
161+ assert content == "hello"
162+ assert data_type is None
163+ assert passed_embedding == [0.1 , 0.2 , 0.3 ]
164+
165+
166+ def test_context_add_without_embedding ():
167+ ctx = Context .__new__ (Context )
168+ dummy = DummyInner ()
169+ ctx ._inner = dummy # type: ignore[attr-defined]
170+
171+ ctx .add ("assistant" , "world" )
172+
173+ assert len (dummy .add_calls ) == 1
174+ role , content , data_type , passed_embedding = dummy .add_calls [0 ]
175+ assert role == "assistant"
176+ assert content == "world"
177+ assert passed_embedding is None
178+
179+
180+ def test_context_add_with_content_type_and_embedding ():
181+ ctx = Context .__new__ (Context )
182+ dummy = DummyInner ()
183+ ctx ._inner = dummy # type: ignore[attr-defined]
184+
185+ embedding = [0.5 , 0.6 ]
186+ ctx .add ("system" , "prompt" , content_type = "text/markdown" , embedding = embedding )
187+
188+ assert len (dummy .add_calls ) == 1
189+ role , content , data_type , passed_embedding = dummy .add_calls [0 ]
190+ assert role == "system"
191+ assert data_type == "text/markdown"
192+ assert passed_embedding == [0.5 , 0.6 ]
0 commit comments