@@ -42,21 +42,21 @@ def __init__(
4242 self ._scalar_label_field = "label"
4343
4444 self .with_scalar_labels = with_scalar_labels
45+
46+ # Initialize client with new SDK pattern
47+ self .client = tpuf .Turbopuffer (api_key = self .api_key , base_url = self .api_base_url )
48+
4549 if drop_old :
4650 log .info (f"Drop old. delete the namespace: { self .namespace } " )
47- tpuf .api_key = self .api_key
48- tpuf .api_base_url = self .api_base_url
49- ns = tpuf .Namespace (self .namespace )
51+ ns = self .client .namespace (self .namespace )
5052 try :
5153 ns .delete_all ()
5254 except Exception as e :
5355 log .warning (f"Failed to delete all. Error: { e } " )
5456
5557 @contextmanager
5658 def init (self ):
57- tpuf .api_key = self .api_key
58- tpuf .api_base_url = self .api_base_url
59- self .ns = tpuf .Namespace (self .namespace )
59+ self .ns = self .client .namespace (self .namespace )
6060 yield
6161
6262 def optimize (self , data_size : int | None = None ):
@@ -78,7 +78,7 @@ def insert_embeddings(
7878 try :
7979 if self .with_scalar_labels :
8080 self .ns .write (
81- upsert_columns = {
81+ columns = {
8282 self ._scalar_id_field : metadata ,
8383 self ._vector_field : embeddings ,
8484 self ._scalar_label_field : labels_data ,
@@ -87,7 +87,7 @@ def insert_embeddings(
8787 )
8888 else :
8989 self .ns .write (
90- upsert_columns = {
90+ columns = {
9191 self ._scalar_id_field : metadata ,
9292 self ._vector_field : embeddings ,
9393 },
@@ -104,19 +104,19 @@ def search_embedding(
104104 timeout : int | None = None ,
105105 ) -> list [int ]:
106106 res = self .ns .query (
107- rank_by = [ "vector" , "ANN" , query ] ,
107+ rank_by = ( "vector" , "ANN" , query ) ,
108108 top_k = k ,
109109 filters = self .expr ,
110110 )
111- return [row .id for row in res .rows ]
111+ return [row .id for row in res .rows ] if res . rows is not None else []
112112
113113 def prepare_filter (self , filters : Filter ):
114114 if filters .type == FilterOp .NonFilter :
115115 self .expr = None
116116 elif filters .type == FilterOp .NumGE :
117- self .expr = [ self ._scalar_id_field , "Gte" , filters .int_value ]
117+ self .expr = ( self ._scalar_id_field , "Gte" , filters .int_value )
118118 elif filters .type == FilterOp .StrEqual :
119- self .expr = [ self ._scalar_label_field , "Eq" , filters .label_value ]
119+ self .expr = ( self ._scalar_label_field , "Eq" , filters .label_value )
120120 else :
121121 msg = f"Not support Filter for TurboPuffer - { filters } "
122122 raise ValueError (msg )
0 commit comments