@@ -85,8 +85,12 @@ class Config:
8585 overlap_ratio : float = 0.2
8686 query_multiplier : int = - 1
8787 query_exclude : list [PathLike ] = field (default_factory = list )
88- reranker : Optional [str ] = "cross-encoder/ms-marco-MiniLM-L-6-v2"
89- reranker_params : dict [str , Any ] = field (default_factory = dict )
88+ reranker : Optional [str ] = "CrossEncoderReranker"
89+ reranker_params : dict [str , Any ] = field (
90+ default_factory = lambda : {
91+ "model_name_or_path" : "cross-encoder/ms-marco-MiniLM-L-6-v2"
92+ }
93+ )
9094 check_item : Optional [str ] = None
9195 use_absolute_path : bool = False
9296 include : list [QueryInclude ] = field (
@@ -100,6 +104,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
100104 """
101105 Raise IOError if db_path is not valid.
102106 """
107+ default_config = Config ()
103108 db_path = config_dict .get ("db_path" )
104109 host = config_dict .get ("host" ) or "localhost"
105110 port = config_dict .get ("port" ) or 8000
@@ -112,25 +117,35 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
112117 return Config (
113118 ** {
114119 "embedding_function" : config_dict .get (
115- "embedding_function" , "SentenceTransformerEmbeddingFunction"
120+ "embedding_function" , default_config .embedding_function
121+ ),
122+ "embedding_params" : config_dict .get (
123+ "embedding_params" , default_config .embedding_params
116124 ),
117- "embedding_params" : config_dict .get ("embedding_params" , {}),
118125 "host" : host ,
119126 "port" : port ,
120127 "db_path" : db_path ,
121128 "db_log_path" : os .path .expanduser (
122- config_dict .get ("db_log_path" , "~/.local/share/vectorcode/" )
129+ config_dict .get ("db_log_path" , default_config .db_log_path )
130+ ),
131+ "chunk_size" : config_dict .get ("chunk_size" , default_config .chunk_size ),
132+ "overlap_ratio" : config_dict .get (
133+ "overlap_ratio" , default_config .overlap_ratio
134+ ),
135+ "query_multiplier" : config_dict .get (
136+ "query_multiplier" , default_config .query_multiplier
137+ ),
138+ "reranker" : config_dict .get ("reranker" , default_config .reranker ),
139+ "reranker_params" : config_dict .get (
140+ "reranker_params" , default_config .reranker_params
141+ ),
142+ "db_settings" : config_dict .get (
143+ "db_settings" , default_config .db_settings
123144 ),
124- "chunk_size" : config_dict .get ("chunk_size" , 2500 ),
125- "overlap_ratio" : config_dict .get ("overlap_ratio" , 0.2 ),
126- "query_multiplier" : config_dict .get ("query_multiplier" , - 1 ),
127- "reranker" : config_dict .get (
128- "reranker" , "cross-encoder/ms-marco-MiniLM-L-6-v2"
145+ "hnsw" : config_dict .get ("hnsw" , default_config .hnsw ),
146+ "chunk_filters" : config_dict .get (
147+ "chunk_filters" , default_config .chunk_filters
129148 ),
130- "reranker_params" : config_dict .get ("reranker_params" , {}),
131- "db_settings" : config_dict .get ("db_settings" , None ),
132- "hnsw" : config_dict .get ("hnsw" , {}),
133- "chunk_filters" : config_dict .get ("chunk_filters" , {}),
134149 }
135150 )
136151
0 commit comments