|
| 1 | +import asyncio |
1 | 2 | import logging |
2 | 3 | from abc import ABC, abstractmethod |
3 | | -from typing import Optional, Sequence |
| 4 | +from typing import Optional, Self, Sequence |
4 | 5 |
|
5 | 6 | from chromadb import EmbeddingFunction |
6 | 7 | from numpy.typing import NDArray |
@@ -117,3 +118,35 @@ async def drop(self, collection_path: str): |
117 | 118 | Delete a collection from the database. |
118 | 119 | """ |
119 | 120 | pass |
| 121 | + |
| 122 | + def _check_new_config(self, new_config: Config) -> bool: |
| 123 | + """ |
| 124 | + Verify that the `new_config` is a valid one for updating. |
| 125 | + """ |
| 126 | + assert isinstance(new_config, Config), "`new_config` is not a `Config` object." |
| 127 | + return ( |
| 128 | + new_config.db_type == self._configs.db_type |
| 129 | + and new_config.db_params == self._configs.db_params |
| 130 | + ) |
| 131 | + |
| 132 | + def update_config(self, new_config: Config) -> Self: |
| 133 | + assert self._check_new_config(new_config), ( |
| 134 | + "The new config has different database configs." |
| 135 | + ) |
| 136 | + |
| 137 | + # no need to make this one async |
| 138 | + try: |
| 139 | + loop = asyncio.get_running_loop() |
| 140 | + except RuntimeError: |
| 141 | + loop = asyncio.new_event_loop() |
| 142 | + asyncio.set_event_loop(loop) |
| 143 | + |
| 144 | + self._configs = loop.run_until_complete(self._configs.merge_from(new_config)) |
| 145 | + return self |
| 146 | + |
| 147 | + def replace_config(self, new_config: Config) -> Self: |
| 148 | + assert self._check_new_config(new_config), ( |
| 149 | + "The new config has different database configs." |
| 150 | + ) |
| 151 | + self._configs = new_config |
| 152 | + return self |
0 commit comments