From 4c06297c5e280b46ffbf76770013c579080e3ed5 Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 22:23:44 +0200 Subject: [PATCH 001/125] Add error handling when cleaning up the temporary dir --- mindsdb/__main__.py | 139 +++++++++++++++++++++++++++++++++----------- 1 file changed, 105 insertions(+), 34 deletions(-) diff --git a/mindsdb/__main__.py b/mindsdb/__main__.py index 27d34161a8d..d7b9e17f722 100644 --- a/mindsdb/__main__.py +++ b/mindsdb/__main__.py @@ -35,7 +35,12 @@ ) from mindsdb.utilities.ps import is_pid_listen_port, get_child_pids import mindsdb.interfaces.storage.db as db -from mindsdb.utilities.fs import clean_process_marks, clean_unlinked_process_marks, create_pid_file, delete_pid_file +from mindsdb.utilities.fs import ( + clean_process_marks, + clean_unlinked_process_marks, + create_pid_file, + delete_pid_file, +) from mindsdb.utilities.context import context as ctx from mindsdb.utilities.auth import register_oauth_client, get_aws_meta_data from mindsdb.utilities.sentry import sentry_sdk # noqa: F401 @@ -100,7 +105,9 @@ def request_restart_attempt(self) -> bool: self._restarts_time.append(current_time_seconds) if self.max_restart_interval_seconds > 0: self._restarts_time = [ - x for x in self._restarts_time if x >= (current_time_seconds - self.max_restart_interval_seconds) + x + for x in self._restarts_time + if x >= (current_time_seconds - self.max_restart_interval_seconds) ] if len(self._restarts_time) > self.max_restart_count: return False @@ -117,11 +124,16 @@ def should_restart(self) -> bool: if config.is_cloud: return False if sys.platform in ("linux", "darwin"): - return self.restart_on_failure and self.process.exitcode == -signal.SIGKILL.value + return ( + self.restart_on_failure + and self.process.exitcode == -signal.SIGKILL.value + ) else: if self.max_restart_count == 0: # to prevent infinity restarts, max_restart_count should be > 0 - logger.warning("In the current OS, it is not possible to use `max_restart_count=0`") + logger.warning( + "In the current OS, it is not possible to use `max_restart_count=0`" + ) return False return self.restart_on_failure @@ -155,12 +167,23 @@ def close_api_gracefully(trunc_processes_struct): def clean_mindsdb_tmp_dir(): """Clean the MindsDB tmp dir at exit.""" - temp_dir = config["paths"]["tmp"] - for file in temp_dir.iterdir(): - if file.is_dir(): - shutil.rmtree(file) - else: - file.unlink() + try: + temp_dir = config["paths"]["tmp"] + if not temp_dir.exists(): + return + + for file in temp_dir.iterdir(): + try: + if file.is_dir(): + # https://docs.python.org/3/library/shutil.html#shutil.rmtree + shutil.rmtree(file, ignore_errors=True) + else: + # https://docs.python.org/3/library/pathlib.html#pathlib.Path.unlink + file.unlink(missing_ok=True) + except Exception as e: + logger.error(f"Failed to clean {file}: {e}") + except Exception as e: + logger.error(f"Failed to clean MindsDB tmp dir: {e}") def set_error_model_status_by_pids(unexisting_pids: List[int]): @@ -175,18 +198,24 @@ def set_error_model_status_by_pids(unexisting_pids: List[int]): db.session.query(db.Predictor) .filter( db.Predictor.deleted_at.is_(None), - db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]), + db.Predictor.status.not_in( + [db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR] + ), ) .all() ) for predictor_record in predictor_records: - predictor_process_id = (predictor_record.training_metadata or {}).get("process_id") + predictor_process_id = (predictor_record.training_metadata or {}).get( + "process_id" + ) if predictor_process_id in unexisting_pids: predictor_record.status = db.PREDICTOR_STATUS.ERROR if isinstance(predictor_record.data, dict) is False: predictor_record.data = {} if "error" not in predictor_record.data: - predictor_record.data["error"] = "The training process was terminated for unknown reasons" + predictor_record.data["error"] = ( + "The training process was terminated for unknown reasons" + ) flag_modified(predictor_record, "data") db.session.commit() @@ -199,7 +228,9 @@ def set_error_model_status_for_unfinished(): db.session.query(db.Predictor) .filter( db.Predictor.deleted_at.is_(None), - db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]), + db.Predictor.status.not_in( + [db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR] + ), ) .all() ) @@ -227,7 +258,11 @@ def create_permanent_integrations(): NOTE: this is intentional to avoid importing integration_controller """ integration_name = "files" - existing = db.session.query(db.Integration).filter_by(name=integration_name, company_id=None).first() + existing = ( + db.session.query(db.Integration) + .filter_by(name=integration_name, company_id=None) + .first() + ) if existing is not None: return integration_record = db.Integration( @@ -240,7 +275,9 @@ def create_permanent_integrations(): try: db.session.commit() except Exception: - logger.exception(f"Failed to create permanent integration '{integration_name}' in the internal database.") + logger.exception( + f"Failed to create permanent integration '{integration_name}' in the internal database." + ) db.session.rollback() @@ -266,7 +303,9 @@ def validate_default_project() -> None: func.lower(db.Project.name) == func.lower(new_default_project_name), ).first() if existing_project is None: - logger.critical(f"A project with the name '{new_default_project_name}' does not exist") + logger.critical( + f"A project with the name '{new_default_project_name}' does not exist" + ) sys.exit(1) existing_project.metadata_ = {"is_default": True} @@ -279,7 +318,9 @@ def validate_default_project() -> None: func.lower(db.Project.name) == func.lower(new_default_project_name), ).first() if existing_project is not None: - logger.critical(f"A project with the name '{new_default_project_name}' already exists") + logger.critical( + f"A project with the name '{new_default_project_name}' already exists" + ) sys.exit(1) current_default_project.name = new_default_project_name db.session.commit() @@ -301,7 +342,9 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: ) trunc_process_data.process.start() except Exception as e: - logger.exception(f"Failed to start '{trunc_process_data.name}' API process due to unexpected error:") + logger.exception( + f"Failed to start '{trunc_process_data.name}' API process due to unexpected error:" + ) close_api_gracefully(trunc_processes_struct) raise e @@ -311,7 +354,8 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: # warn if less than 1Gb of free RAM if psutil.virtual_memory().available < (1 << 30): logger.warning( - "The system is running low on memory. " + "This may impact the stability and performance of the program." + "The system is running low on memory. " + + "This may impact the stability and performance of the program." ) ctx.set_default() @@ -409,7 +453,9 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: migrate.migrate_to_head() except Exception: - logger.exception("Failed to apply database migrations. This may prevent MindsDB from operating correctly:") + logger.exception( + "Failed to apply database migrations. This may prevent MindsDB from operating correctly:" + ) validate_default_project() @@ -431,9 +477,12 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: port=http_api_config["port"], args=(config.cmd_args.verbose,), restart_on_failure=http_api_config.get("restart_on_failure", False), - max_restart_count=http_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), + max_restart_count=http_api_config.get( + "max_restart_count", TrunkProcessData.max_restart_count + ), max_restart_interval_seconds=http_api_config.get( - "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds + "max_restart_interval_seconds", + TrunkProcessData.max_restart_interval_seconds, ), ), TrunkProcessEnum.MYSQL: TrunkProcessData( @@ -442,9 +491,12 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: port=mysql_api_config["port"], args=(config.cmd_args.verbose,), restart_on_failure=mysql_api_config.get("restart_on_failure", False), - max_restart_count=mysql_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), + max_restart_count=mysql_api_config.get( + "max_restart_count", TrunkProcessData.max_restart_count + ), max_restart_interval_seconds=mysql_api_config.get( - "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds + "max_restart_interval_seconds", + TrunkProcessData.max_restart_interval_seconds, ), ), TrunkProcessEnum.POSTGRES: TrunkProcessData( @@ -454,13 +506,19 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: args=(config.cmd_args.verbose,), ), TrunkProcessEnum.JOBS: TrunkProcessData( - name=TrunkProcessEnum.JOBS.value, entrypoint=start_scheduler, args=(config.cmd_args.verbose,) + name=TrunkProcessEnum.JOBS.value, + entrypoint=start_scheduler, + args=(config.cmd_args.verbose,), ), TrunkProcessEnum.TASKS: TrunkProcessData( - name=TrunkProcessEnum.TASKS.value, entrypoint=start_tasks, args=(config.cmd_args.verbose,) + name=TrunkProcessEnum.TASKS.value, + entrypoint=start_tasks, + args=(config.cmd_args.verbose,), ), TrunkProcessEnum.ML_TASK_QUEUE: TrunkProcessData( - name=TrunkProcessEnum.ML_TASK_QUEUE.value, entrypoint=start_ml_task_queue, args=(config.cmd_args.verbose,) + name=TrunkProcessEnum.ML_TASK_QUEUE.value, + entrypoint=start_ml_task_queue, + args=(config.cmd_args.verbose,), ), TrunkProcessEnum.LITELLM: TrunkProcessData( name=TrunkProcessEnum.LITELLM.value, @@ -468,9 +526,12 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: port=litellm_api_config.get("port", 8000), args=(config.cmd_args.verbose,), restart_on_failure=litellm_api_config.get("restart_on_failure", False), - max_restart_count=litellm_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), + max_restart_count=litellm_api_config.get( + "max_restart_count", TrunkProcessData.max_restart_count + ), max_restart_interval_seconds=litellm_api_config.get( - "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds + "max_restart_interval_seconds", + TrunkProcessData.max_restart_interval_seconds, ), ), } @@ -493,7 +554,10 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: create_pid_file() for trunc_process_data in trunc_processes_struct.values(): - if trunc_process_data.started is True or trunc_process_data.need_to_run is False: + if ( + trunc_process_data.started is True + or trunc_process_data.need_to_run is False + ): continue start_process(trunc_process_data) # Set status for APIs without ports (they don't go through wait_api_start) @@ -523,7 +587,8 @@ async def wait_apis_start(): trunc_process_data.port, ) for trunc_process_data in trunc_processes_struct.values() - if trunc_process_data.port is not None and trunc_process_data.need_to_run is True + if trunc_process_data.port is not None + and trunc_process_data.need_to_run is True ] for future in asyncio.as_completed(futures): api_name, port, started = await future @@ -546,11 +611,17 @@ async def join_process(trunc_process_data: TrunkProcessData): finally: if trunc_process_data.should_restart: if trunc_process_data.request_restart_attempt(): - logger.warning(f"{trunc_process_data.name} API: stopped unexpectedly, restarting") + logger.warning( + f"{trunc_process_data.name} API: stopped unexpectedly, restarting" + ) trunc_process_data.process = None if trunc_process_data.name == TrunkProcessEnum.HTTP.value: # do not open GUI on HTTP API restart - trunc_process_data.args = (config.cmd_args.verbose, None, True) + trunc_process_data.args = ( + config.cmd_args.verbose, + None, + True, + ) start_process(trunc_process_data) api_name, port, started = await wait_api_start( trunc_process_data.name, From 04878b42187d0aaa317f43b9da9fc451c4053eec Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 22:29:05 +0200 Subject: [PATCH 002/125] Add tests for the main module and cleanup of temp dir --- tests/unit/various/test_main.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/unit/various/test_main.py diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py new file mode 100644 index 00000000000..06ebf2e3a25 --- /dev/null +++ b/tests/unit/various/test_main.py @@ -0,0 +1,26 @@ +import pytest +from unittest.mock import patch + + +@pytest.fixture +def patch_main_config(tmp_path, monkeypatch): + import mindsdb.__main__ as main_mod + + monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": tmp_path}}) + return tmp_path, main_mod + + +class TestMainCleanup: + + def test_cleans_files_and_dirs_but_keeps_tmp_path(self, patch_main_config): + """Test that all content is cleaned but tmp_path itself remains""" + tmp_path, main_mod = patch_main_config + + (tmp_path / "a.txt").write_text("hello") + (tmp_path / "sub").mkdir() + (tmp_path / "sub" / "b.txt").write_text("world") + + main_mod.clean_mindsdb_tmp_dir() + + assert tmp_path.exists(), "tmp_path itself should not be deleted" + assert list(tmp_path.iterdir()) == [], "All content should be removed" From f0f5db52b38ea65c611fe7a4ceaa061763672876 Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 22:33:07 +0200 Subject: [PATCH 003/125] Test empty dir --- tests/unit/various/test_main.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py index 06ebf2e3a25..8067a891058 100644 --- a/tests/unit/various/test_main.py +++ b/tests/unit/various/test_main.py @@ -22,5 +22,16 @@ def test_cleans_files_and_dirs_but_keeps_tmp_path(self, patch_main_config): main_mod.clean_mindsdb_tmp_dir() - assert tmp_path.exists(), "tmp_path itself should not be deleted" - assert list(tmp_path.iterdir()) == [], "All content should be removed" + assert tmp_path.exists() + assert list(tmp_path.iterdir()) == [] + + def test_empty_tmp_path(self, patch_main_config): + """Test that cleaning an already empty tmp_path works without errors""" + tmp_path, main_mod = patch_main_config + + assert list(tmp_path.iterdir()) == [] + main_mod.clean_mindsdb_tmp_dir() + assert ( + list(tmp_path.iterdir()) == [] + ) + assert tmp_path.exists() From b977100a22272aae90712c3cf4a39625ec0281ce Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 22:36:28 +0200 Subject: [PATCH 004/125] Test mixed content and nested dirs --- tests/unit/various/test_main.py | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py index 8067a891058..1d4e5e8aa24 100644 --- a/tests/unit/various/test_main.py +++ b/tests/unit/various/test_main.py @@ -35,3 +35,44 @@ def test_empty_tmp_path(self, patch_main_config): list(tmp_path.iterdir()) == [] ) assert tmp_path.exists() + + def test_nonexistent_tmp_path(self, monkeypatch): + """Test that cleaning a non-existent tmp_path does not raise errors""" + import mindsdb.__main__ as main_mod + + monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": "/nonexistent/path"}}) + + try: + main_mod.clean_mindsdb_tmp_dir() + except Exception as e: + pytest.fail(f"clean_mindsdb_tmp_dir raised an exception: {e}") + + def test_mixed_content_cleanup(self, patch_main_config): + """Test that a mix of files and directories are cleaned properly""" + tmp_path, main_mod = patch_main_config + + (tmp_path / "file1.txt").write_text("file1") + (tmp_path / "dir1").mkdir() + (tmp_path / "dir1" / "file2.txt").write_text("file2") + (tmp_path / "dir2").mkdir() + + main_mod.clean_mindsdb_tmp_dir() + + assert tmp_path.exists() + assert list(tmp_path.iterdir()) == [] + + def test_nested_directories_cleanup(self, patch_main_config): + """Test that nested directories are cleaned properly""" + tmp_path, main_mod = patch_main_config + + (tmp_path / "dir1").mkdir() + (tmp_path / "dir1" / "subdir1").mkdir() + (tmp_path / "dir1" / "subdir1" / "file1.txt").write_text("file1") + (tmp_path / "dir2").mkdir() + (tmp_path / "dir2" / "file2.txt").write_text("file2") + + main_mod.clean_mindsdb_tmp_dir() + + assert tmp_path.exists() + assert list(tmp_path.iterdir()) == [] + \ No newline at end of file From f63df1b1a5765979bff067d0407747b574288f64 Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 22:37:27 +0200 Subject: [PATCH 005/125] Test rmtree error handling --- tests/unit/various/test_main.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py index 1d4e5e8aa24..17dffa21f0a 100644 --- a/tests/unit/various/test_main.py +++ b/tests/unit/various/test_main.py @@ -31,9 +31,7 @@ def test_empty_tmp_path(self, patch_main_config): assert list(tmp_path.iterdir()) == [] main_mod.clean_mindsdb_tmp_dir() - assert ( - list(tmp_path.iterdir()) == [] - ) + assert list(tmp_path.iterdir()) == [] assert tmp_path.exists() def test_nonexistent_tmp_path(self, monkeypatch): @@ -60,7 +58,7 @@ def test_mixed_content_cleanup(self, patch_main_config): assert tmp_path.exists() assert list(tmp_path.iterdir()) == [] - + def test_nested_directories_cleanup(self, patch_main_config): """Test that nested directories are cleaned properly""" tmp_path, main_mod = patch_main_config @@ -75,4 +73,20 @@ def test_nested_directories_cleanup(self, patch_main_config): assert tmp_path.exists() assert list(tmp_path.iterdir()) == [] - \ No newline at end of file + + def test_rmtree_failure_handling(self, patch_main_config, monkeypatch): + """Test that exceptions during rmtree are logged but do not stop cleanup""" + tmp_path, main_mod = patch_main_config + + (tmp_path / "dir1").mkdir() + (tmp_path / "dir1" / "file1.txt").write_text("file1") + + def mock_rmtree(path, ignore_errors): + raise Exception("Simulated rmtree failure") + + monkeypatch.setattr(main_mod.shutil, "rmtree", mock_rmtree) + + main_mod.clean_mindsdb_tmp_dir() + + assert (tmp_path / "dir1").exists() + assert (tmp_path / "dir1" / "file1.txt").exists() From 2adda67d4e3daee41713b209afc4399eec704ac5 Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 22:58:37 +0200 Subject: [PATCH 006/125] Add more tests and cleanup code --- tests/unit/various/test_main.py | 223 ++++++++++++++++++++++++++------ 1 file changed, 187 insertions(+), 36 deletions(-) diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py index 17dffa21f0a..5816176575e 100644 --- a/tests/unit/various/test_main.py +++ b/tests/unit/various/test_main.py @@ -1,16 +1,30 @@ import pytest from unittest.mock import patch +import pathlib +import shutil @pytest.fixture -def patch_main_config(tmp_path, monkeypatch): - import mindsdb.__main__ as main_mod +def errors(caplog): + """Module-level fixture to capture ERROR logs and expose `.text`.""" + caplog.clear() + caplog.set_level("ERROR") - monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": tmp_path}}) - return tmp_path, main_mod + class E: + @property + def text(self): + return "\n".join(r.getMessage() for r in caplog.records) + + return E() class TestMainCleanup: + @pytest.fixture + def patch_main_config(self, tmp_path, monkeypatch): + import mindsdb.__main__ as main_mod + + monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": tmp_path}}) + return tmp_path, main_mod def test_cleans_files_and_dirs_but_keeps_tmp_path(self, patch_main_config): """Test that all content is cleaned but tmp_path itself remains""" @@ -22,71 +36,208 @@ def test_cleans_files_and_dirs_but_keeps_tmp_path(self, patch_main_config): main_mod.clean_mindsdb_tmp_dir() + assert tmp_path.exists(), "tmp_path itself should not be deleted" + assert list(tmp_path.iterdir()) == [], "All content should be removed" + + def test_empty_directory(self, patch_main_config): + """Test cleaning an already empty directory""" + tmp_path, main_mod = patch_main_config + + main_mod.clean_mindsdb_tmp_dir() + assert tmp_path.exists() assert list(tmp_path.iterdir()) == [] - def test_empty_tmp_path(self, patch_main_config): - """Test that cleaning an already empty tmp_path works without errors""" + def test_mixed_files_and_directories(self, patch_main_config): + """Test cleaning mixed content types""" tmp_path, main_mod = patch_main_config - assert list(tmp_path.iterdir()) == [] + (tmp_path / "file1.txt").write_text("a") + (tmp_path / "dir1").mkdir() + (tmp_path / "dir1" / "nested.txt").write_text("b") + (tmp_path / "file2.log").write_text("c") + (tmp_path / "dir2").mkdir() + main_mod.clean_mindsdb_tmp_dir() + + assert tmp_path.exists() assert list(tmp_path.iterdir()) == [] + + def test_deeply_nested_directories(self, patch_main_config): + """Test that deeply nested directories are fully removed""" + tmp_path, main_mod = patch_main_config + + deep = tmp_path / "a" / "b" / "c" / "d" + deep.mkdir(parents=True) + (deep / "file.txt").write_text("deep") + + main_mod.clean_mindsdb_tmp_dir() + assert tmp_path.exists() + assert not (tmp_path / "a").exists() - def test_nonexistent_tmp_path(self, monkeypatch): - """Test that cleaning a non-existent tmp_path does not raise errors""" - import mindsdb.__main__ as main_mod + def test_rmtree_failure_continues_and_logs(self, patch_main_config, errors): + """Test that rmtree failure is logged and cleanup continues""" + tmp_path, main_mod = patch_main_config + + (tmp_path / "file.txt").write_text("content") + (tmp_path / "failing_dir").mkdir() + (tmp_path / "another_file.txt").write_text("more content") + (tmp_path / "good_dir").mkdir() - monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": "/nonexistent/path"}}) + original_rmtree = shutil.rmtree - try: + def mock_rmtree(path, *args, **kwargs): + if "failing_dir" in str(path): + raise PermissionError("Cannot delete directory") + return original_rmtree(path, *args, **kwargs) + + with patch("shutil.rmtree", mock_rmtree): main_mod.clean_mindsdb_tmp_dir() - except Exception as e: - pytest.fail(f"clean_mindsdb_tmp_dir raised an exception: {e}") - def test_mixed_content_cleanup(self, patch_main_config): - """Test that a mix of files and directories are cleaned properly""" + assert "Failed to clean" in errors.text + assert "Cannot delete directory" in errors.text + + assert not (tmp_path / "file.txt").exists() + assert not (tmp_path / "another_file.txt").exists() + assert not (tmp_path / "good_dir").exists() + assert (tmp_path / "failing_dir").exists() + + def test_unlink_failure_continues_and_logs(self, patch_main_config, errors): + """Test that unlink failure is logged and cleanup continues""" tmp_path, main_mod = patch_main_config - (tmp_path / "file1.txt").write_text("file1") - (tmp_path / "dir1").mkdir() - (tmp_path / "dir1" / "file2.txt").write_text("file2") - (tmp_path / "dir2").mkdir() + (tmp_path / "file1.txt").write_text("a") + (tmp_path / "failing_file.txt").write_text("b") + (tmp_path / "file2.txt").write_text("c") + + original_unlink = pathlib.Path.unlink + + def mock_unlink(self, *args, **kwargs): + if self.name == "failing_file.txt": + raise PermissionError("Cannot delete file") + return original_unlink(self, *args, **kwargs) + + with patch.object(pathlib.Path, "unlink", mock_unlink): + main_mod.clean_mindsdb_tmp_dir() + + assert "Failed to clean" in errors.text + assert "Cannot delete file" in errors.text + + assert not (tmp_path / "file1.txt").exists() + assert (tmp_path / "failing_file.txt").exists() + assert not (tmp_path / "file2.txt").exists() + + def test_special_files_are_removed(self, patch_main_config): + """Test that hidden files and special names are cleaned""" + tmp_path, main_mod = patch_main_config + + (tmp_path / ".hidden").write_text("hidden") + (tmp_path / ".config").mkdir() + (tmp_path / "__pycache__").mkdir() main_mod.clean_mindsdb_tmp_dir() assert tmp_path.exists() assert list(tmp_path.iterdir()) == [] - def test_nested_directories_cleanup(self, patch_main_config): - """Test that nested directories are cleaned properly""" + def test_large_directory_tree(self, patch_main_config): + """Test cleaning a large directory structure""" tmp_path, main_mod = patch_main_config - (tmp_path / "dir1").mkdir() - (tmp_path / "dir1" / "subdir1").mkdir() - (tmp_path / "dir1" / "subdir1" / "file1.txt").write_text("file1") - (tmp_path / "dir2").mkdir() - (tmp_path / "dir2" / "file2.txt").write_text("file2") + for i in range(10): + dir_path = tmp_path / f"dir{i}" + dir_path.mkdir() + for j in range(10): + (dir_path / f"file{j}.txt").write_text(f"content{i}{j}") main_mod.clean_mindsdb_tmp_dir() assert tmp_path.exists() assert list(tmp_path.iterdir()) == [] - def test_rmtree_failure_handling(self, patch_main_config, monkeypatch): - """Test that exceptions during rmtree are logged but do not stop cleanup""" + def test_mixed_failures_continue_cleanup(self, patch_main_config, errors): + """Test that multiple failures don't stop the cleanup process""" tmp_path, main_mod = patch_main_config - (tmp_path / "dir1").mkdir() - (tmp_path / "dir1" / "file1.txt").write_text("file1") + (tmp_path / "good_file1.txt").write_text("a") + (tmp_path / "failing_file.txt").write_text("b") + (tmp_path / "good_file2.txt").write_text("c") + (tmp_path / "failing_dir").mkdir() + (tmp_path / "good_dir").mkdir() + + original_unlink = pathlib.Path.unlink + original_rmtree = shutil.rmtree + + def mock_unlink(self, *args, **kwargs): + if self.name == "failing_file.txt": + raise PermissionError("Cannot delete file") + return original_unlink(self, *args, **kwargs) + + def mock_rmtree(path, *args, **kwargs): + if "failing_dir" in str(path): + raise PermissionError("Cannot delete directory") + return original_rmtree(path, *args, **kwargs) + + with patch.object(pathlib.Path, "unlink", mock_unlink): + with patch("shutil.rmtree", mock_rmtree): + main_mod.clean_mindsdb_tmp_dir() + + log_text = errors.text + assert log_text.count("Failed to clean") >= 2 + + assert not (tmp_path / "good_file1.txt").exists() + assert not (tmp_path / "good_file2.txt").exists() + assert not (tmp_path / "good_dir").exists() + + assert (tmp_path / "failing_file.txt").exists() + assert (tmp_path / "failing_dir").exists() + + def test_logger_called_with_correct_level(self, patch_main_config): + """Test that errors are logged at ERROR level""" + tmp_path, main_mod = patch_main_config + + (tmp_path / "failing_file.txt").write_text("content") - def mock_rmtree(path, ignore_errors): - raise Exception("Simulated rmtree failure") + original_unlink = pathlib.Path.unlink - monkeypatch.setattr(main_mod.shutil, "rmtree", mock_rmtree) + def mock_unlink(self, *args, **kwargs): + if self.name == "failing_file.txt": + raise PermissionError("Test error") + return original_unlink(self, *args, **kwargs) + + with patch.object(pathlib.Path, "unlink", mock_unlink): + with patch("mindsdb.__main__.logger") as mock_logger: + main_mod.clean_mindsdb_tmp_dir() + + assert mock_logger.error.called or mock_logger.exception.called + + def test_nonexistent_tmp_path(self, monkeypatch): + """Test handling when tmp path doesn't exist""" + import mindsdb.__main__ as main_mod + from pathlib import Path + + nonexistent = Path("/tmp/nonexistent_mindsdb_test_dir_12345") + assert not nonexistent.exists() + + monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": nonexistent}}) + + main_mod.clean_mindsdb_tmp_dir() + assert not nonexistent.exists() + + def test_symlinks_are_handled(self, patch_main_config): + """Test that symlinks are removed without following them""" + tmp_path, main_mod = patch_main_config + + external_file = tmp_path.parent / "external.txt" + external_file.write_text("external") + + (tmp_path / "link_to_external").symlink_to(external_file) main_mod.clean_mindsdb_tmp_dir() - assert (tmp_path / "dir1").exists() - assert (tmp_path / "dir1" / "file1.txt").exists() + assert tmp_path.exists() + assert list(tmp_path.iterdir()) == [] + assert external_file.exists() + + external_file.unlink() \ No newline at end of file From cb2bde946189e48321fdb72566d5ac7bbfb7f1f1 Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 22:59:25 +0200 Subject: [PATCH 007/125] cleanup code --- tests/unit/various/test_main.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py index 5816176575e..bfce7109ddf 100644 --- a/tests/unit/various/test_main.py +++ b/tests/unit/various/test_main.py @@ -4,20 +4,6 @@ import shutil -@pytest.fixture -def errors(caplog): - """Module-level fixture to capture ERROR logs and expose `.text`.""" - caplog.clear() - caplog.set_level("ERROR") - - class E: - @property - def text(self): - return "\n".join(r.getMessage() for r in caplog.records) - - return E() - - class TestMainCleanup: @pytest.fixture def patch_main_config(self, tmp_path, monkeypatch): @@ -26,6 +12,19 @@ def patch_main_config(self, tmp_path, monkeypatch): monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": tmp_path}}) return tmp_path, main_mod + @pytest.fixture + def errors(self, caplog): + """Fixture to capture error log messages.""" + + class ErrorCapture: + @property + def text(self): + return "\n".join([r.getMessage() for r in caplog.records]) + + caplog.clear() + caplog.set_level("ERROR") + return ErrorCapture() + def test_cleans_files_and_dirs_but_keeps_tmp_path(self, patch_main_config): """Test that all content is cleaned but tmp_path itself remains""" tmp_path, main_mod = patch_main_config @@ -240,4 +239,4 @@ def test_symlinks_are_handled(self, patch_main_config): assert list(tmp_path.iterdir()) == [] assert external_file.exists() - external_file.unlink() \ No newline at end of file + external_file.unlink() From fe821bb4e6eaa8a4ae425112ca33a340545559b3 Mon Sep 17 00:00:00 2001 From: tino097 Date: Sun, 12 Oct 2025 23:05:26 +0200 Subject: [PATCH 008/125] Ruff format --- mindsdb/__main__.py | 82 +++++++++++---------------------------------- 1 file changed, 20 insertions(+), 62 deletions(-) diff --git a/mindsdb/__main__.py b/mindsdb/__main__.py index d7b9e17f722..d89298dcf5a 100644 --- a/mindsdb/__main__.py +++ b/mindsdb/__main__.py @@ -105,9 +105,7 @@ def request_restart_attempt(self) -> bool: self._restarts_time.append(current_time_seconds) if self.max_restart_interval_seconds > 0: self._restarts_time = [ - x - for x in self._restarts_time - if x >= (current_time_seconds - self.max_restart_interval_seconds) + x for x in self._restarts_time if x >= (current_time_seconds - self.max_restart_interval_seconds) ] if len(self._restarts_time) > self.max_restart_count: return False @@ -124,16 +122,11 @@ def should_restart(self) -> bool: if config.is_cloud: return False if sys.platform in ("linux", "darwin"): - return ( - self.restart_on_failure - and self.process.exitcode == -signal.SIGKILL.value - ) + return self.restart_on_failure and self.process.exitcode == -signal.SIGKILL.value else: if self.max_restart_count == 0: # to prevent infinity restarts, max_restart_count should be > 0 - logger.warning( - "In the current OS, it is not possible to use `max_restart_count=0`" - ) + logger.warning("In the current OS, it is not possible to use `max_restart_count=0`") return False return self.restart_on_failure @@ -198,24 +191,18 @@ def set_error_model_status_by_pids(unexisting_pids: List[int]): db.session.query(db.Predictor) .filter( db.Predictor.deleted_at.is_(None), - db.Predictor.status.not_in( - [db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR] - ), + db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]), ) .all() ) for predictor_record in predictor_records: - predictor_process_id = (predictor_record.training_metadata or {}).get( - "process_id" - ) + predictor_process_id = (predictor_record.training_metadata or {}).get("process_id") if predictor_process_id in unexisting_pids: predictor_record.status = db.PREDICTOR_STATUS.ERROR if isinstance(predictor_record.data, dict) is False: predictor_record.data = {} if "error" not in predictor_record.data: - predictor_record.data["error"] = ( - "The training process was terminated for unknown reasons" - ) + predictor_record.data["error"] = "The training process was terminated for unknown reasons" flag_modified(predictor_record, "data") db.session.commit() @@ -228,9 +215,7 @@ def set_error_model_status_for_unfinished(): db.session.query(db.Predictor) .filter( db.Predictor.deleted_at.is_(None), - db.Predictor.status.not_in( - [db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR] - ), + db.Predictor.status.not_in([db.PREDICTOR_STATUS.COMPLETE, db.PREDICTOR_STATUS.ERROR]), ) .all() ) @@ -258,11 +243,7 @@ def create_permanent_integrations(): NOTE: this is intentional to avoid importing integration_controller """ integration_name = "files" - existing = ( - db.session.query(db.Integration) - .filter_by(name=integration_name, company_id=None) - .first() - ) + existing = db.session.query(db.Integration).filter_by(name=integration_name, company_id=None).first() if existing is not None: return integration_record = db.Integration( @@ -275,9 +256,7 @@ def create_permanent_integrations(): try: db.session.commit() except Exception: - logger.exception( - f"Failed to create permanent integration '{integration_name}' in the internal database." - ) + logger.exception(f"Failed to create permanent integration '{integration_name}' in the internal database.") db.session.rollback() @@ -303,9 +282,7 @@ def validate_default_project() -> None: func.lower(db.Project.name) == func.lower(new_default_project_name), ).first() if existing_project is None: - logger.critical( - f"A project with the name '{new_default_project_name}' does not exist" - ) + logger.critical(f"A project with the name '{new_default_project_name}' does not exist") sys.exit(1) existing_project.metadata_ = {"is_default": True} @@ -318,9 +295,7 @@ def validate_default_project() -> None: func.lower(db.Project.name) == func.lower(new_default_project_name), ).first() if existing_project is not None: - logger.critical( - f"A project with the name '{new_default_project_name}' already exists" - ) + logger.critical(f"A project with the name '{new_default_project_name}' already exists") sys.exit(1) current_default_project.name = new_default_project_name db.session.commit() @@ -342,9 +317,7 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: ) trunc_process_data.process.start() except Exception as e: - logger.exception( - f"Failed to start '{trunc_process_data.name}' API process due to unexpected error:" - ) + logger.exception(f"Failed to start '{trunc_process_data.name}' API process due to unexpected error:") close_api_gracefully(trunc_processes_struct) raise e @@ -354,8 +327,7 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: # warn if less than 1Gb of free RAM if psutil.virtual_memory().available < (1 << 30): logger.warning( - "The system is running low on memory. " - + "This may impact the stability and performance of the program." + "The system is running low on memory. " + "This may impact the stability and performance of the program." ) ctx.set_default() @@ -453,9 +425,7 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: migrate.migrate_to_head() except Exception: - logger.exception( - "Failed to apply database migrations. This may prevent MindsDB from operating correctly:" - ) + logger.exception("Failed to apply database migrations. This may prevent MindsDB from operating correctly:") validate_default_project() @@ -477,9 +447,7 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: port=http_api_config["port"], args=(config.cmd_args.verbose,), restart_on_failure=http_api_config.get("restart_on_failure", False), - max_restart_count=http_api_config.get( - "max_restart_count", TrunkProcessData.max_restart_count - ), + max_restart_count=http_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), max_restart_interval_seconds=http_api_config.get( "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds, @@ -491,9 +459,7 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: port=mysql_api_config["port"], args=(config.cmd_args.verbose,), restart_on_failure=mysql_api_config.get("restart_on_failure", False), - max_restart_count=mysql_api_config.get( - "max_restart_count", TrunkProcessData.max_restart_count - ), + max_restart_count=mysql_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), max_restart_interval_seconds=mysql_api_config.get( "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds, @@ -526,9 +492,7 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: port=litellm_api_config.get("port", 8000), args=(config.cmd_args.verbose,), restart_on_failure=litellm_api_config.get("restart_on_failure", False), - max_restart_count=litellm_api_config.get( - "max_restart_count", TrunkProcessData.max_restart_count - ), + max_restart_count=litellm_api_config.get("max_restart_count", TrunkProcessData.max_restart_count), max_restart_interval_seconds=litellm_api_config.get( "max_restart_interval_seconds", TrunkProcessData.max_restart_interval_seconds, @@ -554,10 +518,7 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: create_pid_file() for trunc_process_data in trunc_processes_struct.values(): - if ( - trunc_process_data.started is True - or trunc_process_data.need_to_run is False - ): + if trunc_process_data.started is True or trunc_process_data.need_to_run is False: continue start_process(trunc_process_data) # Set status for APIs without ports (they don't go through wait_api_start) @@ -587,8 +548,7 @@ async def wait_apis_start(): trunc_process_data.port, ) for trunc_process_data in trunc_processes_struct.values() - if trunc_process_data.port is not None - and trunc_process_data.need_to_run is True + if trunc_process_data.port is not None and trunc_process_data.need_to_run is True ] for future in asyncio.as_completed(futures): api_name, port, started = await future @@ -611,9 +571,7 @@ async def join_process(trunc_process_data: TrunkProcessData): finally: if trunc_process_data.should_restart: if trunc_process_data.request_restart_attempt(): - logger.warning( - f"{trunc_process_data.name} API: stopped unexpectedly, restarting" - ) + logger.warning(f"{trunc_process_data.name} API: stopped unexpectedly, restarting") trunc_process_data.process = None if trunc_process_data.name == TrunkProcessEnum.HTTP.value: # do not open GUI on HTTP API restart From 208e49bc6a3d2e98a97f6c9e391bf726b54f7986 Mon Sep 17 00:00:00 2001 From: tino097 Date: Wed, 22 Oct 2025 19:44:48 +0200 Subject: [PATCH 009/125] Refactor tests and improve try-catch block --- mindsdb/__main__.py | 8 +- tests/unit/various/test_main.py | 178 +++++++++++--------------------- 2 files changed, 68 insertions(+), 118 deletions(-) diff --git a/mindsdb/__main__.py b/mindsdb/__main__.py index 2fe101826ba..e8cc4a8e5d1 100644 --- a/mindsdb/__main__.py +++ b/mindsdb/__main__.py @@ -167,12 +167,14 @@ def clean_mindsdb_tmp_dir(): try: if file.is_dir(): # https://docs.python.org/3/library/shutil.html#shutil.rmtree - shutil.rmtree(file, ignore_errors=True) + shutil.rmtree(file) else: # https://docs.python.org/3/library/pathlib.html#pathlib.Path.unlink file.unlink(missing_ok=True) - except Exception as e: - logger.error(f"Failed to clean {file}: {e}") + except PermissionError as e: + logger.error(f"Failed to clean %s: %s{file}: {e}") + except FileNotFoundError: + logger.error(f"File not found during cleanup: {file}") except Exception as e: logger.error(f"Failed to clean MindsDB tmp dir: {e}") diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py index bfce7109ddf..b08ba77146d 100644 --- a/tests/unit/various/test_main.py +++ b/tests/unit/various/test_main.py @@ -1,7 +1,7 @@ -import pytest -from unittest.mock import patch import pathlib import shutil +from unittest.mock import patch +import pytest class TestMainCleanup: @@ -14,24 +14,25 @@ def patch_main_config(self, tmp_path, monkeypatch): @pytest.fixture def errors(self, caplog): - """Fixture to capture error log messages.""" + """Capture only ERROR logs as concatenated text""" class ErrorCapture: @property def text(self): - return "\n".join([r.getMessage() for r in caplog.records]) + return "\n".join( + r.getMessage() for r in caplog.records if r.levelname == "ERROR" + ) caplog.clear() caplog.set_level("ERROR") return ErrorCapture() def test_cleans_files_and_dirs_but_keeps_tmp_path(self, patch_main_config): - """Test that all content is cleaned but tmp_path itself remains""" tmp_path, main_mod = patch_main_config - (tmp_path / "a.txt").write_text("hello") - (tmp_path / "sub").mkdir() - (tmp_path / "sub" / "b.txt").write_text("world") + sub = tmp_path / "sub" + sub.mkdir() + (sub / "b.txt").write_text("world") main_mod.clean_mindsdb_tmp_dir() @@ -39,33 +40,13 @@ def test_cleans_files_and_dirs_but_keeps_tmp_path(self, patch_main_config): assert list(tmp_path.iterdir()) == [], "All content should be removed" def test_empty_directory(self, patch_main_config): - """Test cleaning an already empty directory""" tmp_path, main_mod = patch_main_config - - main_mod.clean_mindsdb_tmp_dir() - - assert tmp_path.exists() - assert list(tmp_path.iterdir()) == [] - - def test_mixed_files_and_directories(self, patch_main_config): - """Test cleaning mixed content types""" - tmp_path, main_mod = patch_main_config - - (tmp_path / "file1.txt").write_text("a") - (tmp_path / "dir1").mkdir() - (tmp_path / "dir1" / "nested.txt").write_text("b") - (tmp_path / "file2.log").write_text("c") - (tmp_path / "dir2").mkdir() - main_mod.clean_mindsdb_tmp_dir() - assert tmp_path.exists() assert list(tmp_path.iterdir()) == [] def test_deeply_nested_directories(self, patch_main_config): - """Test that deeply nested directories are fully removed""" tmp_path, main_mod = patch_main_config - deep = tmp_path / "a" / "b" / "c" / "d" deep.mkdir(parents=True) (deep / "file.txt").write_text("deep") @@ -75,40 +56,27 @@ def test_deeply_nested_directories(self, patch_main_config): assert tmp_path.exists() assert not (tmp_path / "a").exists() - def test_rmtree_failure_continues_and_logs(self, patch_main_config, errors): - """Test that rmtree failure is logged and cleanup continues""" + def test_symlinks_are_handled(self, patch_main_config): tmp_path, main_mod = patch_main_config - (tmp_path / "file.txt").write_text("content") - (tmp_path / "failing_dir").mkdir() - (tmp_path / "another_file.txt").write_text("more content") - (tmp_path / "good_dir").mkdir() - - original_rmtree = shutil.rmtree + external_file = tmp_path.parent / "external.txt" + external_file.write_text("external") - def mock_rmtree(path, *args, **kwargs): - if "failing_dir" in str(path): - raise PermissionError("Cannot delete directory") - return original_rmtree(path, *args, **kwargs) + (tmp_path / "link_to_external").symlink_to(external_file) - with patch("shutil.rmtree", mock_rmtree): - main_mod.clean_mindsdb_tmp_dir() + main_mod.clean_mindsdb_tmp_dir() - assert "Failed to clean" in errors.text - assert "Cannot delete directory" in errors.text + assert tmp_path.exists() + assert list(tmp_path.iterdir()) == [] + assert external_file.exists() - assert not (tmp_path / "file.txt").exists() - assert not (tmp_path / "another_file.txt").exists() - assert not (tmp_path / "good_dir").exists() - assert (tmp_path / "failing_dir").exists() + external_file.unlink() def test_unlink_failure_continues_and_logs(self, patch_main_config, errors): - """Test that unlink failure is logged and cleanup continues""" tmp_path, main_mod = patch_main_config - - (tmp_path / "file1.txt").write_text("a") + (tmp_path / "ok1.txt").write_text("a") (tmp_path / "failing_file.txt").write_text("b") - (tmp_path / "file2.txt").write_text("c") + (tmp_path / "ok2.txt").write_text("c") original_unlink = pathlib.Path.unlink @@ -120,43 +88,42 @@ def mock_unlink(self, *args, **kwargs): with patch.object(pathlib.Path, "unlink", mock_unlink): main_mod.clean_mindsdb_tmp_dir() - assert "Failed to clean" in errors.text - assert "Cannot delete file" in errors.text + txt = errors.text + assert "Failed to clean" in txt + assert "Cannot delete file" in txt - assert not (tmp_path / "file1.txt").exists() + assert not (tmp_path / "ok1.txt").exists() + assert not (tmp_path / "ok2.txt").exists() assert (tmp_path / "failing_file.txt").exists() - assert not (tmp_path / "file2.txt").exists() - def test_special_files_are_removed(self, patch_main_config): - """Test that hidden files and special names are cleaned""" + def test_rmtree_failure_continues_and_logs(self, patch_main_config, errors): tmp_path, main_mod = patch_main_config - (tmp_path / ".hidden").write_text("hidden") - (tmp_path / ".config").mkdir() - (tmp_path / "__pycache__").mkdir() - - main_mod.clean_mindsdb_tmp_dir() + (tmp_path / "file.txt").write_text("content") + (tmp_path / "failing_dir").mkdir() + (tmp_path / "another_file.txt").write_text("more content") + (tmp_path / "good_dir").mkdir() - assert tmp_path.exists() - assert list(tmp_path.iterdir()) == [] + original_rmtree = shutil.rmtree - def test_large_directory_tree(self, patch_main_config): - """Test cleaning a large directory structure""" - tmp_path, main_mod = patch_main_config + def mock_rmtree(path, *args, **kwargs): + if "failing_dir" in str(path): + raise PermissionError("Cannot delete directory") + return original_rmtree(path, *args, **kwargs) - for i in range(10): - dir_path = tmp_path / f"dir{i}" - dir_path.mkdir() - for j in range(10): - (dir_path / f"file{j}.txt").write_text(f"content{i}{j}") + with patch("shutil.rmtree", mock_rmtree): + main_mod.clean_mindsdb_tmp_dir() - main_mod.clean_mindsdb_tmp_dir() + txt = errors.text + assert "Failed to clean" in txt + assert "Cannot delete directory" in txt - assert tmp_path.exists() - assert list(tmp_path.iterdir()) == [] + assert not (tmp_path / "file.txt").exists() + assert not (tmp_path / "another_file.txt").exists() + assert not (tmp_path / "good_dir").exists() + assert (tmp_path / "failing_dir").exists() def test_mixed_failures_continue_cleanup(self, patch_main_config, errors): - """Test that multiple failures don't stop the cleanup process""" tmp_path, main_mod = patch_main_config (tmp_path / "good_file1.txt").write_text("a") @@ -178,41 +145,22 @@ def mock_rmtree(path, *args, **kwargs): raise PermissionError("Cannot delete directory") return original_rmtree(path, *args, **kwargs) - with patch.object(pathlib.Path, "unlink", mock_unlink): - with patch("shutil.rmtree", mock_rmtree): - main_mod.clean_mindsdb_tmp_dir() + with patch.object(pathlib.Path, "unlink", mock_unlink), patch( + "shutil.rmtree", mock_rmtree + ): + main_mod.clean_mindsdb_tmp_dir() - log_text = errors.text - assert log_text.count("Failed to clean") >= 2 + txt = errors.text + # We should have at least two "Failed to clean" lines (file + dir) + assert txt.count("Failed to clean") >= 2 assert not (tmp_path / "good_file1.txt").exists() assert not (tmp_path / "good_file2.txt").exists() assert not (tmp_path / "good_dir").exists() - assert (tmp_path / "failing_file.txt").exists() assert (tmp_path / "failing_dir").exists() - def test_logger_called_with_correct_level(self, patch_main_config): - """Test that errors are logged at ERROR level""" - tmp_path, main_mod = patch_main_config - - (tmp_path / "failing_file.txt").write_text("content") - - original_unlink = pathlib.Path.unlink - - def mock_unlink(self, *args, **kwargs): - if self.name == "failing_file.txt": - raise PermissionError("Test error") - return original_unlink(self, *args, **kwargs) - - with patch.object(pathlib.Path, "unlink", mock_unlink): - with patch("mindsdb.__main__.logger") as mock_logger: - main_mod.clean_mindsdb_tmp_dir() - - assert mock_logger.error.called or mock_logger.exception.called - def test_nonexistent_tmp_path(self, monkeypatch): - """Test handling when tmp path doesn't exist""" import mindsdb.__main__ as main_mod from pathlib import Path @@ -220,23 +168,23 @@ def test_nonexistent_tmp_path(self, monkeypatch): assert not nonexistent.exists() monkeypatch.setattr(main_mod, "config", {"paths": {"tmp": nonexistent}}) - main_mod.clean_mindsdb_tmp_dir() assert not nonexistent.exists() - def test_symlinks_are_handled(self, patch_main_config): - """Test that symlinks are removed without following them""" + def test_logger_called_with_correct_level(self, patch_main_config): tmp_path, main_mod = patch_main_config + (tmp_path / "failing_file.txt").write_text("content") - external_file = tmp_path.parent / "external.txt" - external_file.write_text("external") - - (tmp_path / "link_to_external").symlink_to(external_file) + original_unlink = pathlib.Path.unlink - main_mod.clean_mindsdb_tmp_dir() + def mock_unlink(self, *args, **kwargs): + if self.name == "failing_file.txt": + raise PermissionError("Test error") + return original_unlink(self, *args, **kwargs) - assert tmp_path.exists() - assert list(tmp_path.iterdir()) == [] - assert external_file.exists() + with patch.object(pathlib.Path, "unlink", mock_unlink), patch( + "mindsdb.__main__.logger" + ) as mock_logger: + main_mod.clean_mindsdb_tmp_dir() + assert mock_logger.error.called or mock_logger.exception.called - external_file.unlink() From aa23656f3eab97cf7af1a6e6faa070df594bd054 Mon Sep 17 00:00:00 2001 From: tino097 Date: Thu, 6 Nov 2025 15:32:08 +0100 Subject: [PATCH 010/125] Ruff format --- tests/unit/various/test_main.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/unit/various/test_main.py b/tests/unit/various/test_main.py index b08ba77146d..689a9eb93af 100644 --- a/tests/unit/various/test_main.py +++ b/tests/unit/various/test_main.py @@ -19,9 +19,7 @@ def errors(self, caplog): class ErrorCapture: @property def text(self): - return "\n".join( - r.getMessage() for r in caplog.records if r.levelname == "ERROR" - ) + return "\n".join(r.getMessage() for r in caplog.records if r.levelname == "ERROR") caplog.clear() caplog.set_level("ERROR") @@ -145,9 +143,7 @@ def mock_rmtree(path, *args, **kwargs): raise PermissionError("Cannot delete directory") return original_rmtree(path, *args, **kwargs) - with patch.object(pathlib.Path, "unlink", mock_unlink), patch( - "shutil.rmtree", mock_rmtree - ): + with patch.object(pathlib.Path, "unlink", mock_unlink), patch("shutil.rmtree", mock_rmtree): main_mod.clean_mindsdb_tmp_dir() txt = errors.text @@ -182,9 +178,6 @@ def mock_unlink(self, *args, **kwargs): raise PermissionError("Test error") return original_unlink(self, *args, **kwargs) - with patch.object(pathlib.Path, "unlink", mock_unlink), patch( - "mindsdb.__main__.logger" - ) as mock_logger: + with patch.object(pathlib.Path, "unlink", mock_unlink), patch("mindsdb.__main__.logger") as mock_logger: main_mod.clean_mindsdb_tmp_dir() assert mock_logger.error.called or mock_logger.exception.called - From bb63ee2175294ec0d7ef36202a385d1efb6a6f19 Mon Sep 17 00:00:00 2001 From: Murat Aslan Date: Fri, 12 Dec 2025 12:09:05 +0300 Subject: [PATCH 011/125] fix(s3): handle us-east-1 region returning None from get_bucket_location AWS S3 API returns None for LocationConstraint when bucket is in us-east-1 (the default/classic region). This caused DuckDB to fail with HTTP 400 error because it received 'None' as the region string. Fixes #11886 --- mindsdb/integrations/handlers/s3_handler/s3_handler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/s3_handler/s3_handler.py b/mindsdb/integrations/handlers/s3_handler/s3_handler.py index 108158769c6..a0c616f9260 100644 --- a/mindsdb/integrations/handlers/s3_handler/s3_handler.py +++ b/mindsdb/integrations/handlers/s3_handler/s3_handler.py @@ -157,7 +157,11 @@ def _connect_duckdb(self, bucket): # detect region for bucket if bucket not in self._regions: client = self.connect() - self._regions[bucket] = client.get_bucket_location(Bucket=bucket)["LocationConstraint"] + location = client.get_bucket_location(Bucket=bucket)["LocationConstraint"] + # AWS returns None for us-east-1 region (default/classic region) + if location is None: + location = "us-east-1" + self._regions[bucket] = location region = self._regions[bucket] duckdb_conn.execute(f"SET s3_region='{region}'") From 5a9425cfdde4ed2db840fd7837440f5a8d8dab04 Mon Sep 17 00:00:00 2001 From: tino097 Date: Mon, 15 Dec 2025 12:55:10 +0100 Subject: [PATCH 012/125] Add safe extract for zip files --- mindsdb/utilities/fs.py | 42 ++++++++++++++++++++++--------- tests/unit/api/http/files_test.py | 30 +++++++++++++++++++++- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/mindsdb/utilities/fs.py b/mindsdb/utilities/fs.py index a34ceb92f8c..ae3b82c47be 100644 --- a/mindsdb/utilities/fs.py +++ b/mindsdb/utilities/fs.py @@ -5,6 +5,9 @@ from pathlib import Path from typing import Optional, List, Tuple +import tarfile +import zipfile + import psutil from mindsdb.utilities import log @@ -112,7 +115,9 @@ def clean_unlinked_process_marks() -> List[int]: try: next(t for t in threads if t.id == thread_id) except StopIteration: - logger.warning(f"We have mark for process/thread {process_id}/{thread_id} but it does not exists") + logger.warning( + f"We have mark for process/thread {process_id}/{thread_id} but it does not exists" + ) deleted_pids.append(process_id) file.unlink() @@ -121,7 +126,9 @@ def clean_unlinked_process_marks() -> List[int]: continue except psutil.NoSuchProcess: - logger.warning(f"We have mark for process/thread {process_id}/{thread_id} but it does not exists") + logger.warning( + f"We have mark for process/thread {process_id}/{thread_id} but it does not exists" + ) deleted_pids.append(process_id) file.unlink() return deleted_pids @@ -181,15 +188,26 @@ def __is_within_directory(directory, target): return prefix == abs_directory -def safe_extract(tarfile, path=".", members=None, *, numeric_owner=False): - # for py >= 3.12 - if hasattr(tarfile, "data_filter"): - tarfile.extractall(path, members=members, numeric_owner=numeric_owner, filter="data") +def safe_extract(archivefile, path=".", members=None, *, numeric_owner=False): + if isinstance(archivefile, zipfile.ZipFile): + for member in archivefile.namelist(): + member_path = os.path.join(path, member) + if not __is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Zip File") + archivefile.extractall(path, members) return - # for py < 3.12 - for member in tarfile.getmembers(): - member_path = os.path.join(path, member.name) - if not __is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - tarfile.extractall(path, members=members, numeric_owner=numeric_owner) + if isinstance(archivefile, tarfile.TarFile): + # for py >= 3.12 + if hasattr(archivefile, "data_filter"): + archivefile.extractall( + path, members=members, numeric_owner=numeric_owner, filter="data" + ) + return + + # for py < 3.12 + for member in archivefile.getmembers(): + member_path = os.path.join(path, member.name) + if not __is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + archivefile.extractall(path, members=members, numeric_owner=numeric_owner) diff --git a/tests/unit/api/http/files_test.py b/tests/unit/api/http/files_test.py index d65a7bda19e..49f29938ffd 100644 --- a/tests/unit/api/http/files_test.py +++ b/tests/unit/api/http/files_test.py @@ -1,6 +1,9 @@ import io import os.path +import os from http import HTTPStatus +from unittest.mock import patch +import tempfile def test_get_files_list(client): @@ -52,7 +55,10 @@ def test_delete_nonexistent_file(client): assert response.status_code == HTTPStatus.BAD_REQUEST data = response.get_json() assert "Error deleting file" in data["title"] - assert "There was an error while trying to delete file with name 'nonexistent.txt'" in data["detail"] + assert ( + "There was an error while trying to delete file with name 'nonexistent.txt'" + in data["detail"] + ) def test_put_file_invalid_url(client): @@ -129,3 +135,25 @@ def test_archive_file_with_extension_upload(client): assert response.status_code == 400 data = response.get_json() assert "File name cannot contain extension." in data["detail"] + + +def test_zipfile_traversal(client): + """Test uploading a zip archive with path traversal filenames""" + import zipfile + import io + + # Create a zip file in memory with a symlink + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zf: + zf.writestr("../../../../etc/passwd", "malicious content") + zip_buffer.seek(0) + data = {"file": (zip_buffer, "archive.zip")} + response = client.put( + "/api/files/archive", + data=data, + content_type="multipart/form-data", + follow_redirects=True, + ) + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + data = response.get_json() + assert "Attempted Path Traversal in Zip File" in data["detail"] From c7ea357fda957b9b25597dc01d15aebb2cd09be5 Mon Sep 17 00:00:00 2001 From: tino097 Date: Mon, 15 Dec 2025 12:58:36 +0100 Subject: [PATCH 013/125] Add functionallity, delete temp files correctly --- mindsdb/api/http/namespaces/file.py | 48 ++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/mindsdb/api/http/namespaces/file.py b/mindsdb/api/http/namespaces/file.py index 9cac6a826c0..c9a1045690d 100644 --- a/mindsdb/api/http/namespaces/file.py +++ b/mindsdb/api/http/namespaces/file.py @@ -105,9 +105,13 @@ def on_file(file): try: file_object.flush() except (AttributeError, ValueError, OSError): - logger.debug("Failed to flush file_object before closing.", exc_info=True) + logger.debug( + "Failed to flush file_object before closing.", exc_info=True + ) file_object.close() - Path(file_object.name).rename(Path(file_object.name).parent / data["file"]) + Path(file_object.name).rename( + Path(file_object.name).parent / data["file"] + ) file_object = None else: data = request.json @@ -131,7 +135,9 @@ def on_file(file): if source_type == "url": url_file_upload_enabled = config["url_file_upload"]["enabled"] if url_file_upload_enabled is False: - return http_error(400, "URL file upload is disabled.", "URL file upload is disabled.") + return http_error( + 400, "URL file upload is disabled.", "URL file upload is disabled." + ) if "file" in data: return http_error( @@ -196,7 +202,9 @@ def on_file(file): ) with requests.get(url, stream=True) as r: if r.status_code != 200: - return http_error(400, "Error getting file", f"Got status code: {r.status_code}") + return http_error( + 400, "Error getting file", f"Got status code: {r.status_code}" + ) file_path = os.path.join(temp_dir_path, data["file"]) with open(file_path, "wb") as f: for chunk in r.iter_content(chunk_size=8192): @@ -214,27 +222,37 @@ def on_file(file): file_path = os.path.join(temp_dir_path, data["file"]) lp = file_path.lower() if lp.endswith((".zip", ".tar.gz")): - if lp.endswith(".zip"): - with zipfile.ZipFile(file_path) as f: - f.extractall(temp_dir_path) - elif lp.endswith(".tar.gz"): - with tarfile.open(file_path) as f: - safe_extract(f, temp_dir_path) + try: + if lp.endswith(".zip"): + with zipfile.ZipFile(file_path) as f: + safe_extract(f, temp_dir_path) + elif lp.endswith(".tar.gz"): + with tarfile.open(file_path) as f: + safe_extract(f, temp_dir_path) + except Exception as e: + shutil.rmtree(temp_dir_path, ignore_errors=True) + return http_error(500, "Error", str(e)) os.remove(file_path) files = os.listdir(temp_dir_path) if len(files) != 1: - os.rmdir(temp_dir_path) - return http_error(400, "Wrong content.", "Archive must contain only one data file.") + shutil.rmtree(temp_dir_path, ignore_errors=True) + return http_error( + 400, "Wrong content.", "Archive must contain only one data file." + ) file_path = os.path.join(temp_dir_path, files[0]) mindsdb_file_name = files[0] if not os.path.isfile(file_path): - os.rmdir(temp_dir_path) - return http_error(400, "Wrong content.", "Archive must contain data file in root.") + shutil.rmtree(temp_dir_path, ignore_errors=True) + return http_error( + 400, "Wrong content.", "Archive must contain data file in root." + ) try: if not Path(mindsdb_file_name).suffix == "": return http_error(400, "Error", "File name cannot contain extension.") - ca.file_controller.save_file(mindsdb_file_name, file_path, file_name=original_file_name) + ca.file_controller.save_file( + mindsdb_file_name, file_path, file_name=original_file_name + ) except FileProcessingError as e: return http_error(400, "Error", str(e)) except Exception as e: From 1e6d1e13ef6fc3b619d1e210fc3af633f34a76b3 Mon Sep 17 00:00:00 2001 From: tino097 Date: Mon, 15 Dec 2025 13:00:18 +0100 Subject: [PATCH 014/125] Ruff format --- mindsdb/api/http/namespaces/file.py | 28 +++++++--------------------- mindsdb/utilities/fs.py | 12 +++--------- tests/unit/api/http/files_test.py | 5 +---- 3 files changed, 11 insertions(+), 34 deletions(-) diff --git a/mindsdb/api/http/namespaces/file.py b/mindsdb/api/http/namespaces/file.py index c9a1045690d..a2e418f721a 100644 --- a/mindsdb/api/http/namespaces/file.py +++ b/mindsdb/api/http/namespaces/file.py @@ -105,13 +105,9 @@ def on_file(file): try: file_object.flush() except (AttributeError, ValueError, OSError): - logger.debug( - "Failed to flush file_object before closing.", exc_info=True - ) + logger.debug("Failed to flush file_object before closing.", exc_info=True) file_object.close() - Path(file_object.name).rename( - Path(file_object.name).parent / data["file"] - ) + Path(file_object.name).rename(Path(file_object.name).parent / data["file"]) file_object = None else: data = request.json @@ -135,9 +131,7 @@ def on_file(file): if source_type == "url": url_file_upload_enabled = config["url_file_upload"]["enabled"] if url_file_upload_enabled is False: - return http_error( - 400, "URL file upload is disabled.", "URL file upload is disabled." - ) + return http_error(400, "URL file upload is disabled.", "URL file upload is disabled.") if "file" in data: return http_error( @@ -202,9 +196,7 @@ def on_file(file): ) with requests.get(url, stream=True) as r: if r.status_code != 200: - return http_error( - 400, "Error getting file", f"Got status code: {r.status_code}" - ) + return http_error(400, "Error getting file", f"Got status code: {r.status_code}") file_path = os.path.join(temp_dir_path, data["file"]) with open(file_path, "wb") as f: for chunk in r.iter_content(chunk_size=8192): @@ -236,23 +228,17 @@ def on_file(file): files = os.listdir(temp_dir_path) if len(files) != 1: shutil.rmtree(temp_dir_path, ignore_errors=True) - return http_error( - 400, "Wrong content.", "Archive must contain only one data file." - ) + return http_error(400, "Wrong content.", "Archive must contain only one data file.") file_path = os.path.join(temp_dir_path, files[0]) mindsdb_file_name = files[0] if not os.path.isfile(file_path): shutil.rmtree(temp_dir_path, ignore_errors=True) - return http_error( - 400, "Wrong content.", "Archive must contain data file in root." - ) + return http_error(400, "Wrong content.", "Archive must contain data file in root.") try: if not Path(mindsdb_file_name).suffix == "": return http_error(400, "Error", "File name cannot contain extension.") - ca.file_controller.save_file( - mindsdb_file_name, file_path, file_name=original_file_name - ) + ca.file_controller.save_file(mindsdb_file_name, file_path, file_name=original_file_name) except FileProcessingError as e: return http_error(400, "Error", str(e)) except Exception as e: diff --git a/mindsdb/utilities/fs.py b/mindsdb/utilities/fs.py index ae3b82c47be..bc729a53f91 100644 --- a/mindsdb/utilities/fs.py +++ b/mindsdb/utilities/fs.py @@ -115,9 +115,7 @@ def clean_unlinked_process_marks() -> List[int]: try: next(t for t in threads if t.id == thread_id) except StopIteration: - logger.warning( - f"We have mark for process/thread {process_id}/{thread_id} but it does not exists" - ) + logger.warning(f"We have mark for process/thread {process_id}/{thread_id} but it does not exists") deleted_pids.append(process_id) file.unlink() @@ -126,9 +124,7 @@ def clean_unlinked_process_marks() -> List[int]: continue except psutil.NoSuchProcess: - logger.warning( - f"We have mark for process/thread {process_id}/{thread_id} but it does not exists" - ) + logger.warning(f"We have mark for process/thread {process_id}/{thread_id} but it does not exists") deleted_pids.append(process_id) file.unlink() return deleted_pids @@ -200,9 +196,7 @@ def safe_extract(archivefile, path=".", members=None, *, numeric_owner=False): if isinstance(archivefile, tarfile.TarFile): # for py >= 3.12 if hasattr(archivefile, "data_filter"): - archivefile.extractall( - path, members=members, numeric_owner=numeric_owner, filter="data" - ) + archivefile.extractall(path, members=members, numeric_owner=numeric_owner, filter="data") return # for py < 3.12 diff --git a/tests/unit/api/http/files_test.py b/tests/unit/api/http/files_test.py index 49f29938ffd..52833ec540a 100644 --- a/tests/unit/api/http/files_test.py +++ b/tests/unit/api/http/files_test.py @@ -55,10 +55,7 @@ def test_delete_nonexistent_file(client): assert response.status_code == HTTPStatus.BAD_REQUEST data = response.get_json() assert "Error deleting file" in data["title"] - assert ( - "There was an error while trying to delete file with name 'nonexistent.txt'" - in data["detail"] - ) + assert "There was an error while trying to delete file with name 'nonexistent.txt'" in data["detail"] def test_put_file_invalid_url(client): From 6c6dbcc7cb1610a9f76ad604d5bb0108f91202f8 Mon Sep 17 00:00:00 2001 From: tino097 Date: Mon, 15 Dec 2025 14:29:06 +0100 Subject: [PATCH 015/125] Remove unused imports --- tests/unit/api/http/files_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/api/http/files_test.py b/tests/unit/api/http/files_test.py index 52833ec540a..36337170a59 100644 --- a/tests/unit/api/http/files_test.py +++ b/tests/unit/api/http/files_test.py @@ -2,8 +2,6 @@ import os.path import os from http import HTTPStatus -from unittest.mock import patch -import tempfile def test_get_files_list(client): From ea7724a68bd6a484fd56a2a66b67b6555fa15326 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Feb 2026 14:24:27 +0100 Subject: [PATCH 016/125] Add the leads mappings --- .../hubspot_handler/hubspot_tables.py | 251 ++++++++++++++++++ 1 file changed, 251 insertions(+) diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py index ad381432023..ef3a11466f5 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py @@ -601,6 +601,17 @@ def _extract(node: ASTNode, **kwargs): ("stage_probability", "DECIMAL", "Stage probability"), ("stage_archived", "BOOLEAN", "Stage archived"), ], + "leads": [ + ("hs_lead_name", "VARCHAR", "Lead name"), + ("hs_lead_type", "VARCHAR", "Lead type"), + ("hs_lead_label", "VARCHAR", "Lead label/status"), + ("hubspot_owner_id", "VARCHAR", "Owner ID"), + ("hs_timestamp", "TIMESTAMP", "Lead timestamp"), + ("primary_contact_id", "VARCHAR", "Primary associated contact ID"), + ("primary_company_id", "VARCHAR", "Primary associated company ID"), + ("createdate", "TIMESTAMP", "Creation date"), + ("lastmodifieddate", "TIMESTAMP", "Last modification date"), + ], } @@ -3818,3 +3829,243 @@ def delete_notes(self, note_ids: List[Text]) -> None: logger.info("Notes deleted") except Exception as e: raise Exception(f"Notes deletion failed {e}") + + +class LeadsTable(HubSpotAPIResource): + """HubSpot Leads table for prospective customer records.""" + + # Reference: https://developers.hubspot.com/docs/api-reference/crm-leads-v3/guide + SEARCHABLE_COLUMNS: Set[str] = {"hs_lead_name", "hs_lead_type", "hs_lead_label", "id"} + ASSOCIATION_COLUMNS = {"primary_contact_id", "primary_company_id"} + + def meta_get_tables(self, table_name: str) -> Dict[str, Any]: + row_count = None + try: + self.handler.connect() + row_count = self.handler._estimate_table_rows("leads") + except Exception as e: + logger.warning(f"Could not estimate HubSpot leads row count: {e}") + return { + "TABLE_NAME": "leads", + "TABLE_TYPE": "BASE TABLE", + "TABLE_DESCRIPTION": "HubSpot leads representing prospective customer records", + "ROW_COUNT": row_count, + } + + def meta_get_columns(self, table_name: str) -> List[Dict[str, Any]]: + return self.handler._get_default_meta_columns("leads") + + def list( + self, + conditions: List[FilterCondition] = None, + limit: int = None, + sort: List[SortColumn] = None, + targets: List[str] = None, + search_filters: Optional[List[Dict[str, Any]]] = None, + search_sorts: Optional[List[Dict[str, Any]]] = None, + allow_search: bool = True, + ) -> pd.DataFrame: + leads_df = pd.json_normalize( + self.get_leads( + limit=limit, + where_conditions=conditions, + properties=targets, + search_filters=search_filters, + search_sorts=search_sorts, + allow_search=allow_search, + ) + ) + if leads_df.empty: + leads_df = pd.DataFrame(columns=targets or self._get_default_lead_columns()) + return leads_df + + def add(self, lead_data: List[dict]): + self.create_leads(lead_data) + + def modify(self, conditions: List[FilterCondition], values: Dict) -> None: + normalized_conditions = _normalize_filter_conditions(conditions) + leads_df = pd.json_normalize(self.get_leads(limit=200, where_conditions=normalized_conditions)) + + if leads_df.empty: + raise ValueError("No leads retrieved from HubSpot to evaluate update conditions.") + + executor_conditions = _normalize_conditions_for_executor(normalized_conditions) + update_query_executor = UPDATEQueryExecutor(leads_df, executor_conditions) + filtered_df = update_query_executor.execute_query() + + if filtered_df.empty: + raise ValueError(f"No leads found matching WHERE conditions: {conditions}.") + + lead_ids = filtered_df["id"].astype(str).tolist() + logger.info(f"Updating {len(lead_ids)} lead(s) matching WHERE conditions") + self.update_leads(lead_ids, values) + + def remove(self, conditions: List[FilterCondition]) -> None: + normalized_conditions = _normalize_filter_conditions(conditions) + leads_df = pd.json_normalize(self.get_leads(limit=200, where_conditions=normalized_conditions)) + + if leads_df.empty: + raise ValueError("No leads retrieved from HubSpot to evaluate delete conditions.") + + executor_conditions = _normalize_conditions_for_executor(normalized_conditions) + delete_query_executor = DELETEQueryExecutor(leads_df, executor_conditions) + filtered_df = delete_query_executor.execute_query() + + if filtered_df.empty: + raise ValueError(f"No leads found matching WHERE conditions: {conditions}.") + + lead_ids = filtered_df["id"].astype(str).tolist() + logger.info(f"Deleting {len(lead_ids)} lead(s) matching WHERE conditions") + self.delete_leads(lead_ids) + + def get_columns(self) -> List[Text]: + return self._get_default_lead_columns() + + @staticmethod + def _get_default_lead_columns() -> List[str]: + return [ + "id", + "hs_lead_name", + "hs_lead_type", + "hs_lead_label", + "hubspot_owner_id", + "hs_timestamp", + "primary_contact_id", + "primary_company_id", + "createdate", + "lastmodifieddate", + ] + + def get_leads( + self, + limit: Optional[int] = None, + where_conditions: Optional[List] = None, + properties: Optional[List[str]] = None, + search_filters: Optional[List[Dict[str, Any]]] = None, + search_sorts: Optional[List[Dict[str, Any]]] = None, + allow_search: bool = True, + **kwargs, + ) -> List[Dict]: + normalized_conditions = _normalize_filter_conditions(where_conditions) + hubspot = self.handler.connect() + + requested_properties = properties or [] + default_properties = self._get_default_lead_columns() + columns = requested_properties or default_properties + association_targets, hubspot_columns = _prepare_association_request("leads", columns) + hubspot_properties = _build_hubspot_properties(hubspot_columns) + + api_kwargs = {**kwargs, "properties": hubspot_properties} + if limit is not None: + api_kwargs["limit"] = limit + else: + api_kwargs.pop("limit", None) + if association_targets: + api_kwargs["associations"] = association_targets + + if allow_search and (search_filters or search_sorts or normalized_conditions): + filters = search_filters + if filters is None and normalized_conditions: + filters = _build_hubspot_search_filters(normalized_conditions, self.SEARCHABLE_COLUMNS) + if filters is not None or search_sorts is not None: + if association_targets: + logger.debug("HubSpot search API does not include associations for leads.") + search_results = self._search_leads_by_conditions( + hubspot, + filters, + hubspot_properties, + limit, + search_sorts, + hubspot_columns, + association_targets, + ) + logger.info(f"Retrieved {len(search_results)} leads from HubSpot via search API") + return search_results + + leads = self.handler._get_objects_all("leads", **api_kwargs) + leads_dict = [] + for lead in leads: + try: + row = self._lead_to_dict(lead, hubspot_columns, association_targets) + leads_dict.append(row) + except Exception as e: + logger.warning(f"Error processing lead {getattr(lead, 'id', 'unknown')}: {str(e)}") + continue + + logger.info(f"Retrieved {len(leads_dict)} leads from HubSpot") + return leads_dict + + def _search_leads_by_conditions( + self, + hubspot: HubSpot, + filters: Optional[List[Dict[str, Any]]], + properties: List[str], + limit: Optional[int], + sorts: Optional[List[Dict[str, Any]]], + columns: List[str], + association_targets: List[str], + ) -> List[Dict[str, Any]]: + return _execute_hubspot_search( + hubspot.crm.objects.search_api, + filters or [], + properties, + limit, + lambda obj: self._lead_to_dict(obj, columns, association_targets), + sorts=sorts, + object_type="leads", + ) + + def _lead_to_dict( + self, + lead: Any, + columns: Optional[List[str]] = None, + association_targets: Optional[List[str]] = None, + ) -> Dict[str, Any]: + columns = columns or self._get_default_lead_columns() + row = self._object_to_dict(lead, columns) + if association_targets: + row = enrich_object_with_associations(lead, "leads", row) + return row + + def create_leads(self, leads_data: List[Dict[Text, Any]]) -> None: + if not leads_data: + raise ValueError("No lead data provided for creation") + + logger.info(f"Attempting to create {len(leads_data)} lead(s)") + hubspot = self.handler.connect() + leads_to_create = [HubSpotObjectInputCreate(properties=lead) for lead in leads_data] + batch_input = BatchInputSimplePublicObjectBatchInputForCreate(inputs=leads_to_create) + + try: + created_leads = hubspot.crm.objects.leads.batch_api.create( + batch_input_simple_public_object_batch_input_for_create=batch_input + ) + if not created_leads or not hasattr(created_leads, "results") or not created_leads.results: + raise Exception("Lead creation returned no results") + created_ids = [l.id for l in created_leads.results] + logger.info(f"Successfully created {len(created_ids)} lead(s) with IDs: {created_ids}") + except Exception as e: + logger.error(f"Leads creation failed: {str(e)}") + raise Exception(f"Leads creation failed {e}") + + def update_leads(self, lead_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + hubspot = self.handler.connect() + leads_to_update = [HubSpotObjectBatchInput(id=lid, properties=values_to_update) for lid in lead_ids] + batch_input = BatchInputSimplePublicObjectBatchInput(inputs=leads_to_update) + try: + updated = hubspot.crm.objects.leads.batch_api.update( + batch_input_simple_public_object_batch_input=batch_input + ) + logger.info(f"Leads with ID {[l.id for l in updated.results]} updated") + except Exception as e: + raise Exception(f"Leads update failed {e}") + + def delete_leads(self, lead_ids: List[Text]) -> None: + hubspot = self.handler.connect() + leads_to_delete = [HubSpotObjectId(id=lid) for lid in lead_ids] + batch_input = BatchInputSimplePublicObjectId(inputs=leads_to_delete) + try: + hubspot.crm.objects.leads.batch_api.archive(batch_input_simple_public_object_id=batch_input) + logger.info("Leads deleted") + except Exception as e: + raise Exception(f"Leads deletion failed {e}") From e832337e3a39a42e91b0302be124a41ad3091c87 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Feb 2026 14:25:26 +0100 Subject: [PATCH 017/125] Update the tables registration --- .../integrations/handlers/hubspot_handler/hubspot_tables.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py index ef3a11466f5..3ada1cf1541 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py @@ -4042,7 +4042,7 @@ def create_leads(self, leads_data: List[Dict[Text, Any]]) -> None: ) if not created_leads or not hasattr(created_leads, "results") or not created_leads.results: raise Exception("Lead creation returned no results") - created_ids = [l.id for l in created_leads.results] + created_ids = [lead.id for lead in created_leads.results] logger.info(f"Successfully created {len(created_ids)} lead(s) with IDs: {created_ids}") except Exception as e: logger.error(f"Leads creation failed: {str(e)}") @@ -4056,7 +4056,7 @@ def update_leads(self, lead_ids: List[Text], values_to_update: Dict[Text, Any]) updated = hubspot.crm.objects.leads.batch_api.update( batch_input_simple_public_object_batch_input=batch_input ) - logger.info(f"Leads with ID {[l.id for l in updated.results]} updated") + logger.info(f"Leads with ID {[lead.id for lead in updated.results]} updated") except Exception as e: raise Exception(f"Leads update failed {e}") From f39ef04ffebe883e51e591fc3df00fd19fdd1948 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Feb 2026 14:26:02 +0100 Subject: [PATCH 018/125] Update the associations for leads --- .../handlers/hubspot_handler/hubspot_association_utils.py | 4 ++++ .../integrations/handlers/hubspot_handler/hubspot_handler.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_association_utils.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_association_utils.py index d8b7de1de24..beb9369c515 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_association_utils.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_association_utils.py @@ -42,6 +42,10 @@ ("companies", "primary_company_id"), ("deals", "primary_deal_id"), ], + "leads": [ + ("contacts", "primary_contact_id"), + ("companies", "primary_company_id"), + ], } diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index e70c4dd0139..d52ba89847f 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -13,6 +13,7 @@ EmailsTable, MeetingsTable, NotesTable, + LeadsTable, OwnersTable, DealStagesTable, to_hubspot_property, @@ -135,6 +136,7 @@ def __init__(self, name: str, **kwargs: Any) -> None: self._register_table("emails", EmailsTable(self)) self._register_table("meetings", MeetingsTable(self)) self._register_table("notes", NotesTable(self)) + self._register_table("leads", LeadsTable(self)) self._register_table("owners", OwnersTable(self)) self._register_table("deal_stages", DealStagesTable(self)) From b1dc3d8f8722a8e5ac4d472ccec32c5052cee0e2 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Feb 2026 16:57:23 +0100 Subject: [PATCH 019/125] Remove mention of oauth from readme --- .../integrations/handlers/hubspot_handler/README.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/README.md b/mindsdb/integrations/handlers/hubspot_handler/README.md index 67bc8be809c..0cb61d4f147 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/README.md +++ b/mindsdb/integrations/handlers/hubspot_handler/README.md @@ -12,7 +12,6 @@ HubSpot handler for MindsDB provides interfaces to connect to HubSpot via APIs a - [Installation](#installation) - [Authentication](#authentication) - [Personal Access Token Authentication](#personal-access-token-authentication) - - [OAuth Authentication](#oauth-authentication) - [Supported Tables](#supported-tables) - [Core CRM and Engagement Tables](#core-crm-and-engagement-tables) - [Metadata Tables](#metadata-tables) @@ -59,15 +58,6 @@ Recommended for server-to-server integrations and production environments. 4. Configure required scopes for the tables you plan to access 5. Copy the generated access token -### OAuth Authentication - -Recommended for applications requiring user consent and dynamic scope management. - -**Required OAuth Parameters:** -- `client_id`: Your app's client identifier -- `client_secret`: Your app's client secret (store securely) - -OAuth token exchange and refresh are handled externally. ## Supported Tables From 42ab4d0077ac8a671cf2ad11e69e27df10e11a79 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Tue, 24 Feb 2026 19:21:38 +0100 Subject: [PATCH 020/125] Add entry for leads in README --- mindsdb/integrations/handlers/hubspot_handler/README.md | 1 + .../integrations/handlers/hubspot_handler/hubspot_handler.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/README.md b/mindsdb/integrations/handlers/hubspot_handler/README.md index 0cb61d4f147..032024df64e 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/README.md +++ b/mindsdb/integrations/handlers/hubspot_handler/README.md @@ -76,6 +76,7 @@ These tables support `SELECT`, `INSERT`, `UPDATE`, and `DELETE` operations. | `emails` | Email log records | https://developers.hubspot.com/docs/api-reference/crm-emails-v3/guide | | `meetings` | Meeting records | https://developers.hubspot.com/docs/api-reference/crm-meetings-v3/guide | | `notes` | Timeline notes | https://developers.hubspot.com/docs/api-reference/crm-notes-v3/guide | +| `leads` | Lead records including lead status and source | https://developers.hubspot.com/docs/api-reference/crm-leads-v3/guide | ### Metadata Tables diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index d52ba89847f..72af5db2a40 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -603,13 +603,14 @@ def _get_table_description(self, table_name: str) -> str: "ticket_deals": "HubSpot ticket to deal associations", "owners": "HubSpot owners with names and emails", "deal_stages": "HubSpot deal pipeline stages with labels", + "leads": "HubSpot leads data including lead status, source and other lead properties", } return descriptions.get(table_name, f"HubSpot {table_name} data") def _estimate_table_rows(self, table_name: str) -> Optional[int]: """Get actual count of rows in a table using HubSpot Search API.""" try: - if table_name in ["companies", "contacts", "deals", "tickets"]: + if table_name in ["companies", "contacts", "deals", "tickets", "leads"]: result = getattr(self.connection.crm, table_name).search_api.do_search( public_object_search_request={"limit": 1} ) From 8c8af0af3bf1d60ccda46f21e57584dd65502b27 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Thu, 26 Feb 2026 18:29:48 +0300 Subject: [PATCH 021/125] Del `--load-tokenizer` option (#12237) --- docker/mindsdb.Dockerfile | 4 ++-- mindsdb/utilities/config.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docker/mindsdb.Dockerfile b/docker/mindsdb.Dockerfile index 2712c854daf..1da9a8250ee 100644 --- a/docker/mindsdb.Dockerfile +++ b/docker/mindsdb.Dockerfile @@ -93,8 +93,8 @@ ENV PATH=/venv/bin:$PATH EXPOSE 47334/tcp EXPOSE 47335/tcp -# Pre-load tokenizer from Huggingface, and UI -RUN python -m mindsdb --config=/root/mindsdb_config.json --load-tokenizer --update-gui +# Pre-load web GUI +RUN python -m mindsdb --config=/root/mindsdb_config.json --update-gui # Same as extras image, but with dev dependencies installed. # This image is used in our docker-compose diff --git a/mindsdb/utilities/config.py b/mindsdb/utilities/config.py index 82a857b00c7..b660b31cb79 100644 --- a/mindsdb/utilities/config.py +++ b/mindsdb/utilities/config.py @@ -615,7 +615,6 @@ def parse_cmd_args(self) -> None: parser.add_argument("--project-name", type=str, default=None, help="MindsDB project name") parser.add_argument("--update-gui", action="store_true", default=False, help="Update GUI and exit") - parser.add_argument("--load-tokenizer", action="store_true", default=False, help="Preload tokenizer and exit") self._cmd_args = parser.parse_args() From e6c54989e9dcc6b807a9e2c6622e31f23334253e Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Thu, 26 Feb 2026 19:12:41 +0300 Subject: [PATCH 022/125] Lock the PID file when creating or deleting it (#12232) --- mindsdb/utilities/fs.py | 180 +++++++++++++++++++++++++++------------- 1 file changed, 124 insertions(+), 56 deletions(-) diff --git a/mindsdb/utilities/fs.py b/mindsdb/utilities/fs.py index 2462960acca..99dd49acaef 100644 --- a/mindsdb/utilities/fs.py +++ b/mindsdb/utilities/fs.py @@ -1,4 +1,5 @@ import os +import sys import json import time import tempfile @@ -127,6 +128,70 @@ def clean_unlinked_process_marks() -> list[int]: return deleted_pids +class PidFileLock: + """Cross-platform exclusive file lock context manager. + Uses fcntl.flock on Unix and msvcrt.locking on Windows. + + Attributes: + _lock_file_path (Path): path to lock file + _blocking (bool): if True, waits until the lock becomes available, otherwise raises OSError immediately if lock is held + _fh (int): lock file descriptor + """ + + def __init__(self, lock_file_path: Path, blocking: bool = True): + self._lock_file_path = lock_file_path + self._blocking = blocking + self._fh = None + + def __enter__(self): + self._lock_file_path.parent.mkdir(parents=True, exist_ok=True) + self._fh = open(self._lock_file_path, "a+") + try: + if sys.platform == "win32": + import msvcrt + + # NOTE if file is locked, LK_LOCK will raise OSError after 10 seconds, LK_NBLCK immediately + mode = msvcrt.LK_LOCK if self._blocking else msvcrt.LK_NBLCK + self._fh.seek(0) + msvcrt.locking(self._fh.fileno(), mode, 1) + else: + import fcntl + + flags = fcntl.LOCK_EX + if not self._blocking: + flags |= fcntl.LOCK_NB + fcntl.flock(self._fh.fileno(), flags) + except (OSError, IOError): + self._fh.close() + self._fh = None + logger.error(f"Failed to acquire lock on {self._lock_file_path}") + raise + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._fh is None: + return False + try: + if sys.platform == "win32": + import msvcrt + + self._fh.seek(0) + msvcrt.locking(self._fh.fileno(), msvcrt.LK_UNLCK, 1) + else: + import fcntl + + fcntl.flock(self._fh.fileno(), fcntl.LOCK_UN) + except (OSError, IOError): + pass + finally: + try: + self._fh.close() + except (OSError, IOError): + pass + self._fh = None + return False + + def create_pid_file(config): """ Create mindsdb process pid file. Check if previous process exists and is running @@ -140,48 +205,49 @@ def create_pid_file(config): p = get_tmp_dir() p.mkdir(parents=True, exist_ok=True) pid_file = p.joinpath("pid") - if pid_file.exists(): - # if process exists raise exception - pid_file_data_str = pid_file.read_text().strip() - pid = None - try: - pid_file_data = json.loads(pid_file_data_str) - if isinstance(pid_file_data, dict): - pid = pid_file_data.get("pid") - else: - pid = pid_file_data - except json.JSONDecodeError: - # is it just pid number (old approach)? - try: - pid = int(pid_file_data_str) - except Exception: - pass - logger.warning(f"Found existing PID file {pid_file} but it is not a valid JSON, removing") + lock_file = p.joinpath("pid.lock") - if pid is not None: + with PidFileLock(lock_file): + if pid_file.exists(): + pid_file_data_str = pid_file.read_text().strip() + pid = None try: - psutil.Process(int(pid)) - raise Exception(f"Found PID file with existing process: {pid} {pid_file}") - except (psutil.Error, ValueError): - pass - logger.warning(f"Found existing PID file {pid_file}({pid}), removing") - - pid_file.unlink(missing_ok=True) - - pid_file_content = config["pid_file_content"] - if pid_file_content is None or len(pid_file_content) == 0: - pid_file_data_str = str(os.getpid()) - else: - pid_file_data = {"pid": os.getpid()} - for key, value in pid_file_content.items(): - value_path = value.split(".") - value_obj = config - for path_part in value_path: - value_obj = value_obj.get(path_part) if value_obj else None - pid_file_data[key] = value_obj + pid_file_data = json.loads(pid_file_data_str) + if isinstance(pid_file_data, dict): + pid = pid_file_data.get("pid") + else: + pid = pid_file_data + except json.JSONDecodeError: + try: + pid = int(pid_file_data_str) + except Exception: + pass + logger.warning(f"Found existing PID file {pid_file} but it is not a valid JSON, removing") + + if pid is not None: + try: + psutil.Process(int(pid)) + raise Exception(f"Found PID file with existing process: {pid} {pid_file}") + except (psutil.Error, ValueError): + pass + logger.warning(f"Found existing PID file {pid_file}({pid}), removing") + + pid_file.unlink(missing_ok=True) + + pid_file_content = config["pid_file_content"] + if pid_file_content is None or len(pid_file_content) == 0: + pid_file_data_str = str(os.getpid()) + else: + pid_file_data = {"pid": os.getpid()} + for key, value in pid_file_content.items(): + value_path = value.split(".") + value_obj = config + for path_part in value_path: + value_obj = value_obj.get(path_part) if value_obj else None + pid_file_data[key] = value_obj - pid_file_data_str = json.dumps(pid_file_data) - pid_file.write_text(pid_file_data_str) + pid_file_data_str = json.dumps(pid_file_data) + pid_file.write_text(pid_file_data_str) def delete_pid_file(): @@ -193,27 +259,29 @@ def delete_pid_file(): return pid_file = get_tmp_dir().joinpath("pid") + lock_file = get_tmp_dir().joinpath("pid.lock") - if not pid_file.exists(): - return + with PidFileLock(lock_file): + if not pid_file.exists(): + return - pid_file_data_str = pid_file.read_text().strip() - pid = None - try: - pid_file_data = json.loads(pid_file_data_str) - if isinstance(pid_file_data, dict): - pid = pid_file_data.get("pid") - else: - # It's a simple number (old format or pid_file_content=None format) - pid = pid_file_data - except json.JSONDecodeError: - logger.warning(f"Found existing PID file {pid_file} but it is not a valid JSON") + pid_file_data_str = pid_file.read_text().strip() + pid = None + try: + pid_file_data = json.loads(pid_file_data_str) + if isinstance(pid_file_data, dict): + pid = pid_file_data.get("pid") + else: + # It's a simple number (old format or pid_file_content=None format) + pid = pid_file_data + except json.JSONDecodeError: + logger.warning(f"Found existing PID file {pid_file} but it is not a valid JSON") - if pid is not None and str(pid) != str(os.getpid()): - logger.warning(f"Process id in PID file ({pid_file}) doesn't match mindsdb pid") - return + if pid is not None and str(pid) != str(os.getpid()): + logger.warning(f"Process id in PID file ({pid_file}) doesn't match mindsdb pid") + return - pid_file.unlink(missing_ok=True) + pid_file.unlink(missing_ok=True) def __is_within_directory(directory, target): From 2498f855fe6efb438e847f8a916ded42f55660dc Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Fri, 27 Feb 2026 14:37:22 +0100 Subject: [PATCH 023/125] Remove deprecated Text and use builtin str --- .../hubspot_handler/hubspot_tables.py | 86 +++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py index 3ada1cf1541..f2224b42ac7 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_tables.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Text, Any, Optional, Tuple, Set, Iterable +from typing import List, Dict, Any, Optional, Tuple, Set, Iterable import calendar import inspect import re @@ -1338,7 +1338,7 @@ def modify(self, conditions: List[FilterCondition], values: Dict) -> None: def remove(self, conditions: List[FilterCondition]) -> None: raise NotImplementedError("Deleting owners via DELETE is not supported.") - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_owner_columns() @staticmethod @@ -1404,7 +1404,7 @@ def modify(self, conditions: List[FilterCondition], values: Dict) -> None: def remove(self, conditions: List[FilterCondition]) -> None: raise NotImplementedError("Deleting deal stages via DELETE is not supported.") - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_deal_stage_columns() @staticmethod @@ -1543,7 +1543,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(company_ids)} compan(ies) matching WHERE conditions") self.delete_companies(company_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_company_columns() @staticmethod @@ -1652,7 +1652,7 @@ def _company_to_dict(self, company: Any, columns: Optional[List[str]] = None) -> columns = columns or self._get_default_company_columns() return self._object_to_dict(company, columns) - def create_companies(self, companies_data: List[Dict[Text, Any]]) -> None: + def create_companies(self, companies_data: List[Dict[str, Any]]) -> None: if not companies_data: raise ValueError("No company data provided for creation") @@ -1673,7 +1673,7 @@ def create_companies(self, companies_data: List[Dict[Text, Any]]) -> None: logger.error(f"Companies creation failed: {str(e)}") raise Exception(f"Companies creation failed {e}") - def update_companies(self, company_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_companies(self, company_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() companies_to_update = [HubSpotObjectBatchInput(id=cid, properties=values_to_update) for cid in company_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=companies_to_update) @@ -1683,7 +1683,7 @@ def update_companies(self, company_ids: List[Text], values_to_update: Dict[Text, except Exception as e: raise Exception(f"Companies update failed {e}") - def delete_companies(self, company_ids: List[Text]) -> None: + def delete_companies(self, company_ids: List[str]) -> None: hubspot = self.handler.connect() companies_to_delete = [HubSpotObjectId(id=cid) for cid in company_ids] batch_input = BatchInputSimplePublicObjectId(inputs=companies_to_delete) @@ -1811,7 +1811,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(contact_ids)} contact(s) matching WHERE conditions") self.delete_contacts(contact_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_contact_columns() @staticmethod @@ -1948,7 +1948,7 @@ def _contact_to_dict( **{col: None for col in assoc_columns}, } - def create_contacts(self, contacts_data: List[Dict[Text, Any]]) -> None: + def create_contacts(self, contacts_data: List[Dict[str, Any]]) -> None: if not contacts_data: raise ValueError("No contact data provided for creation") @@ -1969,7 +1969,7 @@ def create_contacts(self, contacts_data: List[Dict[Text, Any]]) -> None: logger.error(f"Contacts creation failed: {str(e)}") raise Exception(f"Contacts creation failed {e}") - def update_contacts(self, contact_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_contacts(self, contact_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() contacts_to_update = [HubSpotObjectBatchInput(id=cid, properties=values_to_update) for cid in contact_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=contacts_to_update) @@ -1979,7 +1979,7 @@ def update_contacts(self, contact_ids: List[Text], values_to_update: Dict[Text, except Exception as e: raise Exception(f"Contacts update failed {e}") - def delete_contacts(self, contact_ids: List[Text]) -> None: + def delete_contacts(self, contact_ids: List[str]) -> None: hubspot = self.handler.connect() contacts_to_delete = [HubSpotObjectId(id=cid) for cid in contact_ids] batch_input = BatchInputSimplePublicObjectId(inputs=contacts_to_delete) @@ -2102,7 +2102,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(deal_ids)} deal(s) matching WHERE conditions") self.delete_deals(deal_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_deal_columns() @staticmethod @@ -2334,7 +2334,7 @@ def _deal_to_dict( row = enrich_object_with_associations(deal, "deals", row) return row - def create_deals(self, deals_data: List[Dict[Text, Any]]) -> None: + def create_deals(self, deals_data: List[Dict[str, Any]]) -> None: if not deals_data: raise ValueError("No deal data provided for creation") @@ -2355,7 +2355,7 @@ def create_deals(self, deals_data: List[Dict[Text, Any]]) -> None: logger.error(f"Deals creation failed: {str(e)}") raise Exception(f"Deals creation failed {e}") - def update_deals(self, deal_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_deals(self, deal_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() deals_to_update = [HubSpotObjectBatchInput(id=did, properties=values_to_update) for did in deal_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=deals_to_update) @@ -2365,7 +2365,7 @@ def update_deals(self, deal_ids: List[Text], values_to_update: Dict[Text, Any]) except Exception as e: raise Exception(f"Deals update failed {e}") - def delete_deals(self, deal_ids: List[Text]) -> None: + def delete_deals(self, deal_ids: List[str]) -> None: hubspot = self.handler.connect() deals_to_delete = [HubSpotObjectId(id=did) for did in deal_ids] batch_input = BatchInputSimplePublicObjectId(inputs=deals_to_delete) @@ -2463,7 +2463,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(ticket_ids)} ticket(s) matching WHERE conditions") self.delete_tickets(ticket_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_ticket_columns() @staticmethod @@ -2574,7 +2574,7 @@ def _ticket_to_dict( row = enrich_object_with_associations(ticket, "tickets", row) return row - def create_tickets(self, tickets_data: List[Dict[Text, Any]]) -> None: + def create_tickets(self, tickets_data: List[Dict[str, Any]]) -> None: if not tickets_data: raise ValueError("No ticket data provided for creation") @@ -2595,7 +2595,7 @@ def create_tickets(self, tickets_data: List[Dict[Text, Any]]) -> None: logger.error(f"Tickets creation failed: {str(e)}") raise Exception(f"Tickets creation failed {e}") - def update_tickets(self, ticket_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_tickets(self, ticket_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() tickets_to_update = [HubSpotObjectBatchInput(id=tid, properties=values_to_update) for tid in ticket_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=tickets_to_update) @@ -2605,7 +2605,7 @@ def update_tickets(self, ticket_ids: List[Text], values_to_update: Dict[Text, An except Exception as e: raise Exception(f"Tickets update failed {e}") - def delete_tickets(self, ticket_ids: List[Text]) -> None: + def delete_tickets(self, ticket_ids: List[str]) -> None: hubspot = self.handler.connect() tickets_to_delete = [HubSpotObjectId(id=tid) for tid in ticket_ids] batch_input = BatchInputSimplePublicObjectId(inputs=tickets_to_delete) @@ -2703,7 +2703,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(task_ids)} task(s) matching WHERE conditions") self.delete_tasks(task_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_task_columns() @staticmethod @@ -2816,7 +2816,7 @@ def _task_to_dict( row = enrich_object_with_associations(task, "tasks", row) return row - def create_tasks(self, tasks_data: List[Dict[Text, Any]]) -> None: + def create_tasks(self, tasks_data: List[Dict[str, Any]]) -> None: if not tasks_data: raise ValueError("No task data provided for creation") @@ -2837,7 +2837,7 @@ def create_tasks(self, tasks_data: List[Dict[Text, Any]]) -> None: logger.error(f"Tasks creation failed: {str(e)}") raise Exception(f"Tasks creation failed {e}") - def update_tasks(self, task_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_tasks(self, task_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() tasks_to_update = [HubSpotObjectBatchInput(id=tid, properties=values_to_update) for tid in task_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=tasks_to_update) @@ -2849,7 +2849,7 @@ def update_tasks(self, task_ids: List[Text], values_to_update: Dict[Text, Any]) except Exception as e: raise Exception(f"Tasks update failed {e}") - def delete_tasks(self, task_ids: List[Text]) -> None: + def delete_tasks(self, task_ids: List[str]) -> None: hubspot = self.handler.connect() tasks_to_delete = [HubSpotObjectId(id=tid) for tid in task_ids] batch_input = BatchInputSimplePublicObjectId(inputs=tasks_to_delete) @@ -2947,7 +2947,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(call_ids)} call(s) matching WHERE conditions") self.delete_calls(call_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_call_columns() @staticmethod @@ -3060,7 +3060,7 @@ def _call_to_dict( row = enrich_object_with_associations(call, "calls", row) return row - def create_calls(self, calls_data: List[Dict[Text, Any]]) -> None: + def create_calls(self, calls_data: List[Dict[str, Any]]) -> None: if not calls_data: raise ValueError("No call data provided for creation") @@ -3081,7 +3081,7 @@ def create_calls(self, calls_data: List[Dict[Text, Any]]) -> None: logger.error(f"Calls creation failed: {str(e)}") raise Exception(f"Calls creation failed {e}") - def update_calls(self, call_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_calls(self, call_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() calls_to_update = [HubSpotObjectBatchInput(id=cid, properties=values_to_update) for cid in call_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=calls_to_update) @@ -3093,7 +3093,7 @@ def update_calls(self, call_ids: List[Text], values_to_update: Dict[Text, Any]) except Exception as e: raise Exception(f"Calls update failed {e}") - def delete_calls(self, call_ids: List[Text]) -> None: + def delete_calls(self, call_ids: List[str]) -> None: hubspot = self.handler.connect() calls_to_delete = [HubSpotObjectId(id=cid) for cid in call_ids] batch_input = BatchInputSimplePublicObjectId(inputs=calls_to_delete) @@ -3191,7 +3191,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(email_ids)} email(s) matching WHERE conditions") self.delete_emails(email_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_email_columns() @staticmethod @@ -3304,7 +3304,7 @@ def _email_to_dict( row = enrich_object_with_associations(email, "emails", row) return row - def create_emails(self, emails_data: List[Dict[Text, Any]]) -> None: + def create_emails(self, emails_data: List[Dict[str, Any]]) -> None: if not emails_data: raise ValueError("No email data provided for creation") @@ -3325,7 +3325,7 @@ def create_emails(self, emails_data: List[Dict[Text, Any]]) -> None: logger.error(f"Emails creation failed: {str(e)}") raise Exception(f"Emails creation failed {e}") - def update_emails(self, email_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_emails(self, email_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() emails_to_update = [HubSpotObjectBatchInput(id=eid, properties=values_to_update) for eid in email_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=emails_to_update) @@ -3337,7 +3337,7 @@ def update_emails(self, email_ids: List[Text], values_to_update: Dict[Text, Any] except Exception as e: raise Exception(f"Emails update failed {e}") - def delete_emails(self, email_ids: List[Text]) -> None: + def delete_emails(self, email_ids: List[str]) -> None: hubspot = self.handler.connect() emails_to_delete = [HubSpotObjectId(id=eid) for eid in email_ids] batch_input = BatchInputSimplePublicObjectId(inputs=emails_to_delete) @@ -3435,7 +3435,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(meeting_ids)} meeting(s) matching WHERE conditions") self.delete_meetings(meeting_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_meeting_columns() @staticmethod @@ -3548,7 +3548,7 @@ def _meeting_to_dict( row = enrich_object_with_associations(meeting, "meetings", row) return row - def create_meetings(self, meetings_data: List[Dict[Text, Any]]) -> None: + def create_meetings(self, meetings_data: List[Dict[str, Any]]) -> None: if not meetings_data: raise ValueError("No meeting data provided for creation") @@ -3569,7 +3569,7 @@ def create_meetings(self, meetings_data: List[Dict[Text, Any]]) -> None: logger.error(f"Meetings creation failed: {str(e)}") raise Exception(f"Meetings creation failed {e}") - def update_meetings(self, meeting_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_meetings(self, meeting_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() meetings_to_update = [HubSpotObjectBatchInput(id=mid, properties=values_to_update) for mid in meeting_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=meetings_to_update) @@ -3581,7 +3581,7 @@ def update_meetings(self, meeting_ids: List[Text], values_to_update: Dict[Text, except Exception as e: raise Exception(f"Meetings update failed {e}") - def delete_meetings(self, meeting_ids: List[Text]) -> None: + def delete_meetings(self, meeting_ids: List[str]) -> None: hubspot = self.handler.connect() meetings_to_delete = [HubSpotObjectId(id=mid) for mid in meeting_ids] batch_input = BatchInputSimplePublicObjectId(inputs=meetings_to_delete) @@ -3679,7 +3679,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(note_ids)} note(s) matching WHERE conditions") self.delete_notes(note_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_note_columns() @staticmethod @@ -3787,7 +3787,7 @@ def _note_to_dict( row = enrich_object_with_associations(note, "notes", row) return row - def create_notes(self, notes_data: List[Dict[Text, Any]]) -> None: + def create_notes(self, notes_data: List[Dict[str, Any]]) -> None: if not notes_data: raise ValueError("No note data provided for creation") @@ -3808,7 +3808,7 @@ def create_notes(self, notes_data: List[Dict[Text, Any]]) -> None: logger.error(f"Notes creation failed: {str(e)}") raise Exception(f"Notes creation failed {e}") - def update_notes(self, note_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_notes(self, note_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() notes_to_update = [HubSpotObjectBatchInput(id=nid, properties=values_to_update) for nid in note_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=notes_to_update) @@ -3820,7 +3820,7 @@ def update_notes(self, note_ids: List[Text], values_to_update: Dict[Text, Any]) except Exception as e: raise Exception(f"Notes update failed {e}") - def delete_notes(self, note_ids: List[Text]) -> None: + def delete_notes(self, note_ids: List[str]) -> None: hubspot = self.handler.connect() notes_to_delete = [HubSpotObjectId(id=nid) for nid in note_ids] batch_input = BatchInputSimplePublicObjectId(inputs=notes_to_delete) @@ -3918,7 +3918,7 @@ def remove(self, conditions: List[FilterCondition]) -> None: logger.info(f"Deleting {len(lead_ids)} lead(s) matching WHERE conditions") self.delete_leads(lead_ids) - def get_columns(self) -> List[Text]: + def get_columns(self) -> List[str]: return self._get_default_lead_columns() @staticmethod @@ -4027,7 +4027,7 @@ def _lead_to_dict( row = enrich_object_with_associations(lead, "leads", row) return row - def create_leads(self, leads_data: List[Dict[Text, Any]]) -> None: + def create_leads(self, leads_data: List[Dict[str, Any]]) -> None: if not leads_data: raise ValueError("No lead data provided for creation") @@ -4048,7 +4048,7 @@ def create_leads(self, leads_data: List[Dict[Text, Any]]) -> None: logger.error(f"Leads creation failed: {str(e)}") raise Exception(f"Leads creation failed {e}") - def update_leads(self, lead_ids: List[Text], values_to_update: Dict[Text, Any]) -> None: + def update_leads(self, lead_ids: List[str], values_to_update: Dict[str, Any]) -> None: hubspot = self.handler.connect() leads_to_update = [HubSpotObjectBatchInput(id=lid, properties=values_to_update) for lid in lead_ids] batch_input = BatchInputSimplePublicObjectBatchInput(inputs=leads_to_update) @@ -4060,7 +4060,7 @@ def update_leads(self, lead_ids: List[Text], values_to_update: Dict[Text, Any]) except Exception as e: raise Exception(f"Leads update failed {e}") - def delete_leads(self, lead_ids: List[Text]) -> None: + def delete_leads(self, lead_ids: List[str]) -> None: hubspot = self.handler.connect() leads_to_delete = [HubSpotObjectId(id=lid) for lid in lead_ids] batch_input = BatchInputSimplePublicObjectId(inputs=leads_to_delete) From 5e5b706cfedeb6cb1d83047335f7ffc595a31eb7 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 27 Feb 2026 19:42:26 +0300 Subject: [PATCH 024/125] fix error "cannot access local variable 'column_quoted' where it is not associated with a value" --- mindsdb/api/executor/sql_query/steps/subselect_step.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mindsdb/api/executor/sql_query/steps/subselect_step.py b/mindsdb/api/executor/sql_query/steps/subselect_step.py index 8e4e5725cf9..8ebf7c0ba39 100644 --- a/mindsdb/api/executor/sql_query/steps/subselect_step.py +++ b/mindsdb/api/executor/sql_query/steps/subselect_step.py @@ -183,6 +183,8 @@ def check_fields(node, is_target=None, **kwargs): "version for the right syntax to use near '$$' at line 1" ) + key, column_quoted = (), False + match node.parts, node.is_quoted: case [column_name], [column_quoted]: if column_name in aliases: From 35a32ceb995374f8b7d3e61a3c15a851d193a132 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 2 Mar 2026 16:17:38 +0300 Subject: [PATCH 025/125] Added timeout (5 seconds), if it passed: don't try to fetch sample of data and fetch all records --- .../executor/sql_query/steps/fetch_dataframe_partition.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py index 30de48b9442..819f7f1d563 100644 --- a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +++ b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py @@ -1,3 +1,4 @@ +import time import copy import pandas as pd from typing import List @@ -105,6 +106,7 @@ def repeat_till_reach_limit(self, step, limit): query, context_callback = query_context_controller.handle_db_context_vars(query, dn, self.session) try_num = 1 + started_at = time.time() while True: self.substeps = copy.deepcopy(step.steps) query2 = copy.deepcopy(query) @@ -126,7 +128,8 @@ def repeat_till_reach_limit(self, step, limit): result = result[:limit] break - if try_num > 3: + # break if process is too long or to many tries + if try_num > 3 or started_at - time.time() > 5: # the last try without the limit first_table_limit = None continue From a37b41122f49b9abd59eac2fb91fb503178aa7df Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 2 Mar 2026 16:24:56 +0300 Subject: [PATCH 026/125] fix selecting from view --- mindsdb/interfaces/database/projects.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/mindsdb/interfaces/database/projects.py b/mindsdb/interfaces/database/projects.py index d51811d3f04..bd7b443f833 100644 --- a/mindsdb/interfaces/database/projects.py +++ b/mindsdb/interfaces/database/projects.py @@ -8,7 +8,7 @@ import numpy as np from mindsdb_sql_parser.ast.base import ASTNode -from mindsdb_sql_parser.ast import Select, Star, Constant, Identifier, BinaryOperation +from mindsdb_sql_parser.ast import Select, Star, Constant, Identifier, BinaryOperation, Join from mindsdb_sql_parser import parse_sql from mindsdb.interfaces.storage import db @@ -185,28 +185,34 @@ def get_conditions_to_move(node): # column is not in black list AND (query has star(*) OR column in white list) has_star = False - white_list, black_list = [], [] + white_list, black_list = {}, [] for target in view_query.targets: if isinstance(target, Star): has_star = True if isinstance(target, Identifier): name = target.parts[-1].lower() if target.alias is None or target.alias.parts[-1].lower() == name: - white_list.append(name) + white_list[name] = target elif target.alias is not None: black_list.append(target.alias.parts[-1].lower()) + is_join = isinstance(view_query.from_table, Join) view_where = view_query.where for condition in conditions: arg1, arg2 = condition.args if isinstance(arg1, Identifier): name = arg1.parts[-1].lower() - if name in black_list or not (has_star or name in white_list): + if name in white_list: + arg1 = white_list[name] + # don't move condition for join with Star + elif name in black_list or not (has_star and not is_join): continue if isinstance(arg2, Identifier): name = arg2.parts[-1].lower() - if name in black_list or not (has_star or name in white_list): + if name in white_list: + arg2 = white_list[name] + if name in black_list or not (has_star and not is_join): continue # condition can be moved into view @@ -224,7 +230,13 @@ def get_conditions_to_move(node): # combine outer query with view's query view_query.parentheses = True + + # keep alias (column of the query might relate to it) + alias = query.from_table.alias if query.from_table.alias is not None else query.from_table + view_query.alias = Identifier(parts=[alias.parts[-1]]) + query.from_table = view_query + return query def query_view(self, query: Select, session) -> pd.DataFrame: From 7f0254208fcd68f99ce3297a63bb4c8520de662a Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 2 Mar 2026 16:25:11 +0300 Subject: [PATCH 027/125] fix join with cte --- mindsdb/api/executor/planner/plan_join.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/mindsdb/api/executor/planner/plan_join.py b/mindsdb/api/executor/planner/plan_join.py index 603528ac6f8..a7eb26800ef 100644 --- a/mindsdb/api/executor/planner/plan_join.py +++ b/mindsdb/api/executor/planner/plan_join.py @@ -358,6 +358,7 @@ def _check_identifiers(node, is_table, **kwargs): else: self.has_ambiguous_columns = True + query.cte = None # already used before query_traversal(query, _check_identifiers) self.check_query_conditions(query) @@ -371,6 +372,8 @@ def _check_identifiers(node, is_table, **kwargs): # create plan # TODO add optimization: one integration without predictor + planned_steps_before_join = len(self.planner.plan.steps) + self.step_stack = [] for item in join_sequence: if isinstance(item, TableInfo): @@ -400,20 +403,25 @@ def _check_identifiers(node, is_table, **kwargs): query_in.where = query.where if self.query_context["optimize_inner_join"]: - self.planner.plan.steps = self.optimize_inner_join(self.planner.plan.steps) + self.planner.plan.steps = self.optimize_inner_join(self.planner.plan.steps, planned_steps_before_join) self.close_partition() return self.planner.plan.steps[-1] - def optimize_inner_join(self, steps_in): + def optimize_inner_join(self, steps_in, min_step_num): steps_out = [] partition_step = None partition_used = False - for step in steps_in: + for i, step in enumerate(steps_in): if partition_step is None: - if isinstance(step, FetchDataframeStep) and not partition_used and step.query.limit is not None: + if ( + i >= min_step_num + and isinstance(step, FetchDataframeStep) + and not partition_used + and step.query.limit is not None + ): limit = step.query.limit.value step.query.limit = None partition_used = True From c00a2eeb8b9cae2a6e10ee1de968e7f463a7ca3c Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 2 Mar 2026 16:25:17 +0300 Subject: [PATCH 028/125] tests --- tests/unit/executor/test_base_queires.py | 21 +++++++++ tests/unit/planner/test_join_tables.py | 59 ++++++++++++++++++------ 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/tests/unit/executor/test_base_queires.py b/tests/unit/executor/test_base_queires.py index 5fbece5c4d3..468312d40a5 100644 --- a/tests/unit/executor/test_base_queires.py +++ b/tests/unit/executor/test_base_queires.py @@ -899,6 +899,27 @@ def test_subselect_1row_aggregate(self, data_handler): assert len(ret) == 1 assert ret["result"][0] == 1 + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_cte_join(self, data_handler): + self.set_handler(data_handler, name="pg", tables={"stores": get_stores_df()}) + self.save_file("regions", get_regions_df()) + + ret = self.run_sql(""" + WITH regions AS ( + SELECT DISTINCT id, name FROM files.regions + ), + stores AS ( + SELECT * FROM pg.stores + LIMIT 10 + ) + SELECT format, region_id FROM pg.stores s + JOIN regions r on r.id = s.region_id + WHERE s.format IN (SELECT format FROM stores WHERE format='a') + LIMIT 100; + """) + assert len(ret) > 1 + assert ret["format"][0] == "a" + class TestSet(BaseExecutorTest): @pytest.mark.parametrize("var", ["var", "@@var", "@@session.var", "session var"]) diff --git a/tests/unit/planner/test_join_tables.py b/tests/unit/planner/test_join_tables.py index 24cef73b8fa..7bd8a463d7a 100644 --- a/tests/unit/planner/test_join_tables.py +++ b/tests/unit/planner/test_join_tables.py @@ -11,6 +11,7 @@ Star, BinaryOperation, Function, + Parameter, ) from mindsdb_sql_parser.utils import JoinType @@ -319,43 +320,71 @@ def test_join_tables_plan_limit_offset(self): def test_join_tables_plan_order_by(self): query = parse_sql(""" + WITH tab2 AS ( + SELECT * FROM int2.tab2 limit 100 + ), + categories as ( + SELECT * FROM int3.cats + ) SELECT tab1.column1, tab2.column1, tab2.column2 - FROM int.tab1 INNER - JOIN int2.tab2 ON tab1.column1 > tab2.column1 + FROM int.tab1 tab1 + INNER JOIN tab2 ON tab1.column1 > tab2.column1 + WHERE tab2.category_id = (SELECT id FROM categories WHERE name='book') ORDER BY tab1.column1 LIMIT 10 """) subquery = copy.deepcopy(query) + subquery.cte = None subquery.from_table = None subquery.offset = None + subquery.where.args[1] = Parameter(Result(2)) - plan = plan_query(query, integrations=["int", "int2"]) + plan = plan_query(query, integrations=["int", "int2", "int3"], default_namespace="mindsdb") expected_plan = QueryPlan( integrations=["int"], steps=[ - FetchDataframeStepPartition( + FetchDataframeStep( step_num=0, + integration="int2", + query=parse_sql("select * from tab2 limit 100"), + ), + FetchDataframeStep( + step_num=1, + integration="int3", + query=parse_sql("select * from cats"), + ), + SubSelectStep( + step_num=2, + query=Select( + targets=[Identifier("id")], + where=BinaryOperation(op="=", args=[Identifier("name"), Constant("book")]), + ), + dataframe=Result(1), + table_name="categories", + ), + FetchDataframeStepPartition( + step_num=3, integration="int", - query=parse_sql("select column1 AS column1 from tab1 order by column1"), + query=parse_sql("select column1 AS column1 from tab1 AS tab1 order by column1"), condition={"limit": 10}, steps=[ - FetchDataframeStep( - step_num=1, - integration="int2", + SubSelectStep( + step_num=4, + dataframe=Result(0), query=Select( targets=[ - Identifier("column1", alias=Identifier("column1")), - Identifier("column2", alias=Identifier("column2")), + Star(), ], # Column pruning - from_table=Identifier("tab2"), + where=BinaryOperation(op="=", args=[Identifier("category_id"), Parameter(Result(2))]), ), + table_name="tab2", ), JoinStep( - step_num=2, - left=Result(0), - right=Result(1), + step_num=5, + left=Result(3), + right=Result(4), query=Join( left=Identifier("tab1"), right=Identifier("tab2"), @@ -367,7 +396,7 @@ def test_join_tables_plan_order_by(self): ), ], ), - QueryStep(subquery, from_table=Result(0), strict_where=False), + QueryStep(subquery, from_table=Result(3), strict_where=False), ], ) From 5c1c235474f330f4ef4bd4d5f151a9a1ca04da21 Mon Sep 17 00:00:00 2001 From: andrew Date: Mon, 2 Mar 2026 18:01:43 +0300 Subject: [PATCH 029/125] fix tests --- .../interfaces/query_context/last_query.py | 139 +++++++++++------- 1 file changed, 82 insertions(+), 57 deletions(-) diff --git a/mindsdb/interfaces/query_context/last_query.py b/mindsdb/interfaces/query_context/last_query.py index 1df233d4405..0043c55aa1a 100644 --- a/mindsdb/interfaces/query_context/last_query.py +++ b/mindsdb/interfaces/query_context/last_query.py @@ -3,7 +3,17 @@ from collections import defaultdict from mindsdb_sql_parser.ast import ( - Identifier, Select, BinaryOperation, Last, Constant, Star, ASTNode, NullConstant, OrderBy, Function, TypeCast + Identifier, + Select, + BinaryOperation, + Last, + Constant, + Star, + ASTNode, + NullConstant, + OrderBy, + Function, + TypeCast, ) from mindsdb.integrations.utilities.query_traversal import query_traversal @@ -34,21 +44,21 @@ def __init__(self, query: ASTNode): def _find_last_columns(self, query: ASTNode) -> Union[dict, None]: """ - This function: - - Searches LAST column in the input query - - Replaces it with constants and memorises link to these constants - - Link to constants will be used to inject values to query instead of LAST - - Provide checks: - - if it is possible to find the table for column - - if column in select target - - Generates and returns last_column variable which is dict - last_columns[table_name] = { - 'table': , - 'column': , - 'links': [, ... ], - 'target_idx': , - 'gen_init_query': if true: to generate query to initial values for LAST - } + This function: + - Searches LAST column in the input query + - Replaces it with constants and memorises link to these constants + - Link to constants will be used to inject values to query instead of LAST + - Provide checks: + - if it is possible to find the table for column + - if column in select target + - Generates and returns last_column variable which is dict + last_columns[table_name] = { + 'table':
, + 'column': , + 'links': [, ... ], + 'target_idx': , + 'gen_init_query': if true: to generate query to initial values for LAST + } """ # index last variables in query @@ -76,7 +86,6 @@ def replace_last_in_tree(node: ASTNode, injected: Constant): return found def index_query(node, is_table, parent_query, **kwargs): - parent_query_id = id(parent_query) last = None if is_table and isinstance(node, Identifier): @@ -105,13 +114,15 @@ def index_query(node, is_table, parent_query, **kwargs): if last is not None: # memorize - conditions.append({ - 'query_id': parent_query_id, - 'condition': node, - 'last': last, - 'column': col, - 'gen_init_query': gen_init_query # generate query to fetch initial last values from table - }) + conditions.append( + { + "query_id": parent_query_id, + "condition": node, + "last": last, + "column": col, + "gen_init_query": gen_init_query, # generate query to fetch initial last values from table + } + ) # find lasts query_traversal(query, index_query) @@ -122,7 +133,7 @@ def index_query(node, is_table, parent_query, **kwargs): self.query_orig = copy.deepcopy(query) for info in conditions: - self.last_idx[info['query_id']].append(info) + self.last_idx[info["query_id"]].append(info) # index query targets query_id = id(query) @@ -152,21 +163,20 @@ def index_query(node, is_table, parent_query, **kwargs): last_columns = {} for parent_query_id, items in self.last_idx.items(): for info in items: - col = info['column'] - last = info['last'] + col = info["column"] + last = info["last"] tables = tables_idx[parent_query_id] uniq_tables = len(set([id(v) for v in tables.values()])) if len(col.parts) > 1: - table = tables.get(col.parts[-2]) if table is None: - raise ValueError('cant find table') + raise ValueError("cant find table") elif uniq_tables == 1: table = list(tables.values())[0] else: # or just skip it? - raise ValueError('cant find table') + raise ValueError("cant find table") col_name = col.parts[-1] @@ -179,29 +189,46 @@ def index_query(node, is_table, parent_query, **kwargs): # will try to get by name ... else: - raise ValueError('Last value should be in query target') + raise ValueError("Last value should be in query target") last_columns[table_name] = { - 'table': table, - 'column': col_name, - 'links': [last], - 'target_idx': target_idx, - 'gen_init_query': info['gen_init_query'] + "table": table, + "column": col_name, + "links": [last], + "target_idx": target_idx, + "gen_init_query": info["gen_init_query"], } - elif last_columns[table_name]['column'] == col_name: - last_columns[table_name]['column'].append(last) + elif last_columns[table_name]["column"] == col_name: + last_columns[table_name]["column"].append(last) else: - raise ValueError('possible to use only one column') + raise ValueError("possible to use only one column") return last_columns def to_string(self) -> str: """ - String representation of the query - Used to identify query in query_context table + String representation of the query + Used to identify query in query_context table """ - return self.query_orig.to_string() + query = self.query_orig + if isinstance(query.from_table, Select) and query.targets == [Star()]: + # simplify nested query + if ( + query.group_by is None + and query.order_by is None + and query.having is None + and query.distinct is False + and query.where is None + and query.limit is None + and query.offset is None + and query.cte is None + ): + query = query.from_table + query.parentheses = False + query.alias = None + + return query.to_string() def get_last_columns(self) -> List[dict]: """ @@ -210,11 +237,11 @@ def get_last_columns(self) -> List[dict]: """ return [ { - 'table': info['table'], - 'table_name': table_name, - 'column_name': info['column'], - 'target_idx': info['target_idx'], - 'gen_init_query': info['gen_init_query'], + "table": info["table"], + "table_name": table_name, + "column_name": info["column"], + "target_idx": info["target_idx"], + "gen_init_query": info["gen_init_query"], } for table_name, info in self.last_tables.items() ] @@ -224,8 +251,8 @@ def apply_values(self, values: dict) -> ASTNode: Fills query with new values and return it """ for table_name, info in self.last_tables.items(): - value = values.get(table_name, {}).get(info['column']) - for last in info['links']: + value = values.get(table_name, {}).get(info["column"]) + for last in info["links"]: last.value = value return self.query @@ -239,9 +266,9 @@ def get_init_queries(self): # replace values for items in self.last_idx.values(): for info in items: - node = info['condition'] + node = info["condition"] back_up_values.append([node.op, node.args[1]]) - node.op = 'is not' + node.op = "is not" node.args[1] = NullConstant() query2 = copy.deepcopy(self.query) @@ -249,18 +276,16 @@ def get_init_queries(self): # return values for items in self.last_idx.values(): for info in items: - node = info['condition'] + node = info["condition"] op, arg1 = back_up_values.pop(0) node.op = op node.args[1] = arg1 for info in self.get_last_columns(): - if not info['gen_init_query']: + if not info["gen_init_query"]: continue - col = Identifier(info['column_name']) + col = Identifier(info["column_name"]) query2.targets = [col] - query2.order_by = [ - OrderBy(col, direction='DESC') - ] + query2.order_by = [OrderBy(col, direction="DESC")] query2.limit = Constant(1) yield query2, info From 90b9007a87abcc6b21f7d3897c2b3cd7ceb2ca83 Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 4 Mar 2026 16:03:10 +0300 Subject: [PATCH 030/125] prevent duplicated columns in view --- mindsdb/interfaces/database/projects.py | 12 +++++++++++- tests/unit/executor/test_base_queires.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mindsdb/interfaces/database/projects.py b/mindsdb/interfaces/database/projects.py index bd7b443f833..6e063d603ad 100644 --- a/mindsdb/interfaces/database/projects.py +++ b/mindsdb/interfaces/database/projects.py @@ -125,7 +125,17 @@ def create_view(self, name: str, query: str, session): query_context_controller.set_context(query_context_controller.IGNORE_CONTEXT) try: - SQLQuery(ast_query, session=session, database=self.name) + query = SQLQuery(ast_query, session=session, database=self.name) + columns = [col.name for col in query.fetched_data.columns] + seen, duplicates = set(), set() + for col in columns: + if col in seen: + duplicates.add(col) + else: + seen.add(col) + if len(duplicates) > 0: + raise ValueError(f"Found duplicated columns in the view: {', '.join(duplicates)}") + finally: query_context_controller.release_context(query_context_controller.IGNORE_CONTEXT) diff --git a/tests/unit/executor/test_base_queires.py b/tests/unit/executor/test_base_queires.py index 468312d40a5..0a0e3c2ab79 100644 --- a/tests/unit/executor/test_base_queires.py +++ b/tests/unit/executor/test_base_queires.py @@ -920,6 +920,19 @@ def test_cte_join(self, data_handler): assert len(ret) > 1 assert ret["format"][0] == "a" + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_view_duplicated_cols(self, data_handler): + self.set_handler(data_handler, name="pg", tables={"stores": get_stores_df(), "regions": get_regions_df()}) + + with pytest.raises(Exception): + # `id` exists in both tables, should raise an exception + self.run_sql(""" + create view v1 ( + select * from pg.stores s + join pg.regions r on r.id = s.region_id + ) + """) + class TestSet(BaseExecutorTest): @pytest.mark.parametrize("var", ["var", "@@var", "@@session.var", "session var"]) From 7f2cc884cea32801be43d41905fd56fb12654dd4 Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 4 Mar 2026 17:47:55 +0300 Subject: [PATCH 031/125] fix --- mindsdb/interfaces/database/projects.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindsdb/interfaces/database/projects.py b/mindsdb/interfaces/database/projects.py index 6e063d603ad..a8f96092f24 100644 --- a/mindsdb/interfaces/database/projects.py +++ b/mindsdb/interfaces/database/projects.py @@ -125,8 +125,8 @@ def create_view(self, name: str, query: str, session): query_context_controller.set_context(query_context_controller.IGNORE_CONTEXT) try: - query = SQLQuery(ast_query, session=session, database=self.name) - columns = [col.name for col in query.fetched_data.columns] + resp = SQLQuery(ast_query, session=session, database=self.name) + columns = [col.name for col in resp.fetched_data.columns] seen, duplicates = set(), set() for col in columns: if col in seen: From bb9dddf45d184c16bec148303f06dd070bcb9190 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Thu, 5 Mar 2026 18:25:34 +0300 Subject: [PATCH 032/125] Ability for data handlers to stream data (#12018) --- Makefile | 2 +- docs/contribute/app-handlers.mdx | 4 +- docs/contribute/data-handlers.mdx | 132 +++- docs/mindsdb-handlers.mdx | 2 +- docs/rest/sql.mdx | 67 +- mindsdb/api/executor/command_executor.py | 3 +- mindsdb/api/executor/data_types/sql_answer.py | 129 ++++ .../api/executor/datahub/classes/response.py | 14 - .../executor/datahub/datanodes/datanode.py | 5 +- .../datanodes/information_schema_datanode.py | 16 +- .../datahub/datanodes/integration_datanode.py | 76 +- .../datahub/datanodes/project_datanode.py | 27 +- mindsdb/api/executor/sql_query/result_set.py | 87 ++- mindsdb/api/executor/sql_query/sql_query.py | 12 +- .../sql_query/steps/apply_predictor_step.py | 3 +- .../sql_query/steps/fetch_dataframe.py | 20 +- .../executor/sql_query/steps/insert_step.py | 3 +- .../executor/sql_query/steps/prepare_steps.py | 28 +- .../sql_query/steps/subselect_step.py | 3 +- mindsdb/api/http/namespaces/sql.py | 106 +-- .../mysql_proxy/executor/mysql_executor.py | 2 +- mindsdb/api/mysql/mysql_proxy/mysql_proxy.py | 45 +- .../api/mysql/mysql_proxy/utilities/dump.py | 3 +- .../dummy_data_handler/dummy_data_handler.py | 2 +- .../handlers/mysql_handler/mysql_handler.py | 192 +++-- .../handlers/oracle_handler/oracle_handler.py | 245 ++++--- .../pgvector_handler/pgvector_handler.py | 4 +- .../postgres_handler/postgres_handler.py | 221 ++++-- .../snowflake_handler/snowflake_handler.py | 348 +++++---- mindsdb/integrations/libs/base.py | 118 ++- .../integrations/libs/keyword_search_base.py | 2 +- mindsdb/integrations/libs/ml_exec_base.py | 2 +- mindsdb/integrations/libs/response.py | 479 ++++++++++++- .../libs/vectordatabase_handler.py | 42 +- .../utilities/rag/retrievers/sql_retriever.py | 10 +- mindsdb/interfaces/database/log.py | 16 +- mindsdb/interfaces/jobs/jobs_controller.py | 5 +- .../query_context/context_controller.py | 13 +- mindsdb/utilities/config.py | 3 + mindsdb/utilities/types/__init__.py | 0 mindsdb/utilities/types/column.py | 30 + tests/unit/api/http/test_sql_query.py | 145 ++++ tests/unit/executor/test_api_handler.py | 5 +- tests/unit/executor/test_knowledge_base.py | 20 +- tests/unit/executor_test_base.py | 12 +- tests/unit/handlers/base_handler_test.py | 18 +- tests/unit/handlers/test_bigquery.py | 22 +- tests/unit/handlers/test_clickhouse.py | 18 +- tests/unit/handlers/test_confluence.py | 20 +- tests/unit/handlers/test_databricks.py | 55 +- tests/unit/handlers/test_dynamodb.py | 90 +-- tests/unit/handlers/test_mariadb.py | 26 +- tests/unit/handlers/test_mongodb.py | 48 +- tests/unit/handlers/test_mssql.py | 68 +- tests/unit/handlers/test_mysql.py | 156 ++-- tests/unit/handlers/test_oracle.py | 104 ++- tests/unit/handlers/test_postgres.py | 212 +++--- tests/unit/handlers/test_redshift.py | 79 +-- tests/unit/handlers/test_s3.py | 114 ++- tests/unit/handlers/test_salesforce.py | 12 +- tests/unit/handlers/test_slack.py | 6 +- tests/unit/handlers/test_snowflake.py | 83 +-- tests/unit/handlers/test_timescaledb.py | 21 +- tests/unit/integrations/__init__.py | 0 tests/unit/integrations/libs/__init__.py | 0 tests/unit/integrations/libs/test_response.py | 671 ++++++++++++++++++ 66 files changed, 3156 insertions(+), 1370 deletions(-) create mode 100644 mindsdb/api/executor/data_types/sql_answer.py delete mode 100644 mindsdb/api/executor/datahub/classes/response.py create mode 100644 mindsdb/utilities/types/__init__.py create mode 100644 mindsdb/utilities/types/column.py create mode 100644 tests/unit/api/http/test_sql_query.py create mode 100644 tests/unit/integrations/__init__.py create mode 100644 tests/unit/integrations/libs/__init__.py create mode 100644 tests/unit/integrations/libs/test_response.py diff --git a/Makefile b/Makefile index 05cea89b906..26d90872dc5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -PYTEST_ARGS = -v -rs --disable-warnings -n auto --dist loadfile +PYTEST_ARGS = -v -xrs --disable-warnings -n 1 --dist loadfile PYTEST_ARGS_DEBUG = --runslow -vs -rs DSI_PYTEST_ARGS = --run-dsi-tests DSI_REPORT_ARGS = --json-report --json-report-file=reports/report.json diff --git a/docs/contribute/app-handlers.mdx b/docs/contribute/app-handlers.mdx index 0c0a24639e1..040d3e2bc37 100644 --- a/docs/contribute/app-handlers.mdx +++ b/docs/contribute/app-handlers.mdx @@ -118,13 +118,13 @@ Here is a step-by-step guide: The `native_query()` method runs commands of the native API syntax. ```py - def native_query(self, query: Any) -> HandlerResponse: + def native_query(self, query: Any) -> TableResponse | OkResponse | ErrorResponse: """Receive raw query and act upon it somehow. Args: query (Any): query in native format (str for sql databases, api's json etc) Returns: - HandlerResponse + TableResponse | OkResponse | ErrorResponse """ ``` diff --git a/docs/contribute/data-handlers.mdx b/docs/contribute/data-handlers.mdx index ca796627a7c..cb13aa0621d 100644 --- a/docs/contribute/data-handlers.mdx +++ b/docs/contribute/data-handlers.mdx @@ -45,7 +45,15 @@ Authors can opt for adding private methods, new files and folders, or any combin Under the `mindsdb.integrations.libs.utils` library, contributors can find various methods that may be useful while implementing new handlers. - Also, there are wrapper classes for the `DatabaseHandler` instances called [HandlerResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py#L7) and [HandlerStatusResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py#L32). You should use them to ensure proper output formatting. + For response formatting, use the following classes from `mindsdb.integrations.libs.response`: + - [TableResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py) - for queries returning data (SELECT, SHOW, etc.) + - [OkResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py) - for successful operations without data (CREATE, DROP, INSERT, etc.) + - [ErrorResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py) - for error cases + - [HandlerStatusResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py) - for connection status checks + + + The legacy `HandlerResponse` class is deprecated. Use `TableResponse`, `OkResponse`, or `ErrorResponse` instead. + ### Implementation @@ -124,13 +132,13 @@ Here is a step-by-step guide: The `native_query()` method runs commands of the native database language. ```py - def native_query(self, query: Any) -> HandlerResponse: + def native_query(self, query: Any) -> TableResponse | OkResponse | ErrorResponse: """Receive raw query and act upon it somehow. Args: query (Any): query in native format (str for sql databases, etc) Returns: - HandlerResponse + TableResponse | OkResponse | ErrorResponse """ ``` @@ -139,13 +147,13 @@ Here is a step-by-step guide: The query method runs parsed SQL commands. ```py - def query(self, query: ASTNode) -> HandlerResponse: + def query(self, query: ASTNode) -> TableResponse | OkResponse | ErrorResponse: """Receive query as AST (abstract syntax tree) and act upon it somehow. Args: query (ASTNode): sql query represented as AST. May be any kind of query: SELECT, INSERT, DELETE, etc Returns: - HandlerResponse + TableResponse | OkResponse | ErrorResponse """ ``` @@ -154,11 +162,11 @@ Here is a step-by-step guide: The `get_tables()` method lists all the available tables. ```py - def get_tables(self) -> HandlerResponse: + def get_tables(self) -> TableResponse | ErrorResponse: """ Return list of entities Return a list of entities that will be accessible as tables. Returns: - HandlerResponse: should have the same columns as information_schema.tables + TableResponse | ErrorResponse: should have the same columns as information_schema.tables (https://dev.mysql.com/doc/refman/8.0/en/information-schema-tables-table.html) Column 'TABLE_NAME' is mandatory, other is optional. """ @@ -169,12 +177,12 @@ Here is a step-by-step guide: The `get_columns()` method lists all columns of a specified table. ```py - def get_columns(self, table_name: str) -> HandlerResponse: + def get_columns(self, table_name: str) -> TableResponse | ErrorResponse: """ Returns a list of entity columns Args: table_name (str): name of one of tables returned by self.get_tables() Returns: - HandlerResponse: should have the same columns as information_schema.columns + TableResponse | ErrorResponse: data should have the same columns as information_schema.columns (https://dev.mysql.com/doc/refman/8.0/en/information-schema-columns-table.html) Column 'COLUMN_NAME' is mandatory, other is optional. Highly recommended to define also 'DATA_TYPE': it should be one of @@ -182,6 +190,112 @@ Here is a step-by-step guide: """ ``` +### Response Classes + +The data-returning methods (`native_query()`, `query()`, `get_tables()`, `get_columns()`) should return one of the following response classes from `mindsdb.integrations.libs.response`: + +| Response Class | Use Case | Key Attributes | +|---------------|----------|----------------| +| `TableResponse` | Queries that return data (SELECT, SHOW, etc.) | `data`, `data_generator`, `columns`, `affected_rows` | +| `OkResponse` | Successful operations without data (CREATE, DROP, INSERT, UPDATE, DELETE) | `affected_rows` | +| `ErrorResponse` | Error cases | `error_code`, `error_message`, `is_expected_error` | + +#### TableResponse + +`TableResponse` is used when returning data from queries. It supports two modes of data delivery: + +1. **Immediate data**: Pass all data at once via the `data` parameter (pandas DataFrame) +2. **Streaming data**: Pass a generator via the `data_generator` parameter for lazy loading + +```py +from mindsdb.integrations.libs.response import TableResponse, OkResponse, ErrorResponse + +# Immediate data response +def native_query(self, query: str) -> TableResponse: + result = self.execute_query(query) + df = pd.DataFrame(result) + return TableResponse(data=df) + +# Streaming data response (for large datasets) +def native_query(self, query: str) -> TableResponse: + def data_generator(): + cursor = self.connection.cursor() + cursor.execute(query) + while batch := cursor.fetchmany(size=1000): + yield pd.DataFrame(batch) + + return TableResponse(data_generator=data_generator()) +``` + +#### OkResponse + +`OkResponse` is used for operations that don't return data: + +```py +def native_query(self, query: str) -> OkResponse: + cursor = self.connection.cursor() + cursor.execute(query) + self.connection.commit() + return OkResponse(affected_rows=cursor.rowcount) +``` + +#### ErrorResponse + +`ErrorResponse` is used to report errors: + +```py +def native_query(self, query: str) -> ErrorResponse: + try: + # ... execute query + except DatabaseError as e: + return ErrorResponse( + error_code=e.code, + error_message=str(e), + is_expected_error=True # Set to True for user errors (syntax, permissions, etc.) + ) +``` + +### Streaming Support + +For handlers that deal with large datasets, implementing streaming support is recommended. This allows data to be returned in chunks rather than loading everything into memory at once. + +To enable streaming: + +1. Set the `stream_response` class attribute to `True`: + + ```py + class MyDatabaseHandler(DatabaseHandler): + name = "mydatabase" + stream_response = True # Indicates that handler can return data as a generator + ``` + +2. Implement `native_query()` to return a `TableResponse` with a `data_generator`: + + ```py + def native_query(self, query: str, stream: bool = True) -> TableResponse | OkResponse | ErrorResponse: + if stream: + return self._execute_streaming(query) + else: + return self._execute_immediate(query) + + def _execute_streaming(self, query: str) -> TableResponse: + """Execute query and return results as a stream.""" + cursor = self.connection.cursor(name="server_side_cursor") + cursor.execute(query) + + columns = [Column(name=col.name, type=col.type) for col in cursor.description] + + def generate_data(): + while batch := cursor.fetchmany(size=1000): + yield pd.DataFrame(batch, columns=[c.name for c in columns]) + + return TableResponse(columns=columns, data_generator=generate_data()) + ``` + + +For a complete example of streaming implementation, see the [PostgreSQL handler](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py). + + ### Exporting the `connection_args` Dictionary The `connection_args` dictionary contains all of the arguments used to establish the connection along with their descriptions, types, labels, and whether they are required or not. diff --git a/docs/mindsdb-handlers.mdx b/docs/mindsdb-handlers.mdx index c69e09cee4c..0d9a1aaac36 100644 --- a/docs/mindsdb-handlers.mdx +++ b/docs/mindsdb-handlers.mdx @@ -76,7 +76,7 @@ Whenever you want to parse a string that contains SQL, we strongly recommend usi ### Formatting Output -In the case of data handlers, when it comes to building the response of the public methods, the output should be wrapped by the [mindsdb.integrations.libs.response.HandlerResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py#L7) or [mindsdb.integrations.libs.response.HandlerStatusResponse](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py#L32) class. These classes are used by the MindsDB executioner to orchestrate and coordinate multiple handler instances in parallel. +In the case of data handlers, the data-returning methods (`native_query()`, `query()`, `get_tables()`, `get_columns()`) should return one of the response classes from [mindsdb.integrations.libs.response](https://github.com/mindsdb/mindsdb/blob/main/mindsdb/integrations/libs/response.py) And in the case of ML handlers, output wrapping is automatically done by an intermediate wrapper, the `BaseMLEngineExec` class, so the contributor wouldn't need to worry about it. diff --git a/docs/rest/sql.mdx b/docs/rest/sql.mdx index 0a8d5bd0f5b..4b8604931d0 100644 --- a/docs/rest/sql.mdx +++ b/docs/rest/sql.mdx @@ -20,6 +20,29 @@ String that contains the SQL query that needs to be executed. + + +Format of the response. Available options: +- `null` (default) - returns all data in a single JSON response +- `"sse"` - returns data as Server-Sent Events stream +- `"jsonlines"` - returns data as JSON Lines stream (one JSON object per line) + +Use `"sse"` or `"jsonlines"` for streaming large result sets to avoid loading all data into memory at once. + + + + + +Optional context object, e.g., `{"db": "mindsdb"}` to specify the database. + + + + + +Optional parameters for parameterized queries, e.g., `{"name": "value"}`. + + + ### Response @@ -55,9 +78,32 @@ curl --request POST \ { "query": "SELECT * FROM example_db.demo_data.home_rentals LIMIT 10;" } +' +``` +```shell Shell (Streaming with SSE) +curl --request POST \ + --url https://cloud.mindsdb.com/api/sql/query \ + --header 'Content-Type: application/json' \ + --data ' +{ + "query": "SELECT * FROM example_db.demo_data.home_rentals;", + "response_format": "sse" +} +' ``` +```shell Shell (Streaming with JSON Lines) +curl --request POST \ + --url https://cloud.mindsdb.com/api/sql/query \ + --header 'Content-Type: application/json' \ + --data ' +{ + "query": "SELECT * FROM example_db.demo_data.home_rentals;", + "response_format": "jsonlines" +} +' +``` ```python Python import requests @@ -70,8 +116,8 @@ resp = requests.post(url, json={'query': -```json Response - { +```json Response (Default) +{ "column_names": [ "sqft", "rental_price" @@ -90,7 +136,22 @@ resp = requests.post(url, json={'query': ] ], "type": "table" - } +} +``` + +```text Response (SSE format) +data: {"type": "table", "column_names": ["sqft", "rental_price"], "context": {"db": "mindsdb"}} + +data: [[917, 3901], [194, 2042]] + +data: [[543, 1871], [289, 1563]] + +``` + +```text Response (JSON Lines format) +{"type": "table", "column_names": ["sqft", "rental_price"], "context": {"db": "mindsdb"}} +[[917, 3901], [194, 2042]] +[[543, 1871], [289, 1563]] ``` diff --git a/mindsdb/api/executor/command_executor.py b/mindsdb/api/executor/command_executor.py index 3cc2e5f50b5..deacf21c0cf 100644 --- a/mindsdb/api/executor/command_executor.py +++ b/mindsdb/api/executor/command_executor.py @@ -75,7 +75,8 @@ import mindsdb.utilities.profiler as profiler -from mindsdb.api.executor.sql_query.result_set import Column, ResultSet +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.utilities.types.column import Column from mindsdb.api.executor.sql_query import SQLQuery from mindsdb.api.executor.data_types.answer import ExecuteAnswer from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import ( diff --git a/mindsdb/api/executor/data_types/sql_answer.py b/mindsdb/api/executor/data_types/sql_answer.py new file mode 100644 index 00000000000..0a8b6087dbf --- /dev/null +++ b/mindsdb/api/executor/data_types/sql_answer.py @@ -0,0 +1,129 @@ +from typing import Generator +from dataclasses import dataclass + +import orjson +import numpy as np +import pandas as pd + +from mindsdb.utilities.json_encoder import CustomJSONEncoder +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE +from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE + + +@dataclass +class SQLAnswer: + """Container for SQL query execution results and metadata. + + Attributes: + resp_type: Type of response (OK, ERROR, TABLE, COLUMNS_TABLE). + result_set: Query result data as a ResultSet object. + status: Status code for the response. + state_track: List of state tracking information. + error_code: Error code if query execution failed. + error_message: Human-readable error message if query failed. + affected_rows: Number of rows affected by the query (for DML operations). + mysql_types: List of MySQL data types for result columns. + """ + + resp_type: RESPONSE_TYPE = RESPONSE_TYPE.OK + result_set: ResultSet | None = None + status: int | None = None + state_track: list[list] | None = None + error_code: int | None = None + error_message: str | None = None + affected_rows: int | None = None + mysql_types: list[MYSQL_DATA_TYPE] | None = None + + @property + def type(self) -> RESPONSE_TYPE: + """Get the response type. + + Returns: + RESPONSE_TYPE: The type of this SQL response. + """ + return self.resp_type + + def stream_http_response_sse(self, context: dict | None) -> Generator[str, None, None]: + """Stream response in Server-Sent Events (SSE) format. + + Args: + context: Optional context information. + + Yields: + str: SSE-formatted data lines (prefixed with "data: "). + """ + for piece in self.stream_http_response_jsonlines(context=context): + yield f"data: {piece}\n" + + def stream_http_response_jsonlines(self, context: dict | None) -> Generator[str, None, None]: + """Stream response as newline-delimited JSON (JSONL). + + Args: + context: Optional context information. + + Yields: + str: JSON-encoded lines terminated with newline characters. + """ + _default_json = CustomJSONEncoder().default + + if self.resp_type in (RESPONSE_TYPE.OK, RESPONSE_TYPE.ERROR): + response = self.dump_http_response(context=context) + yield orjson.dumps(response).decode() + "\n" + return + + yield ( + orjson.dumps( + { + "type": RESPONSE_TYPE.TABLE, + "column_names": [column.alias or column.name for column in self.result_set.columns], + } + ).decode() + + "\n" + ) + + for el in self.result_set.stream_data(): + el.replace([np.nan, pd.NA, pd.NaT], None, inplace=True) + yield ( + orjson.dumps( + el.to_dict("split")["data"], + default=_default_json, + option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME, + ).decode() + + "\n" + ) + + def dump_http_response(self, context: dict | None = None) -> dict: + """Serialize the complete response as a single dictionary. + + Args: + context: Optional context information. + + Returns: + dict: Serialized response. + """ + if context is None: + context = {} + if self.resp_type == RESPONSE_TYPE.OK: + return { + "type": self.resp_type, + "affected_rows": self.affected_rows, + "context": context, + } + elif self.resp_type in (RESPONSE_TYPE.TABLE, RESPONSE_TYPE.COLUMNS_TABLE): + data = self.result_set.to_lists(json_types=True) + return { + "type": RESPONSE_TYPE.TABLE, + "data": data, + "column_names": [column.alias or column.name for column in self.result_set.columns], + "context": context, + } + elif self.resp_type == RESPONSE_TYPE.ERROR: + return { + "type": RESPONSE_TYPE.ERROR, + "error_code": self.error_code or 0, + "error_message": self.error_message, + "context": context, + } + else: + raise ValueError(f"Unsupported response type for dump HTTP response: {self.resp_type}") diff --git a/mindsdb/api/executor/datahub/classes/response.py b/mindsdb/api/executor/datahub/classes/response.py deleted file mode 100644 index cd0e990ed71..00000000000 --- a/mindsdb/api/executor/datahub/classes/response.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass, field -from typing import List, Dict - -import pandas as pd - -from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE - - -@dataclass -class DataHubResponse: - data_frame: pd.DataFrame = field(default_factory=pd.DataFrame) - columns: List[Dict] = field(default_factory=list) - affected_rows: int | None = None - mysql_types: list[MYSQL_DATA_TYPE] | None = None diff --git a/mindsdb/api/executor/datahub/datanodes/datanode.py b/mindsdb/api/executor/datahub/datanodes/datanode.py index 256760fc959..8be9e355949 100644 --- a/mindsdb/api/executor/datahub/datanodes/datanode.py +++ b/mindsdb/api/executor/datahub/datanodes/datanode.py @@ -1,10 +1,11 @@ from pandas import DataFrame -from mindsdb.api.executor.datahub.classes.response import DataHubResponse +from mindsdb.integrations.libs.response import DataHandlerResponse class DataNode: type = "meta" + has_support_stream = False def __init__(self): pass @@ -21,5 +22,5 @@ def get_table_columns_df(self, table_name: str, schema_name: str | None = None) def get_table_columns_names(self, table_name: str, schema_name: str | None = None) -> list[str]: pass - def query(self, query=None, session=None) -> DataHubResponse: + def query(self, query=None, session=None) -> DataHandlerResponse: pass diff --git a/mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py b/mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py index 4eabef3d7d7..ac309f72e6d 100644 --- a/mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py +++ b/mindsdb/api/executor/datahub/datanodes/information_schema_datanode.py @@ -6,12 +6,13 @@ from mindsdb.api.executor.datahub.datanodes.datanode import DataNode from mindsdb.api.executor.datahub.datanodes.integration_datanode import IntegrationDataNode from mindsdb.api.executor.datahub.datanodes.project_datanode import ProjectDataNode -from mindsdb.api.executor import exceptions as exc +from mindsdb.api.executor.datahub.classes.tables_row import TablesRow from mindsdb.api.executor.utilities.sql import query_df from mindsdb.api.executor.utilities.sql import get_query_tables +from mindsdb.api.executor import exceptions as exc from mindsdb.interfaces.database.projects import ProjectController -from mindsdb.api.executor.datahub.classes.response import DataHubResponse -from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES +from mindsdb.integrations.libs.response import TableResponse, INF_SCHEMA_COLUMNS_NAMES +from mindsdb.utilities.types.column import Column from mindsdb.utilities import log from .system_tables import ( @@ -47,8 +48,6 @@ TriggersTable, ) -from mindsdb.api.executor.datahub.classes.tables_row import TablesRow - logger = log.getLogger(__name__) @@ -206,7 +205,7 @@ def get_tables(self): def get_tree_tables(self): return {name: table for name, table in self.tables.items() if table.visible} - def query(self, query: ASTNode, session=None) -> DataHubResponse: + def query(self, query: ASTNode, session=None) -> TableResponse: query_tables = [x[1] for x in get_query_tables(query)] if len(query_tables) != 1: @@ -225,9 +224,8 @@ def query(self, query: ASTNode, session=None) -> DataHubResponse: dataframe = self._get_empty_table(tbl) data = query_df(dataframe, query, session=self.session) - columns_info = [{"name": k, "type": v} for k, v in data.dtypes.items()] - - return DataHubResponse(data_frame=data, columns=columns_info, affected_rows=0) + columns = [Column(name=k, dtype=v) for k, v in data.dtypes.items()] + return TableResponse(data=data, columns=columns, affected_rows=0) def _get_empty_table(self, table): columns = table.columns diff --git a/mindsdb/api/executor/datahub/datanodes/integration_datanode.py b/mindsdb/api/executor/datahub/datanodes/integration_datanode.py index 228bd29468c..2175db5d1a3 100644 --- a/mindsdb/api/executor/datahub/datanodes/integration_datanode.py +++ b/mindsdb/api/executor/datahub/datanodes/integration_datanode.py @@ -2,27 +2,24 @@ import inspect import functools from dataclasses import astuple -from typing import Iterable, List -import numpy as np import pandas as pd from sqlalchemy.types import Integer, Float from mindsdb_sql_parser.ast.base import ASTNode from mindsdb_sql_parser.ast import Insert, Identifier, CreateTable, TableColumn, DropTables -from mindsdb.api.executor.datahub.classes.response import DataHubResponse from mindsdb.api.executor.datahub.datanodes.datanode import DataNode +from mindsdb.api.executor.datahub.datanodes.system_tables import infer_mysql_type from mindsdb.api.executor.datahub.classes.tables_row import TablesRow from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE from mindsdb.api.executor.sql_query.result_set import ResultSet -from mindsdb.integrations.libs.response import HandlerResponse, INF_SCHEMA_COLUMNS_NAMES +from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES, DataHandlerResponse, ErrorResponse, OkResponse from mindsdb.integrations.utilities.utils import get_class_name from mindsdb.metrics import metrics from mindsdb.utilities import log from mindsdb.utilities.profiler import profiler from mindsdb.utilities.exception import QueryError -from mindsdb.api.executor.datahub.datanodes.system_tables import infer_mysql_type logger = log.getLogger(__name__) @@ -57,9 +54,9 @@ def wrapper(self, *args, **kwargs): query_time_with_labels = metrics.INTEGRATION_HANDLER_QUERY_TIME.labels(handler_class_name, result.type) query_time_with_labels.observe(elapsed_seconds) - num_rows = 0 - if result.data_frame is not None: - num_rows = len(result.data_frame.index) + num_rows = getattr(result, "affected_rows", None) + if num_rows is None: + num_rows = -1 response_size_with_labels = metrics.INTEGRATION_HANDLER_RESPONSE_SIZE.labels( handler_class_name, result.type ) @@ -164,12 +161,12 @@ def create_table( self, table_name: Identifier, result_set: ResultSet = None, - columns: List[TableColumn] = None, + columns: list[TableColumn] = None, is_replace: bool = False, is_create: bool = False, raise_if_exists: bool = True, **kwargs, - ) -> DataHubResponse: + ) -> OkResponse: # is_create - create table # if !raise_if_exists: error will be skipped # is_replace - drop table if exists @@ -197,18 +194,18 @@ def create_table( if result_set is None: # it is just a 'create table' - return DataHubResponse() + return OkResponse() # native insert if hasattr(self.integration_handler, "insert"): df = result_set.to_df() - result: HandlerResponse = self.integration_handler.insert(table_name.parts[-1], df) + result: DataHandlerResponse = self.integration_handler.insert(table_name.parts[-1], df) if result is not None: affected_rows = result.affected_rows else: affected_rows = None - return DataHubResponse(affected_rows=affected_rows) + return OkResponse(affected_rows=affected_rows) insert_columns = [Identifier(parts=[x.alias]) for x in result_set.columns] @@ -232,29 +229,28 @@ def create_table( if len(values) == 0: # not need to insert - return DataHubResponse() + return OkResponse() insert_ast = Insert(table=table_name, columns=insert_columns, values=values, is_plain=True) try: - result: DataHubResponse = self.query(insert_ast) + result: DataHandlerResponse = self.query(insert_ast) except Exception as e: msg = f"[{self.ds_type}/{self.integration_name}]: {str(e)}" raise DBHandlerException(msg) from e - return DataHubResponse(affected_rows=result.affected_rows) + return OkResponse(affected_rows=result.affected_rows) def has_support_stream(self) -> bool: - # checks if data handler has query_stream method - return hasattr(self.integration_handler, "query_stream") and callable(self.integration_handler.query_stream) + """Check if the integration handler supports streaming responses. - @profiler.profile() - def query_stream(self, query: ASTNode, fetch_size: int = None) -> Iterable: - # returns generator of results from handler (split by chunks) - return self.integration_handler.query_stream(query, fetch_size=fetch_size) + Returns: + bool: True if the integration handler supports streaming responses, False otherwise. + """ + return getattr(self.integration_handler, "stream_response", False) @profiler.profile() - def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: + def query(self, query: ASTNode | str = None, session=None) -> DataHandlerResponse: """Execute a query against the integration data source. This method processes SQL queries either as ASTNode objects or raw SQL strings @@ -266,20 +262,20 @@ def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: session: Session object (currently unused but kept for compatibility) Returns: - DataHubResponse: Response object + DataHandlerResponse: Response object Raises: NotImplementedError: If query is not ASTNode or str type Exception: If the query execution fails with an error response """ if isinstance(query, ASTNode): - result: HandlerResponse = self.query_integration_handler(query=query) + result: DataHandlerResponse = self.query_integration_handler(query=query) elif isinstance(query, str): - result: HandlerResponse = self.native_query_integration(query=query) + result: DataHandlerResponse = self.native_query_integration(query=query) else: raise NotImplementedError("Thew query argument must be ASTNode or string type") - if result.type == RESPONSE_TYPE.ERROR: + if type(result) is ErrorResponse: if isinstance(query, ASTNode): try: query_str = query.to_string() @@ -302,32 +298,12 @@ def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: else: raise exception from result.exception - if result.type == RESPONSE_TYPE.OK: - return DataHubResponse(affected_rows=result.affected_rows) - - df = result.data_frame - # region clearing df from NaN values - # recursion error appears in pandas 1.5.3 https://github.com/pandas-dev/pandas/pull/45749 - if isinstance(df, pd.Series): - df = df.to_frame() - - columns_info = [{"name": k, "type": v} for k, v in df.dtypes.items()] - try: - # replace python's Nan, np.nan and pd.NA to None - # TODO keep all NAN to the end of processing, bacause replacing also changes dtypes - df.replace([np.nan, pd.NA, pd.NaT], None, inplace=True) - except Exception: - logger.exception("Issue with clearing DF from NaN values:") - # endregion - - return DataHubResponse( - data_frame=df, columns=columns_info, affected_rows=result.affected_rows, mysql_types=result.mysql_types - ) + return result @collect_metrics - def query_integration_handler(self, query: ASTNode) -> HandlerResponse: + def query_integration_handler(self, query: ASTNode) -> DataHandlerResponse: return self.integration_handler.query(query) @collect_metrics - def native_query_integration(self, query: str) -> HandlerResponse: + def native_query_integration(self, query: str) -> DataHandlerResponse: return self.integration_handler.native_query(query) diff --git a/mindsdb/api/executor/datahub/datanodes/project_datanode.py b/mindsdb/api/executor/datahub/datanodes/project_datanode.py index 12dd98d7d23..21e07d65d83 100644 --- a/mindsdb/api/executor/datahub/datanodes/project_datanode.py +++ b/mindsdb/api/executor/datahub/datanodes/project_datanode.py @@ -13,12 +13,12 @@ Delete, ) -from mindsdb.utilities.exception import EntityNotExistsError from mindsdb.api.executor.datahub.datanodes.datanode import DataNode from mindsdb.api.executor.datahub.classes.tables_row import TablesRow -from mindsdb.api.executor.datahub.classes.response import DataHubResponse +from mindsdb.utilities.exception import EntityNotExistsError +from mindsdb.utilities.types.column import Column from mindsdb.utilities.partitioning import process_dataframe_in_partitions -from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES +from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES, DataHandlerResponse, OkResponse, TableResponse class ProjectDataNode(DataNode): @@ -100,7 +100,7 @@ def callback(chunk): return ml_handler.predict(model_name, df, project_name=self.project.name, version=version, params=params) - def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: + def query(self, query: ASTNode | str = None, session=None) -> DataHandlerResponse: if isinstance(query, str): query = parse_sql(query) @@ -110,7 +110,7 @@ def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: if kb_table: # this is the knowledge db kb_table.update_query(query) - return DataHubResponse() + return OkResponse() raise NotImplementedError(f"Can't update object: {query_table}") @@ -120,7 +120,7 @@ def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: if kb_table: # this is the knowledge db kb_table.delete_query(query) - return DataHubResponse() + return OkResponse() raise NotImplementedError(f"Can't delete object: {query_table}") @@ -157,17 +157,15 @@ def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: # this is the view df = self.project.query_view(query, session) - columns_info = [{"name": k, "type": v} for k, v in df.dtypes.items()] - - return DataHubResponse(data_frame=df, columns=columns_info) + columns = [Column(name=k, dtype=v) for k, v in df.dtypes.items()] + return TableResponse(data=df, columns=columns) kb_table = session.kb_controller.get_table(query_table, self.project.id) if kb_table: # this is the knowledge db df = kb_table.select_query(query) - columns_info = [{"name": k, "type": v} for k, v in df.dtypes.items()] - - return DataHubResponse(data_frame=df, columns=columns_info) + columns = [Column(name=k, dtype=v) for k, v in df.dtypes.items()] + return TableResponse(data=df, columns=columns) raise EntityNotExistsError(f"Table '{query_table}' not found in database", self.project.name) else: @@ -175,7 +173,7 @@ def query(self, query: ASTNode | str = None, session=None) -> DataHubResponse: def create_table( self, table_name: Identifier, result_set=None, is_replace=False, params=None, is_create=None, **kwargs - ) -> DataHubResponse: + ) -> OkResponse: # is_create - create table # is_replace - drop table if exists # is_create==False and is_replace==False: just insert @@ -196,6 +194,5 @@ def create_table( df = result_set.to_df() kb_table.insert(df, params=params) - return DataHubResponse() - + return OkResponse() raise ValueError(f"Table or Knowledge Base '{table_name}' doesn't exist") diff --git a/mindsdb/api/executor/sql_query/result_set.py b/mindsdb/api/executor/sql_query/result_set.py index 4d037af7bff..f3b22e13e63 100644 --- a/mindsdb/api/executor/sql_query/result_set.py +++ b/mindsdb/api/executor/sql_query/result_set.py @@ -1,7 +1,6 @@ import copy from array import array -from typing import Any -from dataclasses import dataclass, field, MISSING +from typing import Any, Generator import numpy as np import pandas as pd @@ -11,8 +10,10 @@ from mindsdb_sql_parser.ast import TableColumn from mindsdb.utilities import log +from mindsdb.utilities.types.column import Column from mindsdb.api.executor.exceptions import WrongArgumentError from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE +from mindsdb.integrations.libs.response import TableResponse logger = log.getLogger(__name__) @@ -57,31 +58,6 @@ def _dump_vector(value: Any) -> Any: return value -@dataclass(kw_only=True, slots=True) -class Column: - name: str = field(default=MISSING) - alias: str | None = None - table_name: str | None = None - table_alias: str | None = None - type: MYSQL_DATA_TYPE | None = None - database: str | None = None - flags: dict = None - charset: str | None = None - - def __post_init__(self): - if self.alias is None: - self.alias = self.name - if self.table_alias is None: - self.table_alias = self.table_name - - def get_hash_name(self, prefix): - table_name = self.table_name if self.table_alias is None else self.table_alias - name = self.name if self.alias is None else self.alias - - name = f"{prefix}_{table_name}_{name}" - return name - - def rename_df_columns(df: pd.DataFrame, names: list | None = None) -> None: """Inplace rename of dataframe columns @@ -104,6 +80,7 @@ def __init__( affected_rows: int | None = None, is_prediction: bool = False, mysql_types: list[MYSQL_DATA_TYPE] | None = None, + table_response: TableResponse = None, ): """ Args: @@ -112,9 +89,13 @@ def __init__( df (pd.DataFrame): injected dataframe, have to have enumerated columns and length equal to columns affected_rows (int): number of affected rows """ - if columns is None: - columns = [] - self._columns = columns + self._table_response: TableResponse = table_response + if table_response: + self._columns = table_response.columns + elif columns is None: + self._columns = [] + else: + self._columns = columns if df is None: if values is None: @@ -132,15 +113,19 @@ def __init__( def __repr__(self): col_names = ", ".join([col.name for col in self._columns]) + if self._table_response is not None: + return f"{self.__class__.__name__}(table response, cols: {col_names})" return f"{self.__class__.__name__}({self.length()} rows, cols: {col_names})" def __len__(self) -> int: + self._load_table_response() if self._df is None: return 0 return len(self._df) def __getitem__(self, slice_val): # return resultSet with sliced dataframe + self._load_table_response() df = self._df[slice_val] return ResultSet(columns=self.columns, df=df) @@ -170,6 +155,10 @@ def from_df( rename_df_columns(df) return cls(df=df, columns=columns, is_prediction=is_prediction, mysql_types=mysql_types) + @classmethod + def from_table_response(cls, table_response): + return cls(table_response=table_response) + @classmethod def from_df_cols(cls, df: pd.DataFrame, columns_dict: dict[str, Column], strict: bool = True) -> "ResultSet": """Create ResultSet from dataframe and dictionary of columns @@ -251,6 +240,7 @@ def get_col_index(self, col): return col_idx def add_column(self, col, values=None): + self._load_table_response() self._columns.append(col) col_idx = len(self._columns) - 1 @@ -259,6 +249,7 @@ def add_column(self, col, values=None): return col_idx def del_column(self, col): + self._load_table_response() idx = self.get_col_index(col) self._columns.pop(idx) @@ -296,27 +287,56 @@ def copy_column_to(self, col, result_set2): return col2 def set_col_type(self, col_idx, type_name): + self._load_table_response() self.columns[col_idx].type = type_name if self._df is not None: self._df[col_idx] = self._df[col_idx].astype(type_name) # --- records --- + def _load_table_response(self): + """Fully load the table response by fetching all data from the table response and storing it in the _df attribute.""" + if self._table_response is None: + return + + self._table_response.fetchall() + if self._df is None: + self._df = self._table_response._data + else: + self._df = pd.concat([self._df, self._table_response._data]) + self._table_response = None + + def stream_data(self) -> Generator[pd.DataFrame, None, None]: + """Stream data from the result set. + + Yields: + pd.DataFrame: Data frame. + """ + if self._df is not None: + yield self._df + else: + for el in self._table_response.iterate_no_save(): + yield el + def get_raw_df(self): + self._load_table_response() + names = range(len(self._columns)) if self._df is None: - names = range(len(self._columns)) return pd.DataFrame([], columns=names) + self._df.columns = names return self._df def add_raw_df(self, df): if len(df.columns) != len(self._columns): raise WrongArgumentError(f"Record length mismatch columns length: {len(df.columns)} != {len(self.columns)}") + self._load_table_response() rename_df_columns(df) if self._df is None: self._df = df else: + rename_df_columns(self._df) self._df = pd.concat([self._df, df], ignore_index=True) def add_raw_values(self, values): @@ -341,6 +361,7 @@ def get_ast_columns(self) -> list[TableColumn]: list[TableColumn]: A list of TableColumn objects with properly mapped SQLAlchemy types """ columns: list[TableColumn] = [] + self._load_table_response() type_mapping = { MYSQL_DATA_TYPE.TINYINT: sqlalchemy_types.INTEGER, @@ -382,6 +403,7 @@ def to_lists(self, json_types=False): array->list, datetime64->str :return: list of lists """ + self._load_table_response() if len(self.get_raw_df()) == 0: return [] @@ -408,6 +430,7 @@ def get_column_values(self, col_idx): def set_column_values(self, col_name, values): # values is one value or list of values + self._load_table_response() cols = self.find_columns(col_name) if len(cols) == 0: col_idx = self.add_column(Column(name=col_name)) @@ -424,7 +447,7 @@ def add_from_result_set(self, rs): for name in self.get_column_names(): col_sequence.append(source_names.index(name)) - raw_df = rs.get_raw_df()[col_sequence] + raw_df = rs.get_raw_df().iloc[:, col_sequence] self.add_raw_df(raw_df) diff --git a/mindsdb/api/executor/sql_query/sql_query.py b/mindsdb/api/executor/sql_query/sql_query.py index 7adecf15a86..763db9cceac 100644 --- a/mindsdb/api/executor/sql_query/sql_query.py +++ b/mindsdb/api/executor/sql_query/sql_query.py @@ -33,15 +33,15 @@ UnknownError, LogicError, ) +from mindsdb.interfaces.query_context.context_controller import query_context_controller import mindsdb.utilities.profiler as profiler from mindsdb.utilities.fs import create_process_mark, delete_process_mark from mindsdb.utilities.exception import EntityNotExistsError -from mindsdb.interfaces.query_context.context_controller import query_context_controller from mindsdb.utilities.context import context as ctx - +from mindsdb.utilities.types.column import Column from . import steps -from .result_set import ResultSet, Column +from .result_set import ResultSet from .steps.base import BaseStepCall @@ -288,7 +288,7 @@ def execute_query(self): ctx.run_query_id = self.run_query.record.id - step_result = None + step_result: list[ResultSet] = None process_mark = None try: steps_classes = (x.__class__ for x in steps) @@ -323,10 +323,6 @@ def execute_query(self): self.fetched_data = step_result try: - if hasattr(self, "columns_list") is False: - # how it becomes False? - self.columns_list = self.fetched_data.columns - if self.columns_list is None: self.columns_list = self.fetched_data.columns diff --git a/mindsdb/api/executor/sql_query/steps/apply_predictor_step.py b/mindsdb/api/executor/sql_query/steps/apply_predictor_step.py index 50a0c646e41..a12e56f80fb 100644 --- a/mindsdb/api/executor/sql_query/steps/apply_predictor_step.py +++ b/mindsdb/api/executor/sql_query/steps/apply_predictor_step.py @@ -19,7 +19,8 @@ ApplyPredictorStep, ) -from mindsdb.api.executor.sql_query.result_set import ResultSet, Column +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.utilities.types.column import Column from mindsdb.utilities.cache import get_cache, dataframe_checksum from .base import BaseStepCall diff --git a/mindsdb/api/executor/sql_query/steps/fetch_dataframe.py b/mindsdb/api/executor/sql_query/steps/fetch_dataframe.py index b81215b01cb..d73666e49e3 100644 --- a/mindsdb/api/executor/sql_query/steps/fetch_dataframe.py +++ b/mindsdb/api/executor/sql_query/steps/fetch_dataframe.py @@ -11,12 +11,12 @@ ) from mindsdb.api.executor.planner.steps import FetchDataframeStep -from mindsdb.api.executor.datahub.classes.response import DataHubResponse from mindsdb.api.executor.sql_query.result_set import ResultSet from mindsdb.api.executor.planner.step_result import Result from mindsdb.api.executor.exceptions import UnknownError -from mindsdb.integrations.utilities.query_traversal import query_traversal from mindsdb.interfaces.query_context.context_controller import query_context_controller +from mindsdb.integrations.utilities.query_traversal import query_traversal +from mindsdb.integrations.libs.response import TableResponse from .base import BaseStepCall @@ -92,7 +92,7 @@ def call(self, step): if query is None: table_alias = (self.context.get("database"), "result", "result") - response: DataHubResponse = dn.query(step.raw_query, session=self.session) + response: TableResponse = dn.query(step.raw_query, session=self.session) df = response.data_frame else: if isinstance(step.query, (Union, Intersect)): @@ -108,11 +108,15 @@ def call(self, step): query, context_callback = query_context_controller.handle_db_context_vars(query, dn, self.session) - response: DataHubResponse = dn.query(query=query, session=self.session) - df = response.data_frame - + response: TableResponse = dn.query(query=query, session=self.session) + response.set_columns_attrs( + table_name=table_alias[1], + table_alias=table_alias[2], + database=table_alias[0], + ) if context_callback: - context_callback(df, response.columns) + context_callback(response.data_frame, response.columns) + return ResultSet.from_table_response(response) # if query registered, set progress if self.sql_query.run_query is not None: @@ -122,5 +126,5 @@ def call(self, step): table_name=table_alias[1], table_alias=table_alias[2], database=table_alias[0], - mysql_types=response.mysql_types, + mysql_types=[column.type for column in response.columns], ) diff --git a/mindsdb/api/executor/sql_query/steps/insert_step.py b/mindsdb/api/executor/sql_query/steps/insert_step.py index 2144521dca7..d7ea17cd6cb 100644 --- a/mindsdb/api/executor/sql_query/steps/insert_step.py +++ b/mindsdb/api/executor/sql_query/steps/insert_step.py @@ -1,7 +1,8 @@ from mindsdb_sql_parser.ast import Identifier, Function from mindsdb.api.executor.planner.steps import SaveToTable, InsertToTable, CreateTableStep -from mindsdb.api.executor.sql_query.result_set import ResultSet, Column +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.utilities.types.column import Column from mindsdb.utilities.exception import EntityNotExistsError from mindsdb.api.executor.exceptions import NotSupportedYet, LogicError from mindsdb.integrations.libs.response import INF_SCHEMA_COLUMNS_NAMES diff --git a/mindsdb/api/executor/sql_query/steps/prepare_steps.py b/mindsdb/api/executor/sql_query/steps/prepare_steps.py index b846d4f66b2..7b2950a8e5f 100644 --- a/mindsdb/api/executor/sql_query/steps/prepare_steps.py +++ b/mindsdb/api/executor/sql_query/steps/prepare_steps.py @@ -9,18 +9,18 @@ GetTableColumns, ) -from mindsdb.api.executor.sql_query.result_set import ResultSet, Column +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.utilities.types.column import Column from mindsdb.utilities.config import config from .base import BaseStepCall class GetPredictorColumnsCall(BaseStepCall): - bind = GetPredictorColumns def call(self, step): - mindsdb_database_name = config.get('default_project') + mindsdb_database_name = config.get("default_project") predictor_name = step.predictor.parts[-1] dn = self.session.datahub.get(mindsdb_database_name) @@ -28,20 +28,14 @@ def call(self, step): data = ResultSet() for column_name in columns_names: - data.add_column(Column( - name=column_name, - table_name=predictor_name, - database=mindsdb_database_name - )) + data.add_column(Column(name=column_name, table_name=predictor_name, database=mindsdb_database_name)) return data class GetTableColumnsCall(BaseStepCall): - bind = GetTableColumns def call(self, step): - table = step.table dn = self.session.datahub.get(step.namespace) ds_query = Select(from_table=Identifier(table), targets=[Star()], limit=Constant(0)) @@ -50,10 +44,12 @@ def call(self, step): data = ResultSet() for column in response.columns: - data.add_column(Column( - name=column['name'], - type=column.get('type'), - table_name=table, - database=self.context.get('database') - )) + data.add_column( + Column( + name=column["name"], + type=column.get("type"), + table_name=table, + database=self.context.get("database"), + ) + ) return data diff --git a/mindsdb/api/executor/sql_query/steps/subselect_step.py b/mindsdb/api/executor/sql_query/steps/subselect_step.py index 8e4e5725cf9..ac07cd47d3f 100644 --- a/mindsdb/api/executor/sql_query/steps/subselect_step.py +++ b/mindsdb/api/executor/sql_query/steps/subselect_step.py @@ -15,7 +15,8 @@ from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import SERVER_VARIABLES from mindsdb.api.executor.planner.step_result import Result from mindsdb.api.executor.planner.steps import SubSelectStep, QueryStep -from mindsdb.api.executor.sql_query.result_set import ResultSet, Column +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.utilities.types.column import Column from mindsdb.api.executor.utilities.sql import query_df from mindsdb.api.executor.exceptions import KeyColumnDoesNotExist from mindsdb.integrations.utilities.query_traversal import query_traversal diff --git a/mindsdb/api/http/namespaces/sql.py b/mindsdb/api/http/namespaces/sql.py index 934f89dbbe9..39e53cc431c 100644 --- a/mindsdb/api/http/namespaces/sql.py +++ b/mindsdb/api/http/namespaces/sql.py @@ -1,8 +1,9 @@ import time +from enum import Enum from http import HTTPStatus from collections import defaultdict -from flask import request +from flask import request, Response from flask_restx import Resource from mindsdb_sql_parser import parse_sql @@ -12,15 +13,12 @@ import mindsdb.utilities.profiler as profiler from mindsdb.api.http.utils import http_error from mindsdb.api.http.namespaces.configs.sql import ns_conf -from mindsdb.api.mysql.mysql_proxy.mysql_proxy import SQLAnswer from mindsdb.api.mysql.mysql_proxy.classes.fake_mysql_proxy import FakeMysqlProxy -from mindsdb.api.executor.data_types.response_type import ( - RESPONSE_TYPE as SQL_RESPONSE_TYPE, -) +from mindsdb.api.executor.data_types.sql_answer import SQLAnswer +from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE as SQL_RESPONSE_TYPE from mindsdb.api.executor.sql_query.result_set import ResultSet - -from mindsdb.integrations.utilities.query_traversal import query_traversal from mindsdb.api.executor.exceptions import ExecutorException, UnknownError +from mindsdb.integrations.utilities.query_traversal import query_traversal from mindsdb.metrics.metrics import api_endpoint_metrics from mindsdb.utilities import log from mindsdb.utilities.config import Config @@ -32,6 +30,12 @@ logger = log.getLogger(__name__) +class ReponseFormat(Enum): + DEFAULT = None + SSE = "sse" + JSONLINES = "jsonlines" + + @ns_conf.route("/query") @ns_conf.param("query", "Execute query") class Query(Resource): @@ -45,8 +49,15 @@ def post(self): start_time = time.time() query = request.json["query"] context = request.json.get("context", {}) + if "params" in request.json: ctx.params = request.json["params"] + + try: + response_format = ReponseFormat(request.json.get("response_format", None)) + except ValueError: + return http_error(HTTPStatus.BAD_REQUEST, "Invalid stream format", "Please provide a valid stream format.") + if isinstance(query, str) is False or isinstance(context, dict) is False: return http_error(HTTPStatus.BAD_REQUEST, "Wrong arguments", 'Please provide "query" with the request.') logger.debug(f"Incoming query: {query}") @@ -55,8 +66,6 @@ def post(self): profiler.enable() error_type = None - error_code = None - error_text = None error_traceback = None profiler.set_meta(query=query, api="http", environment=Config().get("environment")) @@ -95,58 +104,49 @@ def post(self): } query_response["context"] = mysql_proxy.get_context() - + query_response = query_response, 200 else: try: result: SQLAnswer = mysql_proxy.process_query(query) - query_response: dict = result.dump_http_response() except ExecutorException as e: # classified error error_type = "expected" - query_response = { - "type": SQL_RESPONSE_TYPE.ERROR, - "error_code": 0, - "error_message": str(e), - } + result = SQLAnswer( + resp_type=SQL_RESPONSE_TYPE.ERROR, + error_code=0, + error_message=str(e), + ) logger.warning(f"Error query processing: {e}") except QueryError as e: error_type = "expected" if e.is_expected else "unexpected" - query_response = { - "type": SQL_RESPONSE_TYPE.ERROR, - "error_code": 0, - "error_message": str(e), - } + result = SQLAnswer( + resp_type=SQL_RESPONSE_TYPE.ERROR, + error_code=0, + error_message=str(e), + ) if e.is_expected: logger.warning(f"Query failed due to expected reason: {e}") else: logger.exception("Error query processing:") - except UnknownError as e: - # unclassified - error_type = "unexpected" - query_response = { - "type": SQL_RESPONSE_TYPE.ERROR, - "error_code": 0, - "error_message": str(e), - } - logger.exception("Error query processing:") - - except Exception as e: + except (UnknownError, Exception) as e: error_type = "unexpected" - query_response = { - "type": SQL_RESPONSE_TYPE.ERROR, - "error_code": 0, - "error_message": str(e), - } + result = SQLAnswer( + resp_type=SQL_RESPONSE_TYPE.ERROR, + error_code=0, + error_message=str(e), + ) logger.exception("Error query processing:") - if query_response.get("type") == SQL_RESPONSE_TYPE.ERROR: - error_type = "expected" - error_code = query_response.get("error_code") - error_text = query_response.get("error_message") - context = mysql_proxy.get_context() - query_response["context"] = context + if response_format == ReponseFormat.JSONLINES: + query_response = result.stream_http_response_jsonlines(context=context) + query_response = Response(query_response, mimetype="application/jsonlines") + elif response_format == ReponseFormat.SSE: + query_response = result.stream_http_response_sse(context=context) + query_response = Response(query_response, mimetype="text/event-stream") + else: + query_response = result.dump_http_response(context=context), 200 hooks.after_api_query( company_id=ctx.company_id, @@ -155,21 +155,23 @@ def post(self): command=None, payload=query, error_type=error_type, - error_code=error_code, - error_text=error_text, + error_code=result.error_code, + error_text=result.error_message, traceback=error_traceback, ) end_time = time.time() - log_msg = f"SQL processed in {(end_time - start_time):.2f}s ({end_time:.2f}-{start_time:.2f}), result is {query_response['type']}" - if query_response["type"] is SQL_RESPONSE_TYPE.TABLE: - log_msg += f" ({len(query_response['data'])} rows), " - elif query_response["type"] is SQL_RESPONSE_TYPE.ERROR: - log_msg += f" ({query_response['error_message']}), " - log_msg += f"used handlers {ctx.used_handlers}" + log_msg = f"SQL processed in {(end_time - start_time):.2f}s ({end_time:.2f}-{start_time:.2f}), result is {result.type}, " + if result.type is SQL_RESPONSE_TYPE.TABLE and response_format is ReponseFormat.DEFAULT: + log_msg += f" one-piece result ({len(query_response[0]['data'])} rows), " + elif result.type is SQL_RESPONSE_TYPE.TABLE: + log_msg += f" {response_format} result, " + elif result.type is SQL_RESPONSE_TYPE.ERROR: + log_msg += f" ({result.error_message}), " + log_msg += f"used handlers: {ctx.used_handlers}" logger.debug(log_msg) - return query_response, 200 + return query_response @ns_conf.route("/charter") diff --git a/mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py b/mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py index 6f3b06387e4..ec9c122f3d6 100644 --- a/mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py +++ b/mindsdb/api/mysql/mysql_proxy/executor/mysql_executor.py @@ -3,7 +3,7 @@ from mindsdb_sql_parser.ast.base import ASTNode import mindsdb.utilities.profiler as profiler from mindsdb.api.executor.sql_query import SQLQuery -from mindsdb.api.executor.sql_query.result_set import Column +from mindsdb.utilities.types.column import Column from mindsdb.api.executor.planner import utils as planner_utils from mindsdb.api.executor.data_types.answer import ExecuteAnswer from mindsdb.api.executor.command_executor import ExecuteCommands diff --git a/mindsdb/api/mysql/mysql_proxy/mysql_proxy.py b/mindsdb/api/mysql/mysql_proxy/mysql_proxy.py index 5fd02915246..8f691db994c 100644 --- a/mindsdb/api/mysql/mysql_proxy/mysql_proxy.py +++ b/mindsdb/api/mysql/mysql_proxy/mysql_proxy.py @@ -22,8 +22,6 @@ import traceback import logging from functools import partial -from typing import List -from dataclasses import dataclass import mindsdb.utilities.hooks as hooks import mindsdb.utilities.profiler as profiler @@ -65,11 +63,12 @@ getConstName, ) from mindsdb.api.executor.data_types.answer import ExecuteAnswer +from mindsdb.api.executor.data_types.sql_answer import SQLAnswer from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE from mindsdb.api.executor import exceptions as executor_exceptions from mindsdb.api.common.middleware import check_auth -from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE -from mindsdb.api.executor.sql_query.result_set import Column, ResultSet +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.utilities.types.column import Column from mindsdb.utilities import log from mindsdb.utilities.config import config from mindsdb.utilities.context import context as ctx @@ -93,44 +92,6 @@ def empty_fn(): pass -@dataclass -class SQLAnswer: - resp_type: RESPONSE_TYPE = RESPONSE_TYPE.OK - result_set: ResultSet | None = None - status: int | None = None - state_track: List[List] | None = None - error_code: int | None = None - error_message: str | None = None - affected_rows: int | None = None - mysql_types: list[MYSQL_DATA_TYPE] | None = None - - @property - def type(self): - return self.resp_type - - def dump_http_response(self) -> dict: - if self.resp_type == RESPONSE_TYPE.OK: - return { - "type": self.resp_type, - "affected_rows": self.affected_rows, - } - elif self.resp_type in (RESPONSE_TYPE.TABLE, RESPONSE_TYPE.COLUMNS_TABLE): - data = self.result_set.to_lists(json_types=True) - return { - "type": RESPONSE_TYPE.TABLE, - "data": data, - "column_names": [column.alias or column.name for column in self.result_set.columns], - } - elif self.resp_type == RESPONSE_TYPE.ERROR: - return { - "type": RESPONSE_TYPE.ERROR, - "error_code": self.error_code or 0, - "error_message": self.error_message, - } - else: - raise ValueError(f"Unsupported response type for dump HTTP response: {self.resp_type}") - - class MysqlTCPServer(SocketServer.ThreadingTCPServer): """ Custom TCP Server with increased request queue size diff --git a/mindsdb/api/mysql/mysql_proxy/utilities/dump.py b/mindsdb/api/mysql/mysql_proxy/utilities/dump.py index f580c7bf714..82fa0a5232f 100644 --- a/mindsdb/api/mysql/mysql_proxy/utilities/dump.py +++ b/mindsdb/api/mysql/mysql_proxy/utilities/dump.py @@ -9,7 +9,8 @@ import pandas as pd from pandas.api import types as pd_types -from mindsdb.api.executor.sql_query.result_set import ResultSet, get_mysql_data_type_from_series, Column +from mindsdb.api.executor.sql_query.result_set import ResultSet, get_mysql_data_type_from_series +from mindsdb.utilities.types.column import Column from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import ( MYSQL_DATA_TYPE, DATA_C_TYPE_MAP, diff --git a/mindsdb/integrations/handlers/dummy_data_handler/dummy_data_handler.py b/mindsdb/integrations/handlers/dummy_data_handler/dummy_data_handler.py index ec205cb9362..6bac43a3e0f 100644 --- a/mindsdb/integrations/handlers/dummy_data_handler/dummy_data_handler.py +++ b/mindsdb/integrations/handlers/dummy_data_handler/dummy_data_handler.py @@ -84,7 +84,7 @@ def get_tables(self) -> HandlerResponse: q = "SHOW TABLES;" result = self.native_query(q) df = result.data_frame - result.data_frame = df.rename(columns={df.columns[0]: "table_name"}) + result._data = df.rename(columns={df.columns[0]: "table_name"}) return result def get_columns(self, table_name: str) -> HandlerResponse: diff --git a/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py b/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py index 2b6c67c4eea..73d0450954d 100644 --- a/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +++ b/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Generator import pandas as pd import mysql.connector @@ -12,11 +12,16 @@ from mindsdb.integrations.libs.response import ( HandlerStatusResponse as StatusResponse, HandlerResponse as Response, - RESPONSE_TYPE, + TableResponse, + OkResponse, + ErrorResponse, + DataHandlerResponse, ) from mindsdb.integrations.handlers.mysql_handler.settings import ConnectionConfig from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import C_TYPES, DATA_C_TYPE_MAP +from mindsdb.utilities.types.column import Column +from mindsdb.utilities.config import config as mindsdb_config logger = log.getLogger(__name__) @@ -37,57 +42,47 @@ def _map_type(mysql_type_text: str) -> MYSQL_DATA_TYPE: return MYSQL_DATA_TYPE.TEXT -def _make_table_response(result: List[Dict[str, Any]], cursor: mysql.connector.cursor.MySQLCursor) -> Response: - """Build response from result and cursor. +def _get_columns(cursor: mysql.connector.cursor.MySQLCursor) -> list[Column]: + """Get columns from cursor description. Args: - result (list[dict]): result of the query. cursor (mysql.connector.cursor.MySQLCursor): cursor object. Returns: - Response: response object. + list[Column]: List of Column objects with type and dtype info. """ description = cursor.description reverse_c_type_map = {v.code: k for k, v in DATA_C_TYPE_MAP.items() if v.code != C_TYPES.MYSQL_TYPE_BLOB} - mysql_types: list[MYSQL_DATA_TYPE] = [] + columns = [] for col in description: + column_name = col[0] type_int = col[1] - if isinstance(type_int, int) is False: - mysql_types.append(MYSQL_DATA_TYPE.TEXT) - continue - if type_int == C_TYPES.MYSQL_TYPE_TINY: + if isinstance(type_int, int) is False: + mysql_type = MYSQL_DATA_TYPE.TEXT + elif type_int == C_TYPES.MYSQL_TYPE_TINY: # There are 3 types that returns as TINYINT: TINYINT, BOOL, BOOLEAN. - mysql_types.append(MYSQL_DATA_TYPE.TINYINT) - continue - - if type_int in reverse_c_type_map: - mysql_types.append(reverse_c_type_map[type_int]) - continue - - if type_int == C_TYPES.MYSQL_TYPE_BLOB: + mysql_type = MYSQL_DATA_TYPE.TINYINT + elif type_int in reverse_c_type_map: + mysql_type = reverse_c_type_map[type_int] + elif type_int == C_TYPES.MYSQL_TYPE_BLOB: # region determine text/blob type by flags # Unfortunately, there is no way to determine particular type of text/blob column by flags. # Subtype have to be determined by 8-s element of description tuple, but mysql.conector # return the same value for all text types (TINYTEXT, TEXT, MEDIUMTEXT, LONGTEXT), and for # all blob types (TINYBLOB, BLOB, MEDIUMBLOB, LONGBLOB). - if col[7] == 16: # and col[8] == 45 - mysql_types.append(MYSQL_DATA_TYPE.TEXT) - elif col[7] == 144: # and col[8] == 63 - mysql_types.append(MYSQL_DATA_TYPE.BLOB) + if col[7] == 16: + mysql_type = MYSQL_DATA_TYPE.TEXT + elif col[7] == 144: + mysql_type = MYSQL_DATA_TYPE.BLOB else: logger.debug(f"MySQL handler: unknown type code {col[7]}, use TEXT as fallback.") - mysql_types.append(MYSQL_DATA_TYPE.TEXT) + mysql_type = MYSQL_DATA_TYPE.TEXT # endregion else: - logger.warning(f"MySQL handler: unknown type id={type_int} in column {col[0]}, use TEXT as fallback.") - mysql_types.append(MYSQL_DATA_TYPE.TEXT) + logger.warning(f"MySQL handler: unknown type id={type_int} in column {column_name}, use TEXT as fallback.") + mysql_type = MYSQL_DATA_TYPE.TEXT - # region cast int and bool to nullable types - serieses = [] - for i, mysql_type in enumerate(mysql_types): - expected_dtype = None - column_name = description[i][0] if mysql_type in ( MYSQL_DATA_TYPE.SMALLINT, MYSQL_DATA_TYPE.INT, @@ -98,12 +93,27 @@ def _make_table_response(result: List[Dict[str, Any]], cursor: mysql.connector.c expected_dtype = "Int64" elif mysql_type in (MYSQL_DATA_TYPE.BOOL, MYSQL_DATA_TYPE.BOOLEAN): expected_dtype = "boolean" - serieses.append(pd.Series([row[column_name] for row in result], dtype=expected_dtype, name=description[i][0])) - df = pd.concat(serieses, axis=1, copy=False) - # endregion + else: + expected_dtype = None + + columns.append(Column(name=column_name, type=mysql_type, dtype=expected_dtype)) + return columns + + +def _make_df(result: list[tuple[Any]], columns: list[Column]) -> pd.DataFrame: + """Make pandas DataFrame from result and columns. + + Args: + result (list[tuple[Any]]): result of the query (list of tuples). + columns (list[Column]): list of columns. - response = Response(RESPONSE_TYPE.TABLE, df, affected_rows=cursor.rowcount, mysql_types=mysql_types) - return response + Returns: + pd.DataFrame: pandas DataFrame. + """ + serieses = [] + for i, column in enumerate(columns): + serieses.append(pd.Series([row[i] for row in result], dtype=column.dtype, name=column.name)) + return pd.concat(serieses, axis=1, copy=False) class MySQLHandler(MetaDatabaseHandler): @@ -112,6 +122,7 @@ class MySQLHandler(MetaDatabaseHandler): """ name = "mysql" + stream_response = True def __init__(self, name: str, **kwargs: Any) -> None: super().__init__(name) @@ -229,41 +240,100 @@ def check_connection(self) -> StatusResponse: return result - def native_query(self, query: str) -> Response: - """ - Executes a SQL query on the MySQL database and returns the result. + def native_query(self, query: str, stream: bool = True, **kwargs) -> DataHandlerResponse: + """Executes a SQL query on the MySQL database and returns the result. Args: query (str): The SQL query to be executed. + stream (bool): Whether to stream the results of the query. + **kwargs: Additional keyword arguments. Returns: - Response: A response object containing the result of the query or an error message. + DataHandlerResponse: A response object containing the result of the query or an error message. """ - need_to_close = not self.is_connected - connection = None - try: - connection = self.connect() - with connection.cursor(dictionary=True, buffered=True) as cur: - cur.execute(query) - if cur.with_rows: - result = cur.fetchall() - response = _make_table_response(result, cur) - else: - response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount) - except mysql.connector.Error as e: - logger.error( - f"Error running query: {query} on {self.connection_data.get('database', 'unknown')}! Error: {e}" - ) - response = Response(RESPONSE_TYPE.ERROR, error_code=e.errno or 1, error_message=str(e)) - if connection is not None and connection.is_connected(): - connection.rollback() + if stream is False: + response = self._execute_fetchall(query) + else: + generator = self._execute_fetchmany(query) + try: + response: TableResponse = next(generator) + response.data_generator = generator + except StopIteration as e: + response = e.value + if isinstance(response, DataHandlerResponse) is False: + raise + return response - if need_to_close: - self.disconnect() + def _execute_fetchall(self, query: str) -> DataHandlerResponse: + """Executes a SQL query on the MySQL database and returns the full result at once. + + Args: + query (str): The SQL query to be executed. + Returns: + DataHandlerResponse: A response object containing the result of the query or an error message. + """ + connection = self.connect() + with connection.cursor(buffered=True) as cursor: + try: + cursor.execute(query) + if cursor.with_rows: + result = cursor.fetchall() + columns = _get_columns(cursor) + df = _make_df(result, columns) + response = TableResponse(data=df, affected_rows=cursor.rowcount, columns=columns) + else: + response = OkResponse(affected_rows=cursor.rowcount) + except Exception as e: + response = self._handle_query_exception(e, query, connection) return response - def query(self, query: ASTNode) -> Response: + def _execute_fetchmany( + self, query: str + ) -> Generator[TableResponse | pd.DataFrame, None, OkResponse | ErrorResponse]: + """Execute a SQL query on the MySQL database and return a generator of data frames. + + Args: + query (str): The SQL query to be executed. + + Returns: + Generator[TableResponse | pd.DataFrame, None, OkResponse | ErrorResponse]: Generator of data frames. + """ + connection = self.connect() + with connection.cursor(buffered=False) as cursor: + try: + cursor.execute(query) + if not cursor.with_rows: + return OkResponse(affected_rows=cursor.rowcount) + + columns = _get_columns(cursor) + yield TableResponse(affected_rows=cursor.rowcount, columns=columns) + + fetch_size = mindsdb_config["data_stream"]["fetch_size"] + while result := cursor.fetchmany(size=fetch_size): + yield _make_df(result, columns) + except Exception as e: + return self._handle_query_exception(e, query, connection) + + def _handle_query_exception(self, e: Exception, query: str, connection) -> ErrorResponse: + """Handle query execution errors with appropriate logging and rollback. + + Args: + e: The exception that was raised + query: The SQL query that failed + connection: The database connection to rollback + + Returns: + ErrorResponse with appropriate error details + """ + logger.error(f"Error running query: {query} on {self.connection_data.get('database', 'unknown')}! Error: {e}") + if connection is not None and connection.is_connected(): + connection.rollback() + if isinstance(e, mysql.connector.Error): + return ErrorResponse(error_code=e.errno or 1, error_message=str(e)) + return ErrorResponse(error_code=0, error_message=str(e)) + + def query(self, query: ASTNode) -> DataHandlerResponse: """ Retrieve the data from the SQL statement. """ diff --git a/mindsdb/integrations/handlers/oracle_handler/oracle_handler.py b/mindsdb/integrations/handlers/oracle_handler/oracle_handler.py index ad9c4cde578..79d4c342ff4 100644 --- a/mindsdb/integrations/handlers/oracle_handler/oracle_handler.py +++ b/mindsdb/integrations/handlers/oracle_handler/oracle_handler.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Text +from typing import Any, Generator import oracledb import pandas as pd @@ -10,9 +10,15 @@ HandlerStatusResponse as StatusResponse, HandlerResponse as Response, RESPONSE_TYPE, + TableResponse, + OkResponse, + ErrorResponse, + DataHandlerResponse, ) from mindsdb.utilities import log from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender +from mindsdb.utilities.config import config as mindsdb_config +from mindsdb.utilities.types.column import Column import mindsdb.utilities.profiler as profiler from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE @@ -80,43 +86,43 @@ def _map_type(internal_type_name: str) -> MYSQL_DATA_TYPE: return MYSQL_DATA_TYPE.VARCHAR -def _make_table_response(result: list[tuple[Any]], cursor: Cursor) -> Response: - """Build response from result and cursor. +def _get_colums(cursor: Cursor) -> list[Column]: + """Get columns from cursor. Args: - result (list[tuple[Any]]): result of the query. - cursor (oracledb.Cursor): cursor object. + cursor (psycopg.Cursor): cursor object. Returns: - Response: response object. + List of columns """ - description: list[tuple[Any]] = cursor.description - mysql_types: list[MYSQL_DATA_TYPE] = [] - for column in description: + columns = [] + for column in cursor.description: + column_name = column[0] db_type = column[1] precision = column[4] scale = column[5] + mysql_type = None if db_type is oracledb.DB_TYPE_JSON: - mysql_types.append(MYSQL_DATA_TYPE.JSON) + mysql_type = MYSQL_DATA_TYPE.JSON elif db_type is oracledb.DB_TYPE_VECTOR: - mysql_types.append(MYSQL_DATA_TYPE.VECTOR) + mysql_type = MYSQL_DATA_TYPE.VECTOR elif db_type is oracledb.DB_TYPE_NUMBER: if scale != 0: - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) + mysql_type = MYSQL_DATA_TYPE.FLOAT else: # python max int is 19 digits, oracle can return more if precision > 18: - mysql_types.append(MYSQL_DATA_TYPE.DECIMAL) + mysql_type = MYSQL_DATA_TYPE.DECIMAL else: - mysql_types.append(MYSQL_DATA_TYPE.INT) + mysql_type = MYSQL_DATA_TYPE.INT elif db_type is oracledb.DB_TYPE_BINARY_FLOAT: - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) + mysql_type = MYSQL_DATA_TYPE.FLOAT elif db_type is oracledb.DB_TYPE_BINARY_DOUBLE: - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) + mysql_type = MYSQL_DATA_TYPE.FLOAT elif db_type is oracledb.DB_TYPE_BINARY_INTEGER: - mysql_types.append(MYSQL_DATA_TYPE.INT) + mysql_type = MYSQL_DATA_TYPE.INT elif db_type is oracledb.DB_TYPE_BOOLEAN: - mysql_types.append(MYSQL_DATA_TYPE.BOOLEAN) + mysql_type = MYSQL_DATA_TYPE.BOOLEAN elif db_type in ( oracledb.DB_TYPE_CHAR, oracledb.DB_TYPE_NCHAR, @@ -125,22 +131,35 @@ def _make_table_response(result: list[tuple[Any]], cursor: Cursor) -> Response: oracledb.DB_TYPE_VARCHAR, oracledb.DB_TYPE_LONG_NVARCHAR, ): - mysql_types.append(MYSQL_DATA_TYPE.TEXT) + mysql_type = MYSQL_DATA_TYPE.TEXT elif db_type in (oracledb.DB_TYPE_RAW, oracledb.DB_TYPE_LONG_RAW): - mysql_types.append(MYSQL_DATA_TYPE.BINARY) + mysql_type = MYSQL_DATA_TYPE.BINARY elif db_type is oracledb.DB_TYPE_DATE: - mysql_types.append(MYSQL_DATA_TYPE.DATE) + mysql_type = MYSQL_DATA_TYPE.DATE elif db_type is oracledb.DB_TYPE_TIMESTAMP: - mysql_types.append(MYSQL_DATA_TYPE.TIMESTAMP) + mysql_type = MYSQL_DATA_TYPE.TIMESTAMP else: # fallback - mysql_types.append(MYSQL_DATA_TYPE.TEXT) + mysql_type = MYSQL_DATA_TYPE.TEXT + + columns.append(Column(name=column_name, type=mysql_type)) + return columns + + +def _make_df(result: list[tuple[Any]], columns: list[Column]) -> pd.DataFrame: + """Make pandas DataFrame from result and columns. + + Args: + result (list[tuple[Any]]): result of the query. + columns (list[Column]): list of columns. - # region cast int and bool to nullable types + Returns: + pd.DataFrame: pandas DataFrame. + """ serieses = [] - for i, mysql_type in enumerate(mysql_types): + for i, column in enumerate(columns): expected_dtype = None - if mysql_type in ( + if column.type in ( MYSQL_DATA_TYPE.SMALLINT, MYSQL_DATA_TYPE.INT, MYSQL_DATA_TYPE.MEDIUMINT, @@ -148,13 +167,11 @@ def _make_table_response(result: list[tuple[Any]], cursor: Cursor) -> Response: MYSQL_DATA_TYPE.TINYINT, ): expected_dtype = "Int64" - elif mysql_type in (MYSQL_DATA_TYPE.BOOL, MYSQL_DATA_TYPE.BOOLEAN): + elif column.type in (MYSQL_DATA_TYPE.BOOL, MYSQL_DATA_TYPE.BOOLEAN): expected_dtype = "boolean" - serieses.append(pd.Series([row[i] for row in result], dtype=expected_dtype, name=description[i][0])) + serieses.append(pd.Series([row[i] for row in result], dtype=expected_dtype, name=column.name)) df = pd.concat(serieses, axis=1, copy=False) - # endregion - - return Response(RESPONSE_TYPE.TABLE, data_frame=df, mysql_types=mysql_types) + return df class OracleHandler(MetaDatabaseHandler): @@ -163,14 +180,15 @@ class OracleHandler(MetaDatabaseHandler): """ name = "oracle" + stream_response = True - def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs) -> None: + def __init__(self, name: str, connection_data: dict | None, **kwargs) -> None: """ Initializes the handler. Args: - name (Text): The name of the handler instance. - connection_data (Dict): The connection data required to connect to OracleDB. + name (str): The name of the handler instance. + connection_data (dict | None): The connection data required to connect to OracleDB. kwargs: Arbitrary keyword arguments. """ super().__init__(name) @@ -304,78 +322,99 @@ def check_connection(self) -> StatusResponse: return response - @profiler.profile() - def native_query(self, query: Text) -> Response: - """ - Executes a SQL query on the Oracle database and returns the result. + def native_query(self, query: str, stream: bool = True, **kwargs) -> TableResponse | OkResponse | ErrorResponse: + """Executes a SQL query on the Oracle database and returns the result. Args: - query (Text): The SQL query to be executed. + query (str): The SQL query to be executed. + stream (bool): Whether to execute the query on the server side (streaming). + **kwargs: Additional keyword arguments. Returns: - Response: A response object containing the result of the query or an error message. + TableResponse | OkResponse | ErrorResponse: A response object containing the result of the query or an error message. """ - need_to_close = self.is_connected is False + if stream is False: + response = self._execute_fetchall(query, **kwargs) + else: + generator = self._execute_fetchmany(query, **kwargs) + try: + response: TableResponse = next(generator) + response.data_generator = generator + except StopIteration as e: + response = e.value + if isinstance(response, DataHandlerResponse) is False: + raise + return response + def _execute_fetchmany(self, query: str) -> Generator[pd.DataFrame, None, OkResponse | ErrorResponse]: connection = self.connect() - with connection.cursor() as cur: + with connection.cursor() as cursor: try: - cur.execute(query) - if cur.description is None: - response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount) - else: - result = cur.fetchall() - response = _make_table_response(result, cur) - connection.commit() - except DatabaseError as database_error: - logger.error(f"Error running query: {query} on Oracle, {database_error}!") - response = Response( - RESPONSE_TYPE.ERROR, - error_message=str(database_error), - ) - connection.rollback() + # Configure cursor for optimal server-side streaming + fetch_size = mindsdb_config["data_stream"]["fetch_size"] + cursor.arraysize = fetch_size - except Exception as unknown_error: - logger.error(f"Unknwon error running query: {query} on Oracle, {unknown_error}!") - response = Response( - RESPONSE_TYPE.ERROR, - error_message=str(unknown_error), - ) - connection.rollback() + cursor.execute(query) - if need_to_close is True: - self.disconnect() - return response + if cursor.description is None: + connection.commit() + return OkResponse(affected_rows=cursor.rowcount) - def query_stream(self, query: ASTNode, fetch_size: int = 1000): - """ - Executes a SQL query represented by an ASTNode and retrieves the data in a streaming fashion. + columns = _get_colums(cursor) + yield TableResponse(affected_rows=cursor.rowcount, columns=columns) + # Stream data in batches + while result := cursor.fetchmany(cursor.arraysize): + yield _make_df(result, columns) + connection.commit() + except Exception as e: + return self._handle_query_exception(e, query, connection) + + def _execute_fetchall(self, query: str) -> DataHandlerResponse: + """Executes a SQL query and fetches all results at once (client-side). Args: - query (ASTNode): An ASTNode representing the SQL query to be executed. - fetch_size (int): The number of rows to fetch in each batch. - Yields: - pd.DataFrame: A DataFrame containing a batch of rows from the query result. - Response: In case of an error, yields a Response object with the error details. - """ - query_str = SqlalchemyRender("oracle").get_string(query, with_failback=True) - need_to_close = self.is_connected is False + query (str): The SQL query to be executed. + Returns: + TableResponse | OkResponse | ErrorResponse: A response object containing the result of the query or an error message. + """ connection = self.connect() - with connection.cursor() as cur: + with connection.cursor() as cursor: try: - cur.execute(query_str) - while True: - result = cur.fetchmany(fetch_size) - if not result: - break - df = pd.DataFrame(result, columns=[col[0] for col in cur.description]) - yield df + cursor.execute(query) + if cursor.description is None: + response = OkResponse(affected_rows=cursor.rowcount) + else: + # Fetch all results at once + result = cursor.fetchall() + columns = _get_colums(cursor) + df = _make_df(result, columns) + response = TableResponse(data=df, affected_rows=cursor.rowcount, columns=columns) connection.commit() - finally: - connect - if need_to_close is True: - self.disconnect() + except Exception as e: + response = self._handle_query_exception(e, query, connection) + + return response + + def _handle_query_exception(self, e: Exception, query: str, connection) -> ErrorResponse: + """Handle query execution errors with appropriate logging and rollback. + + Args: + e: The exception that was raised + query: The SQL query that failed + connection: The database connection to rollback + + Returns: + ErrorResponse with appropriate error details + """ + if isinstance(e, DatabaseError): + logger.error(f"Error running query: {query} on Oracle, {e}!") + connection.rollback() + return ErrorResponse(error_code=0, error_message=str(e)) + + logger.error(f"Unknown error running query: {query} on Oracle, {e}!") + connection.rollback() + return ErrorResponse(error_code=0, error_message=str(e)) def insert(self, table_name: str, df: pd.DataFrame) -> Response: """ @@ -454,12 +493,12 @@ def get_tables(self) -> Response: """ return self.native_query(query) - def get_columns(self, table_name: Text) -> Response: + def get_columns(self, table_name: str) -> Response: """ Retrieves column details for a specified table in the Oracle database. Args: - table_name (Text): The name of the table for which to retrieve column information. + table_name (str): The name of the table for which to retrieve column information. Returns: Response: A response object containing the column details, formatted as per the `Response` class. @@ -485,11 +524,11 @@ def get_columns(self, table_name: Text) -> Response: ORDER BY TABLE_NAME, COLUMN_ID """ result = self.native_query(query) - if result.resp_type is RESPONSE_TYPE.TABLE: + if result.type is RESPONSE_TYPE.TABLE: result.to_columns_table_response(map_type_fn=_map_type) return result - def meta_get_tables(self, table_names: Optional[List[str]]) -> Response: + def meta_get_tables(self, table_names: list[str] | None) -> Response: """ Retrieves metadata about all non-system tables and views accessible to the current user. @@ -524,11 +563,11 @@ def meta_get_tables(self, table_names: Optional[List[str]]) -> Response: result = self.native_query(query) return result - def meta_get_columns(self, table_names: Optional[List[str]]) -> Response: + def meta_get_columns(self, table_names: list[str] | None) -> Response: """Retrieves metadata about the columns of specified tables accessible to the current user. Args: - table_names (list[str]): A list of table names for which to retrieve column metadata. + table_names (list[str] | None): A list of table names for which to retrieve column metadata. Returns: Response: A response object containing column metadata. @@ -564,11 +603,11 @@ def meta_get_columns(self, table_names: Optional[List[str]]) -> Response: result = self.native_query(query) return result - def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> Response: + def meta_get_column_statistics(self, table_names: list[str] | None) -> Response: """Retrieves statistics about the columns of specified tables accessible to the current user. Args: - table_names (list[str]): A list of table names for which to retrieve column statistics. + table_names (list[str] | None): A list of table names for which to retrieve column statistics. Returns: Response: A response object containing column statistics. @@ -623,12 +662,12 @@ def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> Respon result = self.native_query(query) - if result.resp_type is RESPONSE_TYPE.TABLE and result.data_frame is not None: + if result.type is RESPONSE_TYPE.TABLE and result.data_frame is not None: df = result.data_frame def extract_min_max( histogram_str: str, - ) -> tuple[Optional[float], Optional[float]]: + ) -> tuple[float | None, float | None]: if histogram_str and str(histogram_str).lower() not in ["nan", "none"]: values = str(histogram_str).split(",") if values: @@ -643,12 +682,12 @@ def extract_min_max( df.drop(columns=["HISTOGRAM_BOUNDS"], inplace=True) return result - def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> Response: + def meta_get_primary_keys(self, table_names: list[str] | None) -> Response: """ Retrieves the primary keys for the specified tables accessible to the current user. Args: - table_names (list[str]): A list of table names for which to retrieve primary keys. + table_names (list[str] | None): A list of table names for which to retrieve primary keys. Returns: Response: A response object containing primary key information. @@ -681,12 +720,12 @@ def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> Response: result = self.native_query(query) return result - def meta_get_foreign_keys(self, table_names: Optional[List[str]]) -> Response: + def meta_get_foreign_keys(self, table_names: list[str] | None) -> Response: """ Retrieves the foreign keys for the specified tables accessible to the current user. Args: - table_names (list[str]): A list of table names for which to retrieve foreign keys. + table_names (list[str] | None): A list of table names for which to retrieve foreign keys. Returns: Response: A response object containing foreign key information. diff --git a/mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py b/mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py index ab9aac0b340..00318123794 100644 --- a/mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py +++ b/mindsdb/integrations/handlers/pgvector_handler/pgvector_handler.py @@ -131,7 +131,7 @@ def query(self, query: ASTNode) -> Response: if isinstance(query, DropTables): query.tables = [self._check_table(table.parts[-1]) for table in query.tables] query_str, params = self.renderer.get_exec_params(query, with_failback=True) - return self.native_query(query_str, params, no_restrict=True) + return self.native_query(query_str, params, no_restrict=True, stream=False) return super().query(query) def native_query(self, query, params=None, no_restrict=False) -> Response: @@ -146,7 +146,7 @@ def native_query(self, query, params=None, no_restrict=False) -> Response: return super().native_query(query, params=params) def raw_query(self, query, params=None) -> Response: - resp = super().native_query(query, params) + resp = super().native_query(query, params, stream=False) if resp.resp_type == RESPONSE_TYPE.ERROR: raise RuntimeError(resp.error_message) if resp.resp_type == RESPONSE_TYPE.TABLE: diff --git a/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py b/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py index 9e8330c19e9..64afe0913aa 100644 --- a/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +++ b/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py @@ -1,7 +1,7 @@ import time import json import logging -from typing import Optional, Any +from typing import Optional, Any, Generator import pandas as pd from pandas import DataFrame @@ -10,19 +10,25 @@ from psycopg.postgres import TypeInfo, types as pg_types from psycopg.pq import ExecStatus -from mindsdb_sql_parser import parse_sql -from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender +from mindsdb_sql_parser import parse_sql, Select from mindsdb_sql_parser.ast.base import ASTNode -from mindsdb.integrations.libs.base import MetaDatabaseHandler +import mindsdb.utilities.profiler as profiler +from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender +from mindsdb.utilities.types.column import Column from mindsdb.utilities import log +from mindsdb.integrations.libs.base import MetaDatabaseHandler from mindsdb.integrations.libs.response import ( HandlerStatusResponse as StatusResponse, HandlerResponse as Response, RESPONSE_TYPE, + TableResponse, + OkResponse, + ErrorResponse, + DataHandlerResponse, ) -import mindsdb.utilities.profiler as profiler from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE +from mindsdb.utilities.config import config as mindsdb_config logger = log.getLogger(__name__) @@ -70,15 +76,14 @@ def _map_type(internal_type_name: str | None) -> MYSQL_DATA_TYPE: return fallback_type -def _make_table_response(result: list[tuple[Any]], cursor: Cursor) -> Response: - """Build response from result and cursor. +def _get_columns(cursor: Cursor) -> list[Column]: + """Get columns from cursor. Args: - result (list[tuple[Any]]): result of the query. cursor (psycopg.Cursor): cursor object. Returns: - Response: response object. + List of columns """ description: list[PGColumn] = cursor.description mysql_types: list[MYSQL_DATA_TYPE] = [] @@ -108,11 +113,9 @@ def _make_table_response(result: list[tuple[Any]], cursor: Cursor) -> Response: mysql_type = _map_type(regtype) mysql_types.append(mysql_type) - # region cast int and bool to nullable types - serieses = [] - for i, mysql_type in enumerate(mysql_types): - expected_dtype = None - if mysql_type in ( + result = [] + for i, column in enumerate(cursor.description): + if mysql_types[i] in ( MYSQL_DATA_TYPE.SMALLINT, MYSQL_DATA_TYPE.INT, MYSQL_DATA_TYPE.MEDIUMINT, @@ -120,13 +123,30 @@ def _make_table_response(result: list[tuple[Any]], cursor: Cursor) -> Response: MYSQL_DATA_TYPE.TINYINT, ): expected_dtype = "Int64" - elif mysql_type in (MYSQL_DATA_TYPE.BOOL, MYSQL_DATA_TYPE.BOOLEAN): + elif mysql_types[i] in (MYSQL_DATA_TYPE.BOOL, MYSQL_DATA_TYPE.BOOLEAN): expected_dtype = "boolean" - serieses.append(pd.Series([row[i] for row in result], dtype=expected_dtype, name=description[i].name)) - df = pd.concat(serieses, axis=1, copy=False) - # endregion + else: + expected_dtype = None + result.append( + Column(name=column.name, type=mysql_types[i], original_type=column.type_display, dtype=expected_dtype) + ) + return result + + +def _make_df(result: list[tuple[Any]], columns: list[Column]) -> pd.DataFrame: + """Make pandas DataFrame from result and columns. + + Args: + result (list[tuple[Any]]): result of the query. + columns (list[Column]): list of columns. - return Response(RESPONSE_TYPE.TABLE, data_frame=df, affected_rows=cursor.rowcount, mysql_types=mysql_types) + Returns: + pd.DataFrame: pandas DataFrame. + """ + serieses = [] + for i, column in enumerate(columns): + serieses.append(pd.Series([row[i] for row in result], dtype=column.dtype, name=column.name)) + return pd.concat(serieses, axis=1, copy=False) class PostgresHandler(MetaDatabaseHandler): @@ -135,6 +155,7 @@ class PostgresHandler(MetaDatabaseHandler): """ name = "postgres" + stream_response = True @profiler.profile("init_pg_handler") def __init__(self, name=None, **kwargs): @@ -282,19 +303,47 @@ def _cast_dtypes(self, df: DataFrame, description: list) -> DataFrame: logger.error(f"Error casting column {col.name} to {types_map[pg_type_info.name]}: {e}") df.columns = columns - @profiler.profile() - def native_query(self, query: str, params=None, **kwargs) -> Response: - """ - Executes a SQL query on the PostgreSQL database and returns the result. + def native_query(self, query: str, params=None, stream: bool = True, **kwargs) -> DataHandlerResponse: + """Executes a SQL query on the PostgreSQL database and returns the result. + NOTE: 'INSERT' (and may be some else) queries can not be executed on the server side, + but there are fallbackto client side execution. Args: query (str): The SQL query to be executed. + params (list): The parameters to be passed to the query. + stream (bool): Whether to stream the results of the query. + **kwargs: Additional keyword arguments. Returns: - Response: A response object containing the result of the query or an error message. + DataHandlerResponse: A response object containing the result of the query or an error message. """ - need_to_close = not self.is_connected + if stream is False: + response = self._execute_client_side(query, params, **kwargs) + elif params is not None: + logger.info("Server side cursor does not support 'fetchmany', executing with client side cursor") + response = self._execute_client_side(query, params, **kwargs) + else: + generator = self._execute_server_side(query, **kwargs) + try: + response: TableResponse = next(generator) + response.data_generator = generator + except StopIteration as e: + response = e.value + if isinstance(response, DataHandlerResponse) is False: + raise + return response + + def _execute_client_side(self, query: str, params=None, **kwargs) -> TableResponse | OkResponse | ErrorResponse: + """Executes a SQL query on the PostgreSQL database and returns the result. + + Args: + query (str): The SQL query to be executed. + params (list): The parameters to be passed to the query. + **kwargs: Additional keyword arguments. + Returns: + TableResponse | OkResponse | ErrorResponse: A response object containing the result of the query or an error message. + """ connection = self.connect() with connection.cursor() as cur: try: @@ -303,66 +352,86 @@ def native_query(self, query: str, params=None, **kwargs) -> Response: else: cur.execute(query) if cur.pgresult is None or ExecStatus(cur.pgresult.status) == ExecStatus.COMMAND_OK: - response = Response(RESPONSE_TYPE.OK, affected_rows=cur.rowcount) + response = OkResponse(affected_rows=cur.rowcount) else: result = cur.fetchall() - response = _make_table_response(result, cur) + columns: list[Column] = _get_columns(cur) + response = TableResponse( + affected_rows=cur.rowcount, columns=columns, data=_make_df(result, columns) + ) connection.commit() - except (psycopg.ProgrammingError, psycopg.DataError) as e: - # These is 'expected' exceptions, they should not be treated as mindsdb's errors - # ProgrammingError: table not found or already exists, syntax error, etc - # DataError: division by zero, numeric value out of range, etc. - # https://www.psycopg.org/psycopg3/docs/api/errors.html - log_message = "Database query failed with error, likely due to invalid SQL query" - if logger.isEnabledFor(logging.DEBUG): - log_message += f". Executed query:\n{query}" - logger.info(log_message) - response = Response(RESPONSE_TYPE.ERROR, error_code=0, error_message=str(e), is_expected_error=True) - connection.rollback() except Exception as e: - logger.error(f"Error running query:\n{query}\non {self.database}, {e}") - response = Response(RESPONSE_TYPE.ERROR, error_code=0, error_message=str(e)) - connection.rollback() - - if need_to_close: - self.disconnect() + response = self._handle_query_exception(e, query, connection) return response - def query_stream(self, query: ASTNode, fetch_size: int = 1000): - """ - Executes a SQL query and stream results outside by batches - - :param query: An ASTNode representing the SQL query to be executed. - :param fetch_size: size of the batch - :return: generator with query results - """ - query_str, params = self.renderer.get_exec_params(query, with_failback=True) + def _execute_server_side( + self, query: str, **kwargs + ) -> Generator[TableResponse | pd.DataFrame, None, OkResponse | ErrorResponse]: + """Execute a SQL query on the PostgreSQL database and return a generator of data frames. - need_to_close = not self.is_connected + Args: + query (str): The SQL query to be executed. + params (list): The parameters to be passed to the query. + **kwargs: Additional keyword arguments. + Returns: + Generator[TableResponse | pd.DataFrame, None, OkResponse | ErrorResponse]: Generator of data frames. + """ connection = self.connect() - with connection.cursor() as cur: + with connection.cursor(name=f"mindsdb_{id(self)}") as cursor: try: - if params is not None: - cur.executemany(query_str, params) - else: - cur.execute(query_str) - - if cur.pgresult is not None and ExecStatus(cur.pgresult.status) != ExecStatus.COMMAND_OK: - while True: - result = cur.fetchmany(fetch_size) - if not result: - break - df = DataFrame(result, columns=[x.name for x in cur.description]) - self._cast_dtypes(df, cur.description) - yield df + try: + cursor.execute(query) + except psycopg.errors.SyntaxError as e: + # NOTE: INSERT queries cannot be executed server-side. When they fail, they produce a syntax error + # that always starts with the text below, regardless of the INSERT query format. + lower_e = str(e).lower() + if not lower_e.startswith('syntax error at or near "insert"') and not lower_e.startswith( + 'syntax error at or near "drop"' + ): + raise + connection.rollback() + return self._execute_client_side(query=query) + + if cursor.description is None: + connection.commit() + return OkResponse(affected_rows=cursor.rowcount) + + columns: list[Column] = _get_columns(cursor) + yield TableResponse(affected_rows=cursor.rowcount, columns=columns) + while result := cursor.fetchmany(size=mindsdb_config["data_stream"]["fetch_size"]): + yield _make_df(result, columns) connection.commit() - finally: - connection.rollback() + except Exception as e: + return self._handle_query_exception(e, query, connection) - if need_to_close: - self.disconnect() + def _handle_query_exception(self, e: Exception, query: str, connection) -> ErrorResponse: + """Handle query execution errors with appropriate logging and rollback. + + Args: + e: The exception that was raised + query: The SQL query that failed + connection: The database connection to rollback + + Returns: + ErrorResponse with appropriate error details + """ + if isinstance(e, (psycopg.ProgrammingError, psycopg.DataError)): + # These are 'expected' exceptions, they should not be treated as mindsdb's errors + # ProgrammingError: table not found or already exists, syntax error, etc + # DataError: division by zero, numeric value out of range, etc. + # https://www.psycopg.org/psycopg3/docs/api/errors.html + log_message = "Database query failed with error, likely due to invalid SQL query" + if logger.isEnabledFor(logging.DEBUG): + log_message += f". Executed query:\n{query}" + logger.info(log_message) + connection.rollback() + return ErrorResponse(error_code=0, error_message=str(e), is_expected_error=True) + else: + logger.error(f"Error running query:\n{query}\non {self.database}, {e}") + connection.rollback() + return ErrorResponse(error_code=0, error_message=str(e)) def insert(self, table_name: str, df: pd.DataFrame) -> Response: need_to_close = not self.is_connected @@ -401,7 +470,7 @@ def insert(self, table_name: str, df: pd.DataFrame) -> Response: return Response(RESPONSE_TYPE.OK, affected_rows=rowcount) @profiler.profile() - def query(self, query: ASTNode) -> Response: + def query(self, query: ASTNode) -> DataHandlerResponse: """ Executes a SQL query represented by an ASTNode and retrieves the data. @@ -409,11 +478,13 @@ def query(self, query: ASTNode) -> Response: query (ASTNode): An ASTNode representing the SQL query to be executed. Returns: - Response: The response from the `native_query` method, containing the result of the SQL query execution. + DataHandlerResponse: The response from the `native_query` method, + containing the result of the SQL query execution. """ query_str, params = self.renderer.get_exec_params(query, with_failback=True) logger.debug(f"Executing SQL query: {query_str}") - return self.native_query(query_str, params) + support_stream = isinstance(query, Select) + return self.native_query(query_str, params, stream=support_stream) def get_tables(self, all: bool = False) -> Response: """ diff --git a/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py b/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py index 91e20c74e50..1853b6f3447 100644 --- a/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +++ b/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py @@ -1,24 +1,28 @@ -import psutil +from typing import Any, Optional, List, Generator + import pandas from pandas import DataFrame from pandas.api import types as pd_types from snowflake.sqlalchemy import snowdialect from snowflake import connector from snowflake.connector.errors import NotSupportedError -from snowflake.connector.cursor import SnowflakeCursor, ResultMetadata -from typing import Any, Optional, List +from snowflake.connector.cursor import ResultMetadata from mindsdb_sql_parser.ast.base import ASTNode from mindsdb_sql_parser.ast import Select, Identifier -from mindsdb.utilities import log from mindsdb.integrations.libs.base import MetaDatabaseHandler +from mindsdb.utilities import log from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender +from mindsdb.utilities.types.column import Column from mindsdb.integrations.libs.response import ( HandlerStatusResponse as StatusResponse, - HandlerResponse as Response, - RESPONSE_TYPE, + TableResponse, + OkResponse, + ErrorResponse, + DataHandlerResponse, ) + from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE from .auth_types import ( @@ -50,9 +54,9 @@ def _map_type(internal_type_name: str) -> MYSQL_DATA_TYPE: types_map = { ("NUMBER", "DECIMAL", "DEC", "NUMERIC"): MYSQL_DATA_TYPE.DECIMAL, ("INT , INTEGER , BIGINT , SMALLINT , TINYINT , BYTEINT"): MYSQL_DATA_TYPE.INT, - ("FLOAT", "FLOAT4", "FLOAT8"): MYSQL_DATA_TYPE.FLOAT, + ("FLOAT", "FLOAT4", "FLOAT8", "FIXED"): MYSQL_DATA_TYPE.FLOAT, ("DOUBLE", "DOUBLE PRECISION", "REAL"): MYSQL_DATA_TYPE.DOUBLE, - ("VARCHAR"): MYSQL_DATA_TYPE.VARCHAR, + ("VARCHAR",): MYSQL_DATA_TYPE.VARCHAR, ("CHAR", "CHARACTER", "NCHAR"): MYSQL_DATA_TYPE.CHAR, ("STRING", "TEXT", "NVARCHAR"): MYSQL_DATA_TYPE.TEXT, ("NVARCHAR2", "CHAR VARYING", "NCHAR VARYING"): MYSQL_DATA_TYPE.VARCHAR, @@ -61,9 +65,11 @@ def _map_type(internal_type_name: str) -> MYSQL_DATA_TYPE: ("TIMESTAMP_NTZ", "DATETIME"): MYSQL_DATA_TYPE.DATETIME, ("DATE",): MYSQL_DATA_TYPE.DATE, ("TIME",): MYSQL_DATA_TYPE.TIME, - ("TIMESTAMP_LTZ"): MYSQL_DATA_TYPE.DATETIME, - ("TIMESTAMP_TZ"): MYSQL_DATA_TYPE.DATETIME, - ("VARIANT", "OBJECT", "ARRAY", "MAP", "GEOGRAPHY", "GEOMETRY", "VECTOR"): MYSQL_DATA_TYPE.VARCHAR, + ("TIMESTAMP_LTZ",): MYSQL_DATA_TYPE.DATETIME, + ("TIMESTAMP_TZ",): MYSQL_DATA_TYPE.DATETIME, + ("OBJECT", "ARRAY"): MYSQL_DATA_TYPE.JSON, + ("VECTOR",): MYSQL_DATA_TYPE.VECTOR, + ("VARIANT", "MAP", "GEOGRAPHY", "GEOMETRY", "VECTOR"): MYSQL_DATA_TYPE.VARCHAR, } for db_types_list, mysql_data_type in types_map.items(): @@ -74,100 +80,85 @@ def _map_type(internal_type_name: str) -> MYSQL_DATA_TYPE: return MYSQL_DATA_TYPE.VARCHAR -def _make_table_response(result: DataFrame, cursor: SnowflakeCursor) -> Response: - """Build response from result and cursor. - NOTE: Snowflake return only 'general' type in description, so look on result's - DF types and use types from description only if DF type is 'object' +def _get_columns(description: list[ResultMetadata], sample: pandas.DataFrame = None) -> list[Column]: + """Get columns from Snowflake cursor description. Args: - result (DataFrame): result of the query. - cursor (SnowflakeCursor): cursor object. + description (list[ResultMetadata]): cursor description metadata. + sample (pandas.DataFrame): data sample Returns: - Response: response object. + list[Column]: list of columns with mapped MySQL types. """ - description: list[ResultMetadata] = cursor.description - mysql_types: list[MYSQL_DATA_TYPE] = [] + result = [] for column in description: - column_dtype = result[column.name].dtype - description_column_type = connector.constants.FIELD_ID_TO_NAME.get(column.type_code) - if description_column_type in ("OBJECT", "ARRAY"): - mysql_types.append(MYSQL_DATA_TYPE.JSON) - continue - if description_column_type == "VECTOR": - mysql_types.append(MYSQL_DATA_TYPE.VECTOR) - continue - if pd_types.is_integer_dtype(column_dtype): - column_dtype_name = column_dtype.name - if column_dtype_name in ("int8", "Int8"): - mysql_types.append(MYSQL_DATA_TYPE.TINYINT) - elif column_dtype in ("int16", "Int16"): - mysql_types.append(MYSQL_DATA_TYPE.SMALLINT) - elif column_dtype in ("int32", "Int32"): - mysql_types.append(MYSQL_DATA_TYPE.MEDIUMINT) - elif column_dtype in ("int64", "Int64"): - mysql_types.append(MYSQL_DATA_TYPE.BIGINT) - else: - mysql_types.append(MYSQL_DATA_TYPE.INT) - continue - if pd_types.is_float_dtype(column_dtype): - column_dtype_name = column_dtype.name - if column_dtype_name in ("float16", "Float16"): # Float16 does not exists so far - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) - elif column_dtype_name in ("float32", "Float32"): - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) - elif column_dtype_name in ("float64", "Float64"): - mysql_types.append(MYSQL_DATA_TYPE.DOUBLE) - else: - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) - continue - if pd_types.is_bool_dtype(column_dtype): - mysql_types.append(MYSQL_DATA_TYPE.BOOLEAN) - continue - if pd_types.is_datetime64_any_dtype(column_dtype): - mysql_types.append(MYSQL_DATA_TYPE.DATETIME) - series = result[column.name] - # snowflake use pytz.timezone - if series.dt.tz is not None and getattr(series.dt.tz, "zone", "UTC") != "UTC": - series = series.dt.tz_convert("UTC") - result[column.name] = series.dt.tz_localize(None) - continue - - if pd_types.is_object_dtype(column_dtype): - if description_column_type == "TEXT": - # we can also check column.internal_size, if == 16777216 then it is TEXT, else VARCHAR(internal_size) - mysql_types.append(MYSQL_DATA_TYPE.TEXT) - continue - elif description_column_type == "BINARY": - # if column.internal_size == 8388608 then BINARY, else VARBINARY(internal_size) - mysql_types.append(MYSQL_DATA_TYPE.BINARY) - continue - elif description_column_type == "DATE": - mysql_types.append(MYSQL_DATA_TYPE.DATE) - continue - elif description_column_type == "TIME": - mysql_types.append(MYSQL_DATA_TYPE.TIME) - continue - - if description_column_type == "FIXED": - if column.scale == 0: - mysql_types.append(MYSQL_DATA_TYPE.INT) - else: - # It is NUMBER, DECIMAL or NUMERIC with scale > 0 - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) - continue - elif description_column_type == "REAL": - mysql_types.append(MYSQL_DATA_TYPE.FLOAT) - continue - - mysql_types.append(MYSQL_DATA_TYPE.TEXT) - - df = DataFrame( - result, - columns=[column.name for column in description], - ) - - return Response(RESPONSE_TYPE.TABLE, data_frame=df, affected_rows=None, mysql_types=mysql_types) + mysql_type = None + sf_type_name = connector.constants.FIELD_ID_TO_NAME.get(column.type_code) + if sf_type_name is None: + logger.warning(f"Snowflake handler: unknown type code: {column.type_code}") + mysql_type = MYSQL_DATA_TYPE.VARCHAR + + if sample is not None: + column_dtype = sample[column.name].dtype + + if pd_types.is_integer_dtype(column_dtype): + column_dtype_name = column_dtype.name + if column_dtype_name in ("int8", "Int8"): + mysql_type = MYSQL_DATA_TYPE.TINYINT + elif column_dtype in ("int16", "Int16"): + mysql_type = MYSQL_DATA_TYPE.SMALLINT + elif column_dtype in ("int32", "Int32"): + mysql_type = MYSQL_DATA_TYPE.MEDIUMINT + elif column_dtype in ("int64", "Int64"): + mysql_type = MYSQL_DATA_TYPE.BIGINT + else: + mysql_type = MYSQL_DATA_TYPE.INT + + elif pd_types.is_float_dtype(column_dtype): + column_dtype_name = column_dtype.name + if column_dtype_name in ("float16", "Float16"): # Float16 does not exists so far + mysql_type = MYSQL_DATA_TYPE.FLOAT + elif column_dtype_name in ("float32", "Float32"): + mysql_type = MYSQL_DATA_TYPE.FLOAT + elif column_dtype_name in ("float64", "Float64"): + mysql_type = MYSQL_DATA_TYPE.DOUBLE + else: + mysql_type = MYSQL_DATA_TYPE.FLOAT + + elif pd_types.is_bool_dtype(column_dtype): + mysql_type = MYSQL_DATA_TYPE.BOOLEAN + + elif pd_types.is_datetime64_any_dtype(column_dtype): + mysql_type = MYSQL_DATA_TYPE.DATETIME + series = sample[column.name] + # snowflake use pytz.timezone + if series.dt.tz is not None and getattr(series.dt.tz, "zone", "UTC") != "UTC": + series = series.dt.tz_convert("UTC") + sample[column.name] = series.dt.tz_localize(None) + + elif pd_types.is_object_dtype(column_dtype): + if sf_type_name == "TEXT": + # we can also check column.internal_size, if == 16777216 then it is TEXT, else VARCHAR(internal_size) + mysql_type = MYSQL_DATA_TYPE.TEXT + elif sf_type_name == "BINARY": + # if column.internal_size == 8388608 then BINARY, else VARBINARY(internal_size) + mysql_type = MYSQL_DATA_TYPE.BINARY + elif sf_type_name == "DATE": + mysql_type = MYSQL_DATA_TYPE.DATE + elif sf_type_name == "TIME": + mysql_type = MYSQL_DATA_TYPE.TIME + elif sf_type_name == "FIXED": + if getattr(column, "scale", None) == 0: + mysql_type = MYSQL_DATA_TYPE.INT + else: + # It is NUMBER, DECIMAL or NUMERIC with scale > 0 + mysql_type = MYSQL_DATA_TYPE.FLOAT + + if mysql_type is None: + mysql_type = _map_type(sf_type_name) + + result.append(Column(name=column.name, type=mysql_type, original_type=sf_type_name)) + return result class SnowflakeHandler(MetaDatabaseHandler): @@ -176,6 +167,7 @@ class SnowflakeHandler(MetaDatabaseHandler): """ name = "snowflake" + stream_response = True _auth_types = { "key_pair": KeyPairAuthType(), @@ -269,92 +261,84 @@ def check_connection(self) -> StatusResponse: return response - def native_query(self, query: str) -> Response: - """ - Executes a SQL query on the Snowflake account and returns the result. + def native_query(self, query: str, stream: bool = True, **kwargs) -> TableResponse | OkResponse | ErrorResponse: + """Executes a SQL query on the Snowflake account and returns the result. Args: query (str): The SQL query to be executed. + stream (bool): If True - return TableResponse with generator inside. Returns: - Response: A response object containing the result of the query or an error message. + DataHandlerResponse: A response object containing the result of the query or an error message. """ + generator = self._execute_fetch_batches(query) + try: + response: TableResponse = next(generator) + response.data_generator = generator + if stream is False: + response.fetchall() + except StopIteration as e: + response = e.value + if isinstance(response, DataHandlerResponse) is False: + raise - need_to_close = self.is_connected is False + return response + + def _execute_fetch_batches( + self, query: str + ) -> Generator[TableResponse | pandas.DataFrame, None, OkResponse | ErrorResponse]: + """Execute a SQL query and yield results in batches. + Args: + query (str): The SQL query to execute. + + Yields: + TableResponse: First yield — response with column metadata and affected row count. + pandas.DataFrame: Subsequent yields — batches of query results. + + Returns: + OkResponse: For DML statements (INSERT/DELETE/UPDATE) with affected row count. + ErrorResponse: If an exception occurs during query execution. + """ connection = self.connect() - with connection.cursor(connector.DictCursor) as cur: + with connection.cursor(connector.DictCursor) as cursor: try: - cur.execute(query) + cursor.execute(query) try: try: - batches_iter = cur.fetch_pandas_batches() + batches_iter = cursor.fetch_pandas_batches() except ValueError: # duplicated columns raises ValueError raise NotSupportedError() - - batches = [] - memory_estimation_check_done = False - batches_rowcount = 0 - total_rowcount = cur.rowcount or 0 + try: + sample_df = next(batches_iter) + except StopIteration: + sample_df = None + columns = _get_columns(cursor.description, sample=sample_df) + yield TableResponse(data=sample_df, affected_rows=cursor.rowcount, columns=columns) for batch_df in batches_iter: - batches.append(batch_df) - # region check the size of first batch (if it is big enough) to get an estimate of the full - # dataset size. If it does not fit in memory - raise an error. - # NOTE batch size cannot be set on client side. Also, Snowflake will download - # 'CLIENT_PREFETCH_THREADS' count of chunks in parallel (by default 4), therefore this check - # can not work in some cases. - batches_rowcount += len(batch_df) - if memory_estimation_check_done is False and batches_rowcount > 1000: - memory_estimation_check_done = True - available_memory_kb = psutil.virtual_memory().available >> 10 - batches_size_kb = sum( - [(x.memory_usage(index=True, deep=True).sum() >> 10) for x in batches] - ) - rest_rowcount = total_rowcount - batches_rowcount - rest_estimated_size_kb = int((rest_rowcount / batches_rowcount) * batches_size_kb) - # for pd.concat required at least x2 memory - max_allowed_memory_kb = available_memory_kb / 2.4 - if max_allowed_memory_kb < rest_estimated_size_kb: - error_message = ( - "The query result is too large to fit into available memory. " - f"The dataset contains {total_rowcount} rows with an estimated size " - f"of {rest_estimated_size_kb} KB, but only {max_allowed_memory_kb:.0f} KB " - f"of memory is allowed fot the dataset. Please narrow down the query by adding filters " - f"or a LIMIT clause to reduce the result set size." - ) - logger.error(error_message) - raise MemoryError(error_message) - # endregion - if len(batches) > 0: - response = _make_table_response(result=pandas.concat(batches, ignore_index=True), cursor=cur) - else: - response = Response(RESPONSE_TYPE.TABLE, DataFrame([], columns=[x[0] for x in cur.description])) + yield batch_df except NotSupportedError: # Fallback for CREATE/DELETE/UPDATE. These commands returns table with single column, # but it cannot be retrieved as pandas DataFrame. - result = cur.fetchall() + result = cursor.fetchall() match result: case ( [{"number of rows inserted": affected_rows}] | [{"number of rows deleted": affected_rows}] | [{"number of rows updated": affected_rows, "number of multi-joined rows updated": _}] ): - response = Response(RESPONSE_TYPE.OK, affected_rows=affected_rows) + response = OkResponse(affected_rows=affected_rows) case list(): - response = Response( - RESPONSE_TYPE.TABLE, DataFrame(result, columns=[x[0] for x in cur.description]) - ) + response = TableResponse(data=DataFrame(result, columns=[x[0] for x in cursor.description])) case _: # Looks like SnowFlake always returns something in response, so this is suspicious logger.warning("Snowflake did not return any data in response.") - response = Response(RESPONSE_TYPE.OK) + response = OkResponse() + return response except Exception as e: logger.error(f"Error running query: {query} on {self.connection_data.get('database')}, {e}!") - response = Response(RESPONSE_TYPE.ERROR, error_code=0, error_message=str(e)) - - if need_to_close is True: - self.disconnect() + return ErrorResponse(error_code=0, error_message=str(e)) if memory_pool is not None and memory_pool.backend_name == "jemalloc": # This reduce memory consumption, but will slow down next query slightly. @@ -362,9 +346,7 @@ def native_query(self, query: str) -> Response: # and next query processing time may be even lower. memory_pool.release_unused() - return response - - def query(self, query: ASTNode) -> Response: + def query(self, query: ASTNode) -> DataHandlerResponse: """ Executes a SQL query represented by an ASTNode and retrieves the data. @@ -372,7 +354,7 @@ def query(self, query: ASTNode) -> Response: query (ASTNode): An ASTNode representing the SQL query to be executed. Returns: - Response: The response from the `native_query` method, containing the result of the SQL query execution. + DataHandlerResponse: The response from the `native_query` method, containing the result of the SQL query execution. """ query_str = self.renderer.get_string(query, with_failback=True) @@ -402,12 +384,12 @@ def lowercase_columns(self, result, query): result.data_frame = result.data_frame.rename(columns=rename_columns) return result - def get_tables(self) -> Response: + def get_tables(self) -> DataHandlerResponse: """ Retrieves a list of all non-system tables and views in the current schema of the Snowflake account. Returns: - Response: A response object containing the list of tables and views, formatted as per the `Response` class. + DataHandlerResponse: A response object containing the list of tables and views. """ query = """ @@ -418,7 +400,7 @@ def get_tables(self) -> Response: """ return self.native_query(query) - def get_columns(self, table_name) -> Response: + def get_columns(self, table_name) -> DataHandlerResponse: """ Retrieves column details for a specified table in the Snowflake account. @@ -426,7 +408,7 @@ def get_columns(self, table_name) -> Response: table_name (str): The name of the table for which to retrieve column information. Returns: - Response: A response object containing the column details, formatted as per the `Response` class. + DataHandlerResponse: A response object containing the column details. Raises: ValueError: If the 'table_name' is not a valid string. @@ -458,7 +440,7 @@ def get_columns(self, table_name) -> Response: return result - def meta_get_tables(self, table_names: Optional[List[str]] = None) -> Response: + def meta_get_tables(self, table_names: Optional[List[str]] = None) -> DataHandlerResponse: """ Retrieves metadata information about the tables in the Snowflake database to be stored in the data catalog. @@ -466,7 +448,7 @@ def meta_get_tables(self, table_names: Optional[List[str]] = None) -> Response: table_names (list): A list of table names for which to retrieve metadata information. Returns: - Response: A response object containing the metadata information, formatted as per the `Response` class. + DataHandlerResponse: A response object containing the metadata information. """ query = """ SELECT @@ -493,7 +475,7 @@ def meta_get_tables(self, table_names: Optional[List[str]] = None) -> Response: result.data_frame["ROW_COUNT"] = result.data_frame["ROW_COUNT"].astype("Int64") return result - def meta_get_columns(self, table_names: Optional[List[str]] = None) -> Response: + def meta_get_columns(self, table_names: Optional[List[str]] = None) -> DataHandlerResponse: """ Retrieves column metadata for the specified tables (or all tables if no list is provided). @@ -501,7 +483,7 @@ def meta_get_columns(self, table_names: Optional[List[str]] = None) -> Response: table_names (list): A list of table names for which to retrieve column metadata. Returns: - Response: A response object containing the column metadata. + DataHandlerResponse: A response object containing the column metadata. """ query = """ SELECT @@ -529,7 +511,7 @@ def meta_get_columns(self, table_names: Optional[List[str]] = None) -> Response: result = self.native_query(query) return result - def meta_get_column_statistics(self, table_names: Optional[List[str]] = None) -> Response: + def meta_get_column_statistics(self, table_names: Optional[List[str]] = None) -> DataHandlerResponse: """ Retrieves basic column statistics: null %, distinct count. Due to Snowflake limitations, this runs per-table not per-column. @@ -546,11 +528,11 @@ def meta_get_column_statistics(self, table_names: Optional[List[str]] = None) -> columns_result = self.native_query(columns_query) if ( - columns_result.type == RESPONSE_TYPE.ERROR + isinstance(columns_result, ErrorResponse) or columns_result.data_frame is None or columns_result.data_frame.empty ): - return Response(RESPONSE_TYPE.ERROR, error_message="No columns found.") + return ErrorResponse(error_message="No columns found.") columns_df = columns_result.data_frame grouped = columns_df.groupby(["TABLE_SCHEMA", "TABLE_NAME"]) @@ -585,9 +567,13 @@ def meta_get_column_statistics(self, table_names: Optional[List[str]] = None) -> """ try: stats_res = self.native_query(stats_query) - if stats_res.type != RESPONSE_TYPE.TABLE or stats_res.data_frame is None or stats_res.data_frame.empty: + if ( + not isinstance(stats_res, TableResponse) + or stats_res.data_frame is None + or stats_res.data_frame.empty + ): logger.warning( - f"Could not retrieve stats for table {table_name}. Query returned no data or an error: {stats_res.error_message if stats_res.type == RESPONSE_TYPE.ERROR else 'No data'}" + f"Could not retrieve stats for table {table_name}. Query returned no data or an error: {stats_res.error_message if isinstance(stats_res, ErrorResponse) else 'No data'}" ) # Add placeholder stats if query fails or returns empty for _, row in group.iterrows(): @@ -646,11 +632,11 @@ def meta_get_column_statistics(self, table_names: Optional[List[str]] = None) -> ) if not all_stats: - return Response(RESPONSE_TYPE.TABLE, data_frame=pandas.DataFrame()) + return TableResponse(data=pandas.DataFrame()) - return Response(RESPONSE_TYPE.TABLE, data_frame=pandas.DataFrame(all_stats)) + return TableResponse(data=pandas.DataFrame(all_stats)) - def meta_get_primary_keys(self, table_names: Optional[List[str]] = None) -> Response: + def meta_get_primary_keys(self, table_names: Optional[List[str]] = None) -> DataHandlerResponse: """ Retrieves primary key information for the specified tables (or all tables if no list is provided). @@ -658,7 +644,7 @@ def meta_get_primary_keys(self, table_names: Optional[List[str]] = None) -> Resp table_names (list): A list of table names for which to retrieve primary key information. Returns: - Response: A response object containing the primary key information. + DataHandlerResponse: A response object containing the primary key information. """ try: query = """ @@ -666,7 +652,7 @@ def meta_get_primary_keys(self, table_names: Optional[List[str]] = None) -> Resp """ response = self.native_query(query) - if response.type == RESPONSE_TYPE.ERROR and response.error_message: + if isinstance(response, ErrorResponse): logger.error(f"Query error in meta_get_primary_keys: {response.error_message}\nQuery:\n{query}") df = response.data_frame @@ -683,9 +669,9 @@ def meta_get_primary_keys(self, table_names: Optional[List[str]] = None) -> Resp except Exception as e: logger.error(f"Exception in meta_get_primary_keys: {e!r}") - return Response(RESPONSE_TYPE.ERROR, error_message=f"Exception querying primary keys: {e!r}") + return ErrorResponse(error_message=f"Exception querying primary keys: {e!r}") - def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> Response: + def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> DataHandlerResponse: """ Retrieves foreign key information for the specified tables (or all tables if no list is provided). @@ -693,7 +679,7 @@ def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> Resp table_names (list): A list of table names for which to retrieve foreign key information. Returns: - Response: A response object containing the foreign key information. + DataHandlerResponse: A response object containing the foreign key information. """ try: query = """ @@ -701,7 +687,7 @@ def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> Resp """ response = self.native_query(query) - if response.type == RESPONSE_TYPE.ERROR and response.error_message: + if isinstance(response, ErrorResponse): logger.error(f"Query error in meta_get_primary_keys: {response.error_message}\nQuery:\n{query}") df = response.data_frame @@ -725,7 +711,7 @@ def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> Resp except Exception as e: logger.error(f"Exception in meta_get_primary_keys: {e!r}") - return Response(RESPONSE_TYPE.ERROR, error_message=f"Exception querying primary keys: {e!r}") + return ErrorResponse(error_message=f"Exception querying primary keys: {e!r}") def meta_get_handler_info(self, **kwargs: Any) -> str: """ diff --git a/mindsdb/integrations/libs/base.py b/mindsdb/integrations/libs/base.py index 9f7dbe618ff..2757b7ba594 100644 --- a/mindsdb/integrations/libs/base.py +++ b/mindsdb/integrations/libs/base.py @@ -1,15 +1,23 @@ import ast import concurrent.futures +import functools import inspect import textwrap from _ast import AnnAssign, AugAssign -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, get_type_hints, get_args, Union, get_origin import pandas as pd from mindsdb_sql_parser.ast.base import ASTNode from mindsdb.utilities import log -from mindsdb.integrations.libs.response import HandlerResponse, HandlerStatusResponse, RESPONSE_TYPE +from mindsdb.integrations.libs.response import ( + HandlerStatusResponse, + RESPONSE_TYPE, + DataHandlerResponse, + normalize_response, + ErrorResponse, + TableResponse, +) logger = log.getLogger(__name__) @@ -21,6 +29,59 @@ class BaseHandler: broader MindsDB ecosystem via SQL commands. """ + stream_response = False + + def __init_subclass__(cls, **kwargs): + """Automatically wrap handler methods to normalize their responses. + + When a subclass is defined, this method checks if any of the methods + in _methods_to_normalize are overridden and wraps them to convert + legacy HandlerResponse to new response types (TableResponse, OkResponse, + ErrorResponse). + """ + super().__init_subclass__(**kwargs) + + # Methods whose return values should be normalized to new response types + _methods_to_normalize = ( + "native_query", + "query", + "insert", + "get_tables", + "get_columns", + "meta_get_tables", + "meta_get_columns", + "meta_get_column_statistics", + "meta_get_column_statistics_for_table", + "meta_get_primary_keys", + "meta_get_foreign_keys", + ) + for method_name in _methods_to_normalize: + # Only wrap if method is defined directly in this class (not inherited) + if method_name not in cls.__dict__: + continue + + original_method = cls.__dict__[method_name] + + return_type = get_type_hints(original_method).get("return") + if return_type is DataHandlerResponse or ( + get_origin(return_type) is Union and issubclass(get_args(return_type)[0], DataHandlerResponse) + ): + # this is already new style response + continue + + # Skip if already wrapped + if getattr(original_method, "_response_normalized", False): + continue + + # Create wrapper that normalizes response + @functools.wraps(original_method) + def wrapper(self, *args, _orig=original_method, **kwargs): + result = _orig(self, *args, **kwargs) + return normalize_response(result) + + wrapper._response_normalized = True + setattr(cls, method_name, wrapper) + def __init__(self, name: str): """constructor Args: @@ -53,19 +114,19 @@ def check_connection(self) -> HandlerStatusResponse: """ raise NotImplementedError() - def native_query(self, query: Any) -> HandlerResponse: + def native_query(self, query: Any, stream: bool = False, **kwargs) -> DataHandlerResponse: """Receive raw query and act upon it somehow. Args: - query (Any): query in native format (str for sql databases, - etc) - + query (Any): query in native format (str for sql databases, etc) + stream (bool): Whether to stream the results of the query + **kwargs: Additional keyword arguments. Returns: - HandlerResponse + DataHandlerResponse """ raise NotImplementedError() - def query(self, query: ASTNode) -> HandlerResponse: + def query(self, query: ASTNode) -> DataHandlerResponse: """Receive query as AST (abstract syntax tree) and act upon it somehow. Args: @@ -73,30 +134,30 @@ def query(self, query: ASTNode) -> HandlerResponse: of query: SELECT, INSERT, DELETE, etc Returns: - HandlerResponse + DataHandlerResponse """ raise NotImplementedError() - def get_tables(self) -> HandlerResponse: + def get_tables(self) -> DataHandlerResponse: """Return list of entities Return list of entities that will be accesible as tables. Returns: - HandlerResponse: shoud have same columns as information_schema.tables + DataHandlerResponse: shoud have same columns as information_schema.tables (https://dev.mysql.com/doc/refman/8.0/en/information-schema-tables-table.html) Column 'TABLE_NAME' is mandatory, other is optional. """ raise NotImplementedError() - def get_columns(self, table_name: str) -> HandlerResponse: + def get_columns(self, table_name: str) -> DataHandlerResponse: """Returns a list of entity columns Args: table_name (str): name of one of tables returned by self.get_tables() Returns: - HandlerResponse: shoud have same columns as information_schema.columns + DataHandlerResponse: shoud have same columns as information_schema.columns (https://dev.mysql.com/doc/refman/8.0/en/information-schema-columns-table.html) Column 'COLUMN_NAME' is mandatory, other is optional. Hightly recomended to define also 'DATA_TYPE': it should be one of @@ -125,12 +186,12 @@ class MetaDatabaseHandler(DatabaseHandler): def __init__(self, name: str): super().__init__(name) - def meta_get_tables(self, table_names: Optional[List[str]]) -> HandlerResponse: + def meta_get_tables(self, table_names: Optional[List[str]]) -> DataHandlerResponse: """ Returns metadata information about the tables to be stored in the data catalog. Returns: - HandlerResponse: The response should consist of the following columns: + DataHandlerResponse: The response should consist of the following columns: - TABLE_NAME (str): Name of the table. - TABLE_TYPE (str): Type of the table, e.g. 'BASE TABLE', 'VIEW', etc. (optional). - TABLE_SCHEMA (str): Schema of the table (optional). @@ -139,12 +200,12 @@ def meta_get_tables(self, table_names: Optional[List[str]]) -> HandlerResponse: """ raise NotImplementedError() - def meta_get_columns(self, table_names: Optional[List[str]]) -> HandlerResponse: + def meta_get_columns(self, table_names: Optional[List[str]]) -> DataHandlerResponse: """ Returns metadata information about the columns in the tables to be stored in the data catalog. Returns: - HandlerResponse: The response should consist of the following columns: + DataHandlerResponse: The response should consist of the following columns: - TABLE_NAME (str): Name of the table. - COLUMN_NAME (str): Name of the column. - DATA_TYPE (str): Data type of the column, e.g. 'VARCHAR', 'INT', etc. @@ -154,13 +215,13 @@ def meta_get_columns(self, table_names: Optional[List[str]]) -> HandlerResponse: """ raise NotImplementedError() - def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> HandlerResponse: + def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> DataHandlerResponse: """ Returns metadata statisical information about the columns in the tables to be stored in the data catalog. Either this method should be overridden in the handler or `meta_get_column_statistics_for_table` should be implemented. Returns: - HandlerResponse: The response should consist of the following columns: + DataHandlerResponse: The response should consist of the following columns: - TABLE_NAME (str): Name of the table. - COLUMN_NAME (str): Name of the column. - MOST_COMMON_VALUES (List[str]): Most common values in the column (optional). @@ -207,17 +268,14 @@ def meta_get_column_statistics(self, table_names: Optional[List[str]]) -> Handle if not results: logger.warning("No column statistics could be retrieved for the specified tables.") - return HandlerResponse(RESPONSE_TYPE.ERROR, error_message="No column statistics could be retrieved.") - return HandlerResponse( - RESPONSE_TYPE.TABLE, pd.concat(results, ignore_index=True) if results else pd.DataFrame() - ) - + return ErrorResponse(error_message="No column statistics could be retrieved.") + return TableResponse(data=pd.concat(results, ignore_index=True) if results else pd.DataFrame()) else: raise NotImplementedError() def meta_get_column_statistics_for_table( self, table_name: str, column_names: Optional[List[str]] = None - ) -> HandlerResponse: + ) -> DataHandlerResponse: """ Returns metadata statistical information about the columns in a specific table to be stored in the data catalog. Either this method should be implemented in the handler or `meta_get_column_statistics` should be overridden. @@ -227,7 +285,7 @@ def meta_get_column_statistics_for_table( column_names (Optional[List[str]]): List of column names to retrieve statistics for. If None, statistics for all columns will be returned. Returns: - HandlerResponse: The response should consist of the following columns: + DataHandlerResponse: The response should consist of the following columns: - TABLE_NAME (str): Name of the table. - COLUMN_NAME (str): Name of the column. - MOST_COMMON_VALUES (List[str]): Most common values in the column (optional). @@ -239,12 +297,12 @@ def meta_get_column_statistics_for_table( """ pass - def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> HandlerResponse: + def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> DataHandlerResponse: """ Returns metadata information about the primary keys in the tables to be stored in the data catalog. Returns: - HandlerResponse: The response should consist of the following columns: + DataHandlerResponse: The response should consist of the following columns: - TABLE_NAME (str): Name of the table. - COLUMN_NAME (str): Name of the column that is part of the primary key. - ORDINAL_POSITION (int): Position of the column in the primary key (optional). @@ -252,12 +310,12 @@ def meta_get_primary_keys(self, table_names: Optional[List[str]]) -> HandlerResp """ raise NotImplementedError() - def meta_get_foreign_keys(self, table_names: Optional[List[str]]) -> HandlerResponse: + def meta_get_foreign_keys(self, table_names: Optional[List[str]]) -> DataHandlerResponse: """ Returns metadata information about the foreign keys in the tables to be stored in the data catalog. Returns: - HandlerResponse: The response should consist of the following columns: + DataHandlerResponse: The response should consist of the following columns: - PARENT_TABLE_NAME (str): Name of the parent table. - PARENT_COLUMN_NAME (str): Name of the parent column that is part of the foreign key. - CHILD_TABLE_NAME (str): Name of the child table. diff --git a/mindsdb/integrations/libs/keyword_search_base.py b/mindsdb/integrations/libs/keyword_search_base.py index 6a1cfdd9b80..d515764ba2a 100644 --- a/mindsdb/integrations/libs/keyword_search_base.py +++ b/mindsdb/integrations/libs/keyword_search_base.py @@ -36,6 +36,6 @@ def keyword_select( conditions (List[FilterCondition]): conditions to select Returns: - HandlerResponse + pd.DataFrame """ raise NotImplementedError() diff --git a/mindsdb/integrations/libs/ml_exec_base.py b/mindsdb/integrations/libs/ml_exec_base.py index 96eca4a033a..abac27d75de 100644 --- a/mindsdb/integrations/libs/ml_exec_base.py +++ b/mindsdb/integrations/libs/ml_exec_base.py @@ -7,7 +7,7 @@ normally associated with a DB handler (e.g. `native_query`, `get_tables`), as well as other ML-specific behaviors, like `learn()` or `predict()`. Note that while these still have to be implemented at the engine level, the burden on that class is lesser given that it only needs to return a pandas DataFrame. It's this class that will take said - output and format it into the HandlerResponse instance that MindsDB core expects. + output and format it into the DataHandlerResponse instance that MindsDB core expects. - `learn_process` method: handles async dispatch of the `learn` method in an engine, as well as registering all models inside of the internal MindsDB registry. diff --git a/mindsdb/integrations/libs/response.py b/mindsdb/integrations/libs/response.py index aa39ce4c2c6..c559e875869 100644 --- a/mindsdb/integrations/libs/response.py +++ b/mindsdb/integrations/libs/response.py @@ -1,14 +1,17 @@ import sys -from typing import Callable +from abc import ABC +from typing import Callable, Generator, ClassVar from dataclasses import dataclass, fields import numpy import pandas +import psutil from mindsdb.utilities import log from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE from mindsdb_sql_parser.ast import ASTNode +from mindsdb.utilities.types.column import Column logger = log.getLogger(__name__) @@ -40,7 +43,456 @@ class _INFORMATION_SCHEMA_COLUMNS_NAMES: INF_SCHEMA_COLUMNS_NAMES_SET = set(f.name for f in fields(INF_SCHEMA_COLUMNS_NAMES)) +class HandlerStatusResponse: + def __init__( + self, + success: bool = True, + error_message: str = None, + redirect_url: str = None, + copy_storage: str = None, + ) -> None: + self.success = success + self.error_message = error_message + self.redirect_url = redirect_url + self.copy_storage = copy_storage + + def to_json(self): + data = {"success": self.success, "error": self.error_message} + if self.redirect_url is not None: + data["redirect_url"] = self.redirect_url + if self.copy_storage is not None: + data["copy_storage"] = self.copy_storage + return data + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"success={self.success}, " + f"error={self.error_message}, " + f"redirect_url={self.redirect_url}, " + f"copy_storage={self.copy_storage})" + ) + + +class DataHandlerResponse(ABC): + """Base class for all data handler responses.""" + + type: ClassVar[str] + + @property + def resp_type(self): + # For back compatibility with old code, use the type attribute instead of resp_type + return self.type + + +class ErrorResponse(DataHandlerResponse): + """Response for error cases. + + Attributes: + type: RESPONSE_TYPE.ERROR + error_code: int + error_message: str | None + is_expected_error: bool + exception: Exception | None + """ + + type: ClassVar[str] = RESPONSE_TYPE.ERROR + error_code: int + error_message: str | None + is_expected_error: bool + exception: Exception | None + + def __init__(self, error_code: int = 0, error_message: str | None = None, is_expected_error: bool = False): + self.error_code = error_code + self.error_message = error_message + self.is_expected_error = is_expected_error + self.exception = None + current_exception = sys.exc_info() + if current_exception[0] is not None: + self.exception = current_exception[1] + + def to_columns_table_response(self, map_type_fn: Callable) -> None: + raise ValueError( + f"Cannot convert {self.type} to {RESPONSE_TYPE.COLUMNS_TABLE}, the error is: {self.error_message}" + ) + + +class OkResponse(DataHandlerResponse): + """Response for successful cases without data (e.g. CREATE TABLE, DROP TABLE, etc.). + + Attributes: + type: RESPONSE_TYPE.OK + affected_rows: int - how many rows were affected by the query + """ + + type: ClassVar[str] = RESPONSE_TYPE.OK + affected_rows: int + + def __init__(self, affected_rows: int = None): + self.affected_rows = affected_rows + + +def _safe_pandas_concat(pieces: list[pandas.DataFrame]) -> pandas.DataFrame: + """Safely concatenates multiple pandas DataFrames while checking available memory. + If the estimated memory required for concatenation (with a safety multiplier of 2.5x) + exceeds the available memory, it raises a MemoryError. + + Args: + pieces (list[pandas.DataFrame]): A list of pandas DataFrames to concatenate. + + Returns: + pandas.DataFrame: The concatenated DataFrame. + + Raises: + MemoryError: If there is insufficient memory to perform the concatenation safely. + """ + if len(pieces) == 1: + return pieces[0] + available_memory_kb = psutil.virtual_memory().available >> 10 + pieces_size_kb = sum([(x.memory_usage(index=True, deep=True).sum() >> 10) for x in pieces]) + if (pieces_size_kb * 2.5) > available_memory_kb: + raise MemoryError() + return pandas.concat(pieces) + + +class TableResponse(DataHandlerResponse): + """Response for successful cases with data (e.g. SELECT, SHOW, etc.). + + Attributes: + type: RESPONSE_TYPE.TABLE | RESPONSE_TYPE.COLUMNS_TABLE - type of data in the response + affected_rows: int | None - how many rows were affected by the query + data_generator: Generator[pandas.DataFrame, None, None] | None - generator of data for lazy loading + _columns: list[Column] | None - list of columns + _data: pandas.DataFrame | None - loaded data + _fetched: bool - if data was already fetched (data_generator is consumed) + _invalid: bool - if data has already been fetched and cannot be iterated over + _last_data_piece: pandas.DataFrame | None - last data piece fetched + rows_fetched: int - how many rows were fetched + """ + + type: str + affected_rows: int | None + _data_generator: Generator[pandas.DataFrame, None, None] | None + _columns: list[Column] | None + _data: pandas.DataFrame | None + _fetched: bool + _invalid: bool + _last_data_piece: pandas.DataFrame | None + rows_fetched: int + + def __init__( + self, + data: pandas.DataFrame | None = None, + data_generator: Generator[pandas.DataFrame, None, None] | None = None, + affected_rows: int | None = None, + columns: list[Column] = None, + ): + """ + Either data and/or data_generator must be provided. + Args: + data (pandas.DataFrame): initial data + data_generator (Generator[pandas.DataFrame, None, None]): generator of data + affected_rows (int): total data rowcount - can be None depending on the handler + NOTE: name affected_rows for compatibility with OKResponse + columns (list[Column]): list of columns + """ + self.type = RESPONSE_TYPE.TABLE + self._data_generator = data_generator + self._columns = columns + self.affected_rows = affected_rows + self._data = data + self._fetched = False if data_generator else True + self._invalid = False + self._last_data_piece = None + self.rows_fetched = len(data) if data is not None else 0 + + @property + def data_generator(self) -> Generator[pandas.DataFrame, None, None]: + return self._data_generator + + @data_generator.setter + def data_generator(self, value): + self._fetched = False if value else True + self._data_generator = value + + def fetchall(self) -> pandas.DataFrame: + """Fetch all data and store it in the _data attribute. + + Returns: + pandas.DataFrame: Data frame. + """ + self._raise_if_invalid() + if self._data_generator is None or self._fetched: + return self._data + + pieces = list(self._iterate_with_memory_check()) + if self._data is None: + if len(pieces) == 1: + self._data = pieces[0] + elif len(pieces) == 0: + self._data = pandas.DataFrame([], columns=[column.name for column in self._columns]) + else: + self._data = _safe_pandas_concat(pieces) + elif len(pieces) > 0: + self._data = _safe_pandas_concat([self._data, *pieces]) + + self._fetched = True + self._data_generator = None + + return self._data + + def _raise_if_low_memory(self) -> None: + """Check if there is enough available memory to load the next data chunk. + + Estimates the memory required for the next chunk based on the size of the last + fetched chunk. If `affected_rows` (fetched rows) is known, the estimate is capped at the + number of remaining rows (but no more than one chunk). Otherwise, assumes the next chunk will + be the same size as the previous one. + + Does nothing when no data has been fetched yet. + + Raises: + MemoryError: If estimated memory for the next chunk exceeds available memory. + """ + if self._last_data_piece is None or len(self._last_data_piece) == 0: + return + + data_piece_size_kb = self._last_data_piece.memory_usage(index=True, deep=True).sum() >> 10 + if isinstance(self.affected_rows, int) and self.affected_rows > 0: + row_size_kb = data_piece_size_kb / len(self._last_data_piece) + rows_expected = min(self.affected_rows - self.rows_fetched, len(self._last_data_piece)) + if rows_expected > 0: + available_memory_kb = psutil.virtual_memory().available >> 10 + if available_memory_kb < (row_size_kb * rows_expected * 1.1): + raise MemoryError( + f"Not enough memory to load remaining data. " + f"Available: {available_memory_kb}KB, estimated need: {int(row_size_kb * rows_expected * 1.1)}KB" + ) + else: + # assume that next piece is the same size + available_memory_kb = psutil.virtual_memory().available >> 10 + if available_memory_kb < (data_piece_size_kb * 1.1): + raise MemoryError( + f"Not enough memory to load remaining data. " + f"Available: {available_memory_kb}KB, estimated need: {int(data_piece_size_kb * 1.1)}KB" + ) + + def _iterate_with_memory_check(self) -> Generator[pandas.DataFrame, None, None]: + """Iterate over `_data_generator` with memory safety checks. + + Yields: + pandas.DataFrame: The next chunk from the underlying data generator. + + Raises: + MemoryError: Propagated from `_raise_if_low_memory` if available + memory is insufficient for the next chunk. + """ + if self._data_generator is None: + return + + self._raise_if_low_memory() + + for piece in self._data_generator: + self._last_data_piece = piece + self.rows_fetched += len(piece) + yield piece + self._raise_if_low_memory() + + def fetchmany(self) -> pandas.DataFrame | None: + """Fetch one piece of data and store it in the _data attribute. + + Returns: + pandas.DataFrame: Data frame, piece of data. + """ + self._raise_if_invalid() + try: + piece = next(self._iterate_with_memory_check()) + if self._data is None: + self._data = piece + else: + self._data = _safe_pandas_concat([self._data, piece]) + except StopIteration: + self._fetched = True + self._data_generator = None + return None + return piece + + def iterate_no_save(self) -> Generator[pandas.DataFrame, None, None]: + """Iterate over the data and yield each piece of data. Do not save the data to the _data attribute. + NOTE: do it only once, before return result to the user + + Returns: + Generator[pandas.DataFrame, None, None]: Generator of data frames. + """ + self._raise_if_invalid() + if self._data is not None: + yield self._data + if self._data_generator: + self._invalid = True + for piece in self._iterate_with_memory_check(): + yield piece + + def _raise_if_invalid(self): + if self._invalid: + raise ValueError("Data has already been fetched and cannot be iterated over.") + + @property + def data_frame(self) -> pandas.DataFrame: + """Get the data frame. Represents the entire dataset. + + Returns: + pandas.DataFrame: Data frame. + """ + self.fetchall() + return self._data + + @data_frame.setter + def data_frame(self, value): + """for back compatibility""" + self._data = value + + @property + def columns(self) -> list[Column]: + """Get the columns. + + Returns: + list[Column]: List of columns. + """ + self._resolve_columns() + return self._columns + + def _resolve_columns(self): + if self._columns is not None: + return + self.fetchall() + self._columns = [Column(name=c) for c in self._data.columns] + + def set_columns_attrs(self, table_name: str | None, table_alias: str | None, database: str | None): + """Set the attributes of the columns. + + Args: + table_name (str | None): Table name. + table_alias (str | None): Table alias. + database (str | None): Database name. + """ + self._resolve_columns() + for column in self._columns: + if table_name: + column.table_name = table_name + if table_alias: + column.table_alias = table_alias + if database: + column.database = database + + def to_columns_table_response(self, map_type_fn: Callable) -> None: + """Transform the response to a `columns table` response. + NOTE: original dataframe will be mutated + + Args: + map_type_fn (Callable): Function to map the data type to the MySQL data type. + """ + if self.type == RESPONSE_TYPE.COLUMNS_TABLE: + return + if self.type != RESPONSE_TYPE.TABLE: + raise ValueError( + f"Cannot convert handler response with type '{self.type}' to '{RESPONSE_TYPE.COLUMNS_TABLE}'" + ) + + self.fetchall() + self._resolve_columns() + self.type = RESPONSE_TYPE.COLUMNS_TABLE + + if self._data is None: + return + self._data.columns = [name.upper() for name in self._data.columns] + self._data[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE] = self._data[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE].apply( + map_type_fn + ) + + # region validate df + current_columns_set = set(self._data.columns) + if INF_SCHEMA_COLUMNS_NAMES_SET != current_columns_set: + raise ValueError(f"Columns set for INFORMATION_SCHEMA.COLUMNS is wrong: {list(current_columns_set)}") + # endregion + + self._data = self._data.astype( + { + INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME: "string", + INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE: "string", + INF_SCHEMA_COLUMNS_NAMES.ORDINAL_POSITION: "Int32", + INF_SCHEMA_COLUMNS_NAMES.COLUMN_DEFAULT: "string", + INF_SCHEMA_COLUMNS_NAMES.IS_NULLABLE: "string", + INF_SCHEMA_COLUMNS_NAMES.CHARACTER_MAXIMUM_LENGTH: "Int32", + INF_SCHEMA_COLUMNS_NAMES.CHARACTER_OCTET_LENGTH: "Int32", + INF_SCHEMA_COLUMNS_NAMES.NUMERIC_PRECISION: "Int32", + INF_SCHEMA_COLUMNS_NAMES.NUMERIC_SCALE: "Int32", + INF_SCHEMA_COLUMNS_NAMES.DATETIME_PRECISION: "Int32", + INF_SCHEMA_COLUMNS_NAMES.CHARACTER_SET_NAME: "string", + INF_SCHEMA_COLUMNS_NAMES.COLLATION_NAME: "string", + } + ) + self._data.replace([numpy.nan, pandas.NA], None, inplace=True) + + +def normalize_response(response) -> TableResponse | OkResponse | ErrorResponse: + """Convert legacy HandlerResponse to new response types. + + If response is already a new type (TableResponse, OkResponse, ErrorResponse), + return it as-is. If response is a legacy HandlerResponse, convert it based + on its resp_type. + + Args: + response: Either a new response type or legacy HandlerResponse + + Returns: + TableResponse | OkResponse | ErrorResponse: Normalized response + """ + # Already new format - return as-is + if isinstance(response, (TableResponse, OkResponse, ErrorResponse)): + return response + + # Legacy HandlerResponse - convert based on type + if isinstance(response, HandlerResponse): + if response.resp_type == RESPONSE_TYPE.ERROR: + err = ErrorResponse( + error_code=response.error_code, + error_message=response.error_message, + is_expected_error=response.is_expected_error, + ) + err.exception = response.exception + return err + + if response.resp_type == RESPONSE_TYPE.OK: + return OkResponse(affected_rows=response.affected_rows) + + # TABLE or COLUMNS_TABLE + if response.data_frame is not None: + columns = list(response.data_frame.columns) + else: + columns = [] + + mysql_types = response.mysql_types + if mysql_types is None: + mysql_types = [None] * len(columns) + + return TableResponse( + data=response.data_frame, + columns=[ + Column(name=column_name, type=mysql_type) for column_name, mysql_type in zip(columns, mysql_types) + ], + data_generator=iter([]), # empty generator for legacy responses + ) + + # Unknown type - return as-is (shouldn't happen normally) + return response + + +# ! deprecated class HandlerResponse: + """Legacy response class for compatibility with old code. + NOTE: do not use this class directly, use DataHandlerResponse instead + """ + def __init__( self, resp_type: RESPONSE_TYPE, @@ -142,28 +594,3 @@ def __repr__(self): self.error_message, self.affected_rows, ) - - -class HandlerStatusResponse: - def __init__( - self, - success: bool = True, - error_message: str = None, - redirect_url: str = None, - copy_storage: str = None, - ) -> None: - self.success = success - self.error_message = error_message - self.redirect_url = redirect_url - self.copy_storage = copy_storage - - def to_json(self): - data = {"success": self.success, "error": self.error_message} - if self.redirect_url is not None: - data["redirect_url"] = self.redirect_url - return data - - def __repr__(self): - return f"{self.__class__.__name__}: success={self.success},\ - error={self.error_message},\ - redirect_url={self.redirect_url}" diff --git a/mindsdb/integrations/libs/vectordatabase_handler.py b/mindsdb/integrations/libs/vectordatabase_handler.py index 4f332e53028..088ae9844ad 100644 --- a/mindsdb/integrations/libs/vectordatabase_handler.py +++ b/mindsdb/integrations/libs/vectordatabase_handler.py @@ -22,11 +22,10 @@ ) from mindsdb_sql_parser.ast.base import ASTNode -from mindsdb.integrations.libs.response import RESPONSE_TYPE, HandlerResponse -from mindsdb.utilities import log +from mindsdb.integrations.libs.response import DataHandlerResponse, OkResponse, TableResponse from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs - from mindsdb.integrations.utilities.query_traversal import query_traversal +from mindsdb.utilities import log from .base import BaseHandler LOG = log.getLogger(__name__) @@ -521,7 +520,7 @@ def dispatch_select( handler_engine = self.__class__.name raise VectorHandlerException(f"Error in {handler_engine} database: {e}") - def _dispatch(self, query: ASTNode) -> HandlerResponse: + def _dispatch(self, query: ASTNode) -> DataHandlerResponse: """ Parse and Dispatch query to the appropriate method. """ @@ -536,14 +535,14 @@ def _dispatch(self, query: ASTNode) -> HandlerResponse: if type(query) in dispatch_router: resp = dispatch_router[type(query)](query) if resp is not None: - return HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=resp) + return TableResponse(data=resp) else: - return HandlerResponse(resp_type=RESPONSE_TYPE.OK) + return OkResponse() else: raise NotImplementedError(f"Query type {type(query)} not implemented.") - def query(self, query: ASTNode) -> HandlerResponse: + def query(self, query: ASTNode) -> DataHandlerResponse: """ Receive query as AST (abstract syntax tree) and act upon it somehow. @@ -552,11 +551,11 @@ def query(self, query: ASTNode) -> HandlerResponse: of query: SELECT, INSERT, DELETE, etc Returns: - HandlerResponse + DataHandlerResponse """ return self._dispatch(query) - def create_table(self, table_name: str, if_not_exists=True) -> HandlerResponse: + def create_table(self, table_name: str, if_not_exists=True) -> DataHandlerResponse: """Create table Args: @@ -564,11 +563,11 @@ def create_table(self, table_name: str, if_not_exists=True) -> HandlerResponse: if_not_exists (bool): if True, do nothing if table exists Returns: - HandlerResponse + DataHandlerResponse """ raise NotImplementedError() - def drop_table(self, table_name: str, if_exists=True) -> HandlerResponse: + def drop_table(self, table_name: str, if_exists=True) -> DataHandlerResponse: """Drop table Args: @@ -576,11 +575,11 @@ def drop_table(self, table_name: str, if_exists=True) -> HandlerResponse: if_exists (bool): if True, do nothing if table does not exist Returns: - HandlerResponse + DataHandlerResponse """ raise NotImplementedError() - def insert(self, table_name: str, data: pd.DataFrame) -> HandlerResponse: + def insert(self, table_name: str, data: pd.DataFrame) -> DataHandlerResponse: """Insert data into table Args: @@ -589,11 +588,11 @@ def insert(self, table_name: str, data: pd.DataFrame) -> HandlerResponse: columns (List[str]): columns to insert Returns: - HandlerResponse + DataHandlerResponse """ raise NotImplementedError() - def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> HandlerResponse: + def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> DataHandlerResponse: """Delete data from table Args: @@ -601,7 +600,7 @@ def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> H conditions (List[FilterCondition]): conditions to delete Returns: - HandlerResponse + DataHandlerResponse """ raise NotImplementedError() @@ -612,7 +611,7 @@ def select( conditions: List[FilterCondition] = None, offset: int = None, limit: int = None, - ) -> pd.DataFrame: + ) -> DataHandlerResponse: """Select data from table Args: @@ -621,18 +620,15 @@ def select( conditions (List[FilterCondition]): conditions to select Returns: - HandlerResponse + DataHandlerResponse """ raise NotImplementedError() - def get_columns(self, table_name: str) -> HandlerResponse: + def get_columns(self, table_name: str) -> TableResponse: # return a fixed set of columns data = pd.DataFrame(self.SCHEMA) data.columns = ["COLUMN_NAME", "DATA_TYPE"] - return HandlerResponse( - resp_type=RESPONSE_TYPE.TABLE, - data_frame=data, - ) + return TableResponse(data=data) def hybrid_search( self, diff --git a/mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py b/mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py index 34d7a1e0b89..18a7a6c6e0c 100644 --- a/mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py +++ b/mindsdb/integrations/utilities/rag/retrievers/sql_retriever.py @@ -11,7 +11,7 @@ from mindsdb.interfaces.knowledge_base.preprocessing.document_types import SimpleDocument from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE -from mindsdb.integrations.libs.response import HandlerResponse +from mindsdb.integrations.libs.response import ErrorResponse, DataHandlerResponse from mindsdb.integrations.libs.vectordatabase_handler import ( DistanceFunction, VectorStoreHandler, @@ -746,7 +746,7 @@ def _generate_filter(self, prompt: str, query: str) -> MetadataFilter: def _generate_metadata_filters( self, query: str, ranked_database_schema - ) -> Union[List[AblativeMetadataFilter], HandlerResponse]: + ) -> Union[List[AblativeMetadataFilter], DataHandlerResponse]: """Generate metadata filters using LLM""" metadata_filter_list = [] @@ -813,7 +813,7 @@ def _generate_metadata_filters( f"LLM failed to generate structured metadata filters: {e}", exc_info=logger.isEnabledFor(logging.DEBUG), ) - return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e)) + return ErrorResponse(error_message=str(e)) else: metadata_filter = AblativeMetadataFilter( attribute=column_schema.column, @@ -832,7 +832,7 @@ def _prepare_and_execute_query( ranked_database_schema: DatabaseSchema, metadata_filters: List[AblativeMetadataFilter], embeddings_str: str, - ) -> HandlerResponse: + ) -> DataHandlerResponse: try: checked_sql_query = self._prepare_pgvector_query(ranked_database_schema, metadata_filters) checked_sql_query_with_embeddings = checked_sql_query.format(embeddings=embeddings_str) @@ -842,7 +842,7 @@ def _prepare_and_execute_query( f"Failed to prepare and execute SQL query from structured metadata: {e}", exc_info=logger.isEnabledFor(logging.DEBUG), ) - return HandlerResponse(RESPONSE_TYPE.ERROR, error_message=str(e)) + return ErrorResponse(error_message=str(e)) def _get_relevant_documents(self, query: str, *, run_manager: Optional[Any] = None) -> List[Any]: # Rewrite query to be suitable for retrieval. diff --git a/mindsdb/interfaces/database/log.py b/mindsdb/interfaces/database/log.py index bda24fa9f6b..4a3b9a15af9 100644 --- a/mindsdb/interfaces/database/log.py +++ b/mindsdb/interfaces/database/log.py @@ -9,13 +9,14 @@ from mindsdb_sql_parser.utils import JoinType from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender -from mindsdb.integrations.utilities.query_traversal import query_traversal from mindsdb.utilities.functions import resolve_table_identifier -from mindsdb.api.executor.utilities.sql import get_query_tables from mindsdb.utilities.exception import EntityNotExistsError -import mindsdb.interfaces.storage.db as db from mindsdb.utilities.context import context as ctx -from mindsdb.api.executor.datahub.classes.response import DataHubResponse +from mindsdb.utilities.types.column import Column +from mindsdb.integrations.utilities.query_traversal import query_traversal +from mindsdb.integrations.libs.response import TableResponse +import mindsdb.interfaces.storage.db as db +from mindsdb.api.executor.utilities.sql import get_query_tables from mindsdb.api.executor.datahub.classes.tables_row import ( TABLES_ROW_TYPE, TablesRow, @@ -228,7 +229,7 @@ def get_tables_rows(self) -> List[TablesRow]: for table_name in self._tables.keys() ] - def query(self, query: Select = None, native_query: str = None, session=None) -> DataHubResponse: + def query(self, query: Select = None, native_query: str = None, session=None) -> TableResponse: if native_query is not None: if query is not None: raise Exception("'query' and 'native_query' arguments can not be used together") @@ -290,6 +291,5 @@ def check_columns(node, is_table, **kwargs): df[df_column_name] = df[df_column_name].astype(column_type) # endregion - columns_info = [{"name": k, "type": v} for k, v in df.dtypes.items()] - - return DataHubResponse(data_frame=df, columns=columns_info) + columns = [Column(name=k, dtype=v) for k, v in df.dtypes.items()] + return TableResponse(data=df, columns=columns, affected_rows=0) diff --git a/mindsdb/interfaces/jobs/jobs_controller.py b/mindsdb/interfaces/jobs/jobs_controller.py index 31382daedf1..5c85372ffb1 100644 --- a/mindsdb/interfaces/jobs/jobs_controller.py +++ b/mindsdb/interfaces/jobs/jobs_controller.py @@ -16,6 +16,7 @@ from mindsdb.interfaces.database.projects import ProjectController from mindsdb.interfaces.query_context.context_controller import query_context_controller from mindsdb.interfaces.database.log import LogDBController +from mindsdb.integrations.libs.response import TableResponse from mindsdb.utilities import log @@ -346,9 +347,9 @@ def get_history(self, name: str, project_name: str) -> List[dict]: ], ), ) - response = logs_db_controller.query(query) + response: TableResponse = logs_db_controller.query(query) - names = [i["name"] for i in response.columns] + names = [i.name for i in response.columns] return response.data_frame[names].to_dict(orient="records") diff --git a/mindsdb/interfaces/query_context/context_controller.py b/mindsdb/interfaces/query_context/context_controller.py index 97a1ec83189..08188c8ed66 100644 --- a/mindsdb/interfaces/query_context/context_controller.py +++ b/mindsdb/interfaces/query_context/context_controller.py @@ -1,9 +1,9 @@ -from typing import List, Optional, Iterable import pickle import datetime as dt +from typing import List, Optional, Iterable -from sqlalchemy.orm.attributes import flag_modified import pandas as pd +from sqlalchemy.orm.attributes import flag_modified from mindsdb_sql_parser import Select, Star, OrderBy @@ -17,7 +17,6 @@ ) from mindsdb.integrations.utilities.query_traversal import query_traversal from mindsdb.utilities.cache import get_cache - from mindsdb.interfaces.storage import db from mindsdb.utilities.context import context as ctx from mindsdb.utilities.config import config @@ -70,14 +69,14 @@ def get_partitions(self, dn, step_call, query: Select) -> Iterable: :param query: AST query to execute :return: generator with query results """ - if hasattr(dn, "has_support_stream") and dn.has_support_stream(): + if dn.has_support_stream(): query2 = self.get_partition_query(step_call.current_step_num, query, stream=True) - for df in dn.query_stream(query2, fetch_size=self.batch_size): + response = dn.query(query=query2, session=step_call.session) + for df in response.iterate_no_save(): max_track_value = self.get_max_track_value(df) yield df self.set_progress(max_track_value=max_track_value) - else: while True: query2 = self.get_partition_query(step_call.current_step_num, query, stream=False) @@ -457,7 +456,7 @@ def _get_init_last_values(self, l_query: LastQuery, dn, session) -> dict: idx = None for i, col in enumerate(columns_info): - if col["name"].upper() == info["column_name"].upper(): + if col.name.upper() == info["column_name"].upper(): idx = i break diff --git a/mindsdb/utilities/config.py b/mindsdb/utilities/config.py index b660b31cb79..7aa74f2b1e7 100644 --- a/mindsdb/utilities/config.py +++ b/mindsdb/utilities/config.py @@ -215,6 +215,9 @@ def __new__(cls, *args, **kwargs) -> "Config": "data_catalog": { "enabled": False, }, + "data_stream": { + "fetch_size": 10000, + }, "byom": { "enabled": False, }, diff --git a/mindsdb/utilities/types/__init__.py b/mindsdb/utilities/types/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindsdb/utilities/types/column.py b/mindsdb/utilities/types/column.py new file mode 100644 index 00000000000..e8d258468d3 --- /dev/null +++ b/mindsdb/utilities/types/column.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass, field, MISSING + +from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE + + +@dataclass(kw_only=True, slots=True) +class Column: + name: str = field(default=MISSING) + alias: str | None = None + table_name: str | None = None + table_alias: str | None = None + type: MYSQL_DATA_TYPE | None = None + database: str | None = None + flags: dict = None + charset: str | None = None + original_type: str | None = None + dtype: str | None = None + + def __post_init__(self): + if self.alias is None: + self.alias = self.name + if self.table_alias is None: + self.table_alias = self.table_name + + def get_hash_name(self, prefix): + table_name = self.table_name if self.table_alias is None else self.table_alias + name = self.name if self.alias is None else self.alias + + name = f"{prefix}_{table_name}_{name}" + return name diff --git a/tests/unit/api/http/test_sql_query.py b/tests/unit/api/http/test_sql_query.py new file mode 100644 index 00000000000..b40096ecdcc --- /dev/null +++ b/tests/unit/api/http/test_sql_query.py @@ -0,0 +1,145 @@ +""" +Tests for POST /sql/query endpoint with different response_format values: +1. DEFAULT (None) - returns JSON response +2. SSE ("sse") - returns Server-Sent Events stream +3. JSONLINES ("jsonlines") - returns JSON Lines stream +""" + +import json +from http import HTTPStatus +from unittest.mock import patch, MagicMock + +import pandas as pd + +from mindsdb.api.executor.data_types.sql_answer import SQLAnswer +from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE +from mindsdb.api.executor.sql_query.result_set import ResultSet +from mindsdb.utilities.types.column import Column + + +def create_mock_sql_answer(): + """Create a mock SQLAnswer with table data for testing.""" + columns = [ + Column(name="id", alias="id"), + Column(name="name", alias="name"), + Column(name="value", alias="value"), + ] + + df = pd.DataFrame( + [ + [1, "test1", 100], + [2, "test2", 200], + [3, "test3", 300], + ] + ) + + result_set = ResultSet(columns=columns, df=df) + + return SQLAnswer( + resp_type=RESPONSE_TYPE.TABLE, + result_set=result_set, + ) + + +def check_response(response_data: dict): + # Check response structure for default format + assert response_data["type"] == "table" + assert "data" in response_data + assert "column_names" in response_data + assert "context" in response_data + + # Check data content + assert response_data["column_names"] == ["id", "name", "value"] + assert len(response_data["data"]) == 3 + assert response_data["data"][0] == [1, "test1", 100] + assert response_data["data"][1] == [2, "test2", 200] + assert response_data["data"][2] == [3, "test3", 300] + + +def setup_mock_proxy(mock_proxy_class): + """Configure mock proxy with default behavior.""" + mock_proxy = MagicMock() + mock_proxy_class.return_value = mock_proxy + mock_proxy.process_query.return_value = create_mock_sql_answer() + mock_proxy.get_context.return_value = {} + return mock_proxy + + +class TestSQLQueryResponseFormat: + @patch("mindsdb.api.http.namespaces.sql.FakeMysqlProxy") + def test_query_default_format(self, mock_proxy_class, client): + """Test POST /sql/query with default response format (no response_format parameter).""" + setup_mock_proxy(mock_proxy_class) + + response = client.post( + "/api/sql/query", + json={"query": "SELECT * FROM table"}, + ) + + assert response.status_code == HTTPStatus.OK + response_data = response.json + check_response(response_data) + + @patch("mindsdb.api.http.namespaces.sql.FakeMysqlProxy") + def test_query_sse_format(self, mock_proxy_class, client): + """Test POST /sql/query with SSE response format (response_format="sse").""" + setup_mock_proxy(mock_proxy_class) + + response = client.post( + "/api/sql/query", + json={ + "query": "SELECT * FROM table", + "response_format": "sse", + }, + ) + + assert response.status_code == HTTPStatus.OK + assert "text/event-stream" in response.content_type + + # Parse SSE response and build unified response dict + response_text = response.get_data(as_text=True) + lines = [line.replace("data: ", "") for line in response_text.split("\n") if line.startswith("data: ")] + + assert len(lines) > 1 + header = json.loads(lines[0]) + data_rows = json.loads(lines[1]) + + response_data = { + "type": header["type"], + "column_names": header["column_names"], + "data": data_rows, + "context": {}, + } + check_response(response_data) + + @patch("mindsdb.api.http.namespaces.sql.FakeMysqlProxy") + def test_query_jsonlines_format(self, mock_proxy_class, client): + """Test POST /sql/query with JSONLINES response format (response_format="jsonlines").""" + setup_mock_proxy(mock_proxy_class) + + response = client.post( + "/api/sql/query", + json={ + "query": "SELECT * FROM table", + "response_format": "jsonlines", + }, + ) + + assert response.status_code == HTTPStatus.OK + assert response.content_type == "application/jsonlines" + + # Parse JSONLINES response and build unified response dict + response_text = response.get_data(as_text=True) + lines = [line for line in response_text.split("\n") if line.strip()] + + assert len(lines) > 1 + header = json.loads(lines[0]) + data_rows = json.loads(lines[1]) + + response_data = { + "type": header["type"], + "column_names": header["column_names"], + "data": data_rows, + "context": {}, + } + check_response(response_data) diff --git a/tests/unit/executor/test_api_handler.py b/tests/unit/executor/test_api_handler.py index cbc6a8ff862..beb696d2f3c 100644 --- a/tests/unit/executor/test_api_handler.py +++ b/tests/unit/executor/test_api_handler.py @@ -1,15 +1,14 @@ import sys import types -from unittest.mock import patch import datetime as dt +from unittest.mock import patch +from dataclasses import dataclass import pandas as pd from tests.unit.executor_test_base import BaseExecutorDummyML -from dataclasses import dataclass - # import modules virtually if it is not installed try: diff --git a/tests/unit/executor/test_knowledge_base.py b/tests/unit/executor/test_knowledge_base.py index 485e9bb9e20..51fa41efe74 100644 --- a/tests/unit/executor/test_knowledge_base.py +++ b/tests/unit/executor/test_knowledge_base.py @@ -555,7 +555,15 @@ def stream_f(*args, **kwargs): yield df[chunk_size * i : chunk_size * (i + 1) :] # --- stream mode --- - mock_handler().query_stream.side_effect = stream_f + # Mock native_query to return TableResponse with generator + mock_handler().stream_response = True + + def native_query_with_generator(*args, **kwargs): + from mindsdb.integrations.libs.response import TableResponse + + return TableResponse(data_generator=stream_f()) + + mock_handler().native_query.side_effect = native_query_with_generator # test iterate check_partition( @@ -590,7 +598,15 @@ def stream_f(*args, **kwargs): ) # --- general mode --- - mock_handler().query_stream = None + # Mock native_query to return TableResponse with full data + mock_handler().stream_response = False + + def native_query_without_generator(*args, **kwargs): + from mindsdb.integrations.libs.response import TableResponse + + return TableResponse(data=df) + + mock_handler().native_query.side_effect = native_query_without_generator # test iterate check_partition( diff --git a/tests/unit/executor_test_base.py b/tests/unit/executor_test_base.py index d305cd9d90f..a2ebbed7ba6 100644 --- a/tests/unit/executor_test_base.py +++ b/tests/unit/executor_test_base.py @@ -59,6 +59,8 @@ def setup_class(cls): with open(cfg_file, "w") as fd: json.dump(config, fd) + cls._original_storage_dir_env = os.environ.get("MINDSDB_STORAGE_DIR") + cls._original_config_path_env = os.environ.get("MINDSDB_CONFIG_PATH") os.environ["MINDSDB_STORAGE_DIR"] = cls.storage_dir os.environ["MINDSDB_CONFIG_PATH"] = cfg_file @@ -83,6 +85,11 @@ def teardown_class(cls): if env_var_name in os.environ: del os.environ[env_var_name] + if cls._original_storage_dir_env is not None: + os.environ["MINDSDB_STORAGE_DIR"] = cls._original_storage_dir_env + if cls._original_config_path_env is not None: + os.environ["MINDSDB_CONFIG_PATH"] = cls._original_config_path_env + # remove import of mindsdb for next tests unload_module("mindsdb") @@ -339,11 +346,10 @@ def set_handler(self, mock_handler, name, tables, engine="postgres", schema=None self.db.session.add(r) self.db.session.commit() - from mindsdb.integrations.libs.response import RESPONSE_TYPE - from mindsdb.integrations.libs.response import HandlerResponse as Response + from mindsdb.integrations.libs.response import TableResponse def handler_response(df, affected_rows: None | int = None): - response = Response(RESPONSE_TYPE.TABLE, df, affected_rows=affected_rows) + response = TableResponse(data=df, affected_rows=affected_rows) return response def get_tables_f(): diff --git a/tests/unit/handlers/base_handler_test.py b/tests/unit/handlers/base_handler_test.py index 85e4133fbfc..be54f494402 100644 --- a/tests/unit/handlers/base_handler_test.py +++ b/tests/unit/handlers/base_handler_test.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, Mock from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, + DataHandlerResponse as Response, HandlerStatusResponse as StatusResponse, ) @@ -167,22 +167,6 @@ def get_columns_query(self): """ pass - def test_native_query(self): - """ - Tests the `native_query` method to ensure it executes a SQL query using a mock cursor and returns a Response object. - """ - mock_conn = MagicMock() - mock_cursor = MockCursorContextManager() - - self.handler.connect = MagicMock(return_value=mock_conn) - mock_conn.cursor = MagicMock(return_value=mock_cursor) - - query_str = f"SELECT * FROM {self.mock_table}" - data = self.handler.native_query(query_str) - - assert isinstance(data, Response) - self.assertFalse(data.error_code) - def test_get_columns(self): """ Tests if the `get_tables` method calls `native_query` with the correct SQL query. diff --git a/tests/unit/handlers/test_bigquery.py b/tests/unit/handlers/test_bigquery.py index 1bb69de1a11..2c48c87428a 100644 --- a/tests/unit/handlers/test_bigquery.py +++ b/tests/unit/handlers/test_bigquery.py @@ -6,9 +6,11 @@ from google.api_core.exceptions import BadRequest from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, + DataHandlerResponse, HandlerStatusResponse as StatusResponse, RESPONSE_TYPE, + TableResponse, + ErrorResponse, ) try: @@ -87,7 +89,7 @@ def test_native_query(self): mock_query_job_config_instance = mock_query_job_config.return_value data = self.handler.native_query(query_str) mock_conn.query.assert_called_once_with(query_str, job_config=mock_query_job_config_instance) - assert isinstance(data, Response) + assert isinstance(data, DataHandlerResponse) self.assertFalse(data.error_code) def test_get_tables(self): @@ -124,7 +126,7 @@ def test_get_columns(self): self.handler.native_query.assert_called_once_with(expected_query) def test_meta_get_tables_filters(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_tables(table_names=["orders"]) @@ -132,7 +134,7 @@ def test_meta_get_tables_filters(self): self.assertIn("AND t.table_name IN ('orders')", query) def test_meta_get_columns_filters(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_columns(table_names=["orders"]) @@ -176,9 +178,9 @@ def test_meta_get_column_statistics_batches_results(self): self.handler.native_query = MagicMock( side_effect=[ - Response(RESPONSE_TYPE.TABLE, data_frame=column_types_result), - Response(RESPONSE_TYPE.TABLE, data_frame=first_batch_result), - Response(RESPONSE_TYPE.TABLE, data_frame=second_batch_result), + TableResponse(data=column_types_result), + TableResponse(data=first_batch_result), + TableResponse(data=second_batch_result), ] ) @@ -189,20 +191,20 @@ def test_meta_get_column_statistics_batches_results(self): self.assertEqual(self.handler.native_query.call_count, 3) # 1 for column types + 2 for batches def test_meta_get_column_statistics_returns_error_when_empty(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.ERROR, error_message="boom")) + self.handler.native_query = MagicMock(return_value=ErrorResponse(error_message="boom")) response = self.handler.meta_get_column_statistics_for_table("table", ["col"]) self.assertEqual(response.resp_type, RESPONSE_TYPE.ERROR) def test_meta_get_primary_keys_filters(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_primary_keys(table_names=["orders"]) query = self.handler.native_query.call_args[0][0] self.assertIn("AND tc.table_name IN ('orders')", query) def test_meta_get_foreign_keys_filters(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_foreign_keys(table_names=["orders"]) query = self.handler.native_query.call_args[0][0] self.assertIn("AND tc.table_name IN ('orders')", query) diff --git a/tests/unit/handlers/test_clickhouse.py b/tests/unit/handlers/test_clickhouse.py index 404c888a4d7..68ec1d895fd 100644 --- a/tests/unit/handlers/test_clickhouse.py +++ b/tests/unit/handlers/test_clickhouse.py @@ -6,7 +6,8 @@ from sqlalchemy.exc import SQLAlchemyError from mindsdb_sql_parser import parse_sql -from base_handler_test import BaseDatabaseHandlerTest +from base_handler_test import BaseDatabaseHandlerTest, MockCursorContextManager +from mindsdb.integrations.libs.response import TableResponse try: from mindsdb.integrations.handlers.clickhouse_handler.clickhouse_handler import ClickHouseHandler @@ -67,6 +68,21 @@ def test_connect_success(self): f"clickhouse+{self.dummy_connection_data['protocol']}://{self.dummy_connection_data['user']}:{self.dummy_connection_data['password']}@{self.dummy_connection_data['host']}:{self.dummy_connection_data['port']}/{self.dummy_connection_data['database']}" ) + def test_native_query(self): + """ + Tests the `native_query` method to ensure it executes a SQL query using a mock cursor and returns a Response object. + """ + mock_conn = MagicMock() + mock_cursor = MockCursorContextManager() + + self.handler.connect = MagicMock(return_value=mock_conn) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + + query_str = f"SELECT * FROM {self.mock_table}" + data = self.handler.native_query(query_str) + + assert isinstance(data, TableResponse) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/handlers/test_confluence.py b/tests/unit/handlers/test_confluence.py index f5af306caff..7a33a36e3d3 100644 --- a/tests/unit/handlers/test_confluence.py +++ b/tests/unit/handlers/test_confluence.py @@ -15,16 +15,8 @@ ConfluenceWhiteboardsTable, ConfluenceTasksTable, ) -from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, - HandlerStatusResponse as StatusResponse, - RESPONSE_TYPE, -) -from mindsdb.integrations.utilities.sql_utils import ( - FilterCondition, - FilterOperator, - SortColumn, -) +from mindsdb.integrations.libs.response import TableResponse, HandlerStatusResponse as StatusResponse, RESPONSE_TYPE +from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, SortColumn class TestConfluenceHandler(BaseHandlerTestSetup, unittest.TestCase): @@ -103,21 +95,21 @@ def test_check_connection_failure(self): def test_get_tables(self): """ - Test that the `get_tables` method returns a list of table names. + Test that the `get_tables` method returns a TableResponse with a list of table names. """ response = self.handler.get_tables() - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) self.assertEqual(response.data_frame.columns.tolist(), ["table_name", "table_type"]) def test_get_columns(self): """ - Test that the `get_columns` method returns a list of columns for a table. + Test that the `get_columns` method returns a TableResponse with a list of columns for a table. """ response = self.handler.get_columns("spaces") - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) self.assertEqual(response.data_frame.columns.tolist(), ["Field", "Type"]) diff --git a/tests/unit/handlers/test_databricks.py b/tests/unit/handlers/test_databricks.py index df976cc4ce6..9dc2282e90f 100644 --- a/tests/unit/handlers/test_databricks.py +++ b/tests/unit/handlers/test_databricks.py @@ -19,7 +19,10 @@ DATABRICKS_AVAILABLE = False from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, + TableResponse, + ErrorResponse, + OkResponse, + DataHandlerResponse, RESPONSE_TYPE, HandlerStatusResponse as StatusResponse, ) @@ -171,7 +174,7 @@ def tearDown(self): def test_native_query(self): """ - Tests the `native_query` method to ensure it executes a SQL query using a mock cursor and returns a Response object. + Tests the `native_query` method to ensure it executes a SQL query using a mock cursor and returns a TableResponse object. """ self.mock_cursor.set_results([], []) @@ -179,8 +182,8 @@ def test_native_query(self): data = self.handler.native_query(query_str) self.mock_cursor.execute.assert_called_once_with(query_str) - self.assertIsInstance(data, Response) - self.assertFalse(data.error_code) + self.assertIsInstance(data, DataHandlerResponse) + self.assertNotIsInstance(data, ErrorResponse) def test_get_tables(self): """ @@ -241,14 +244,12 @@ def test_native_query_server_error(self): result = self.handler.native_query("SELECT * FROM test_table") - self.assertEqual(result.type, RESPONSE_TYPE.ERROR) + self.assertIsInstance(result, ErrorResponse) self.assertIn("Server error", result.error_message) def test_get_tables_all_schemas(self): """Test get_tables with all=True.""" - self.handler.native_query = MagicMock( - return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame([{"table_name": "t1"}])) - ) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame([{"table_name": "t1"}]))) self.handler.get_tables(all=True) @@ -276,7 +277,7 @@ def test_get_columns_with_schema(self): ] ) - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=mock_df)) + self.handler.native_query = MagicMock(return_value=TableResponse(data=mock_df)) self.handler.get_columns("test_table", schema_name="my_schema") @@ -415,7 +416,7 @@ def test_query_transforms_date_add_day_interval(self): """Test DATE_ADD with INTERVAL DAY is transformed to integer argument.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '30' DAY) AS due_date FROM orders LIMIT 1") # breakpoint() - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -426,7 +427,7 @@ def test_query_transforms_date_add_day_interval(self): def test_query_transforms_date_add_days_plural(self): """Test DATE_ADD with INTERVAL DAYS (plural) is transformed correctly.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL 7 DAYS) AS due_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -437,7 +438,7 @@ def test_query_transforms_date_add_days_plural(self): def test_query_transforms_date_sub_day_interval(self): """Test DATE_SUB with INTERVAL DAY is transformed to integer argument.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '5' DAY) AS past_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -448,7 +449,7 @@ def test_query_transforms_date_sub_day_interval(self): def test_query_transforms_date_add_week_interval(self): """Test DATE_ADD with INTERVAL WEEK is converted to days.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '2' WEEK) AS future_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -459,7 +460,7 @@ def test_query_transforms_date_add_week_interval(self): def test_query_transforms_date_sub_week_interval(self): """Test DATE_SUB with INTERVAL WEEK is converted to days.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '2' WEEK) AS past_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -470,7 +471,7 @@ def test_query_transforms_date_sub_week_interval(self): def test_query_transforms_date_add_month_interval(self): """Test DATE_ADD with INTERVAL MONTH uses ADD_MONTHS function.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '2' MONTH) AS future_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -481,7 +482,7 @@ def test_query_transforms_date_add_month_interval(self): def test_query_transforms_date_sub_month_interval(self): """Test DATE_SUB with INTERVAL MONTH uses ADD_MONTHS with negative value.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '3' MONTH) AS past_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -492,7 +493,7 @@ def test_query_transforms_date_sub_month_interval(self): def test_query_transforms_date_add_year_interval(self): """Test DATE_ADD with INTERVAL YEAR uses ADD_MONTHS with 12x multiplier.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '1' YEAR) AS future_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -503,7 +504,7 @@ def test_query_transforms_date_add_year_interval(self): def test_query_transforms_date_sub_year_interval(self): """Test DATE_SUB with INTERVAL YEAR uses ADD_MONTHS with negative 12x value.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '2' YEAR) AS past_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -514,7 +515,7 @@ def test_query_transforms_date_sub_year_interval(self): def test_query_transforms_date_add_hour_interval(self): """Test DATE_ADD with INTERVAL HOUR uses TIMESTAMPADD function.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '6' HOUR) AS future_time FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -525,7 +526,7 @@ def test_query_transforms_date_add_hour_interval(self): def test_query_transforms_date_sub_hour_interval(self): """Test DATE_SUB with INTERVAL HOUR uses TIMESTAMPADD with negative value.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '3' HOUR) AS past_time FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -536,7 +537,7 @@ def test_query_transforms_date_sub_hour_interval(self): def test_query_transforms_date_add_minute_interval(self): """Test DATE_ADD with INTERVAL MINUTE uses TIMESTAMPADD function.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '30' MINUTE) AS future_time FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -547,7 +548,7 @@ def test_query_transforms_date_add_minute_interval(self): def test_query_transforms_date_add_second_interval(self): """Test DATE_ADD with INTERVAL SECOND uses TIMESTAMPADD function.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '45' SECOND) AS future_time FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -558,7 +559,7 @@ def test_query_transforms_date_add_second_interval(self): def test_query_without_interval_unchanged(self): """Test that queries without INTERVAL pass through unchanged.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, 10) AS future_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -569,7 +570,7 @@ def test_query_without_interval_unchanged(self): def test_query_transforms_date_add_quarter_interval(self): """Test DATE_ADD with INTERVAL QUARTER uses ADD_MONTHS with 3x multiplier.""" query = parse_sql("SELECT DATE_ADD(o_orderdate, INTERVAL '2' QUARTER) AS future_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -580,7 +581,7 @@ def test_query_transforms_date_add_quarter_interval(self): def test_query_transforms_date_sub_quarter_interval(self): """Test DATE_SUB with INTERVAL QUARTER uses ADD_MONTHS with negative 3x value.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '1' QUARTER) AS past_date FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -591,7 +592,7 @@ def test_query_transforms_date_sub_quarter_interval(self): def test_query_transforms_date_sub_minute_interval(self): """Test DATE_SUB with INTERVAL MINUTE uses TIMESTAMPADD with negative value.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '15' MINUTE) AS past_time FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) @@ -602,7 +603,7 @@ def test_query_transforms_date_sub_minute_interval(self): def test_query_transforms_date_sub_second_interval(self): """Test DATE_SUB with INTERVAL SECOND uses TIMESTAMPADD with negative value.""" query = parse_sql("SELECT DATE_SUB(o_orderdate, INTERVAL '30' SECOND) AS past_time FROM orders") - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.OK)) + self.handler.native_query = MagicMock(return_value=OkResponse()) self.handler.query(query) diff --git a/tests/unit/handlers/test_dynamodb.py b/tests/unit/handlers/test_dynamodb.py index f1aef2481b1..6811afa40e7 100644 --- a/tests/unit/handlers/test_dynamodb.py +++ b/tests/unit/handlers/test_dynamodb.py @@ -8,29 +8,24 @@ from mindsdb_sql_parser.ast.select.identifier import Identifier from base_handler_test import BaseHandlerTestSetup -from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, - HandlerStatusResponse as StatusResponse, - RESPONSE_TYPE -) +from mindsdb.integrations.libs.response import TableResponse, HandlerStatusResponse as StatusResponse, RESPONSE_TYPE from mindsdb.integrations.handlers.dynamodb_handler.dynamodb_handler import DynamoDBHandler class TestDynamoDBHandler(BaseHandlerTestSetup, unittest.TestCase): - @property def dummy_connection_data(self): return OrderedDict( - aws_access_key_id='AQAXEQK89OX07YS34OP', - aws_secret_access_key='wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', - region_name='us-east-2', + aws_access_key_id="AQAXEQK89OX07YS34OP", + aws_secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + region_name="us-east-2", ) def create_handler(self): - return DynamoDBHandler('dynamodb', connection_data=self.dummy_connection_data) + return DynamoDBHandler("dynamodb", connection_data=self.dummy_connection_data) def create_patcher(self): - return patch('boto3.client') + return patch("boto3.client") def test_connect_failure_with_missing_connection_data(self): """ @@ -58,8 +53,8 @@ def test_check_connection_failure_with_incorrect_credentials(self): Test if the `check_connection` method returns a StatusResponse object and accurately reflects the connection status on failed connection due to incorrect credentials. """ self.mock_connect.return_value.list_tables.side_effect = ClientError( - error_response={'Error': {'Code': 'AccessDeniedException', 'Message': 'Access Denied'}}, - operation_name='list_tables' + error_response={"Error": {"Code": "AccessDeniedException", "Message": "Access Denied"}}, + operation_name="list_tables", ) response = self.handler.check_connection() @@ -72,7 +67,7 @@ def test_check_connection_success(self): """ Test if the `check_connection` method returns a StatusResponse object and accurately reflects the connection status on a successful connection. """ - self.mock_connect.return_value.list_tables.return_value = {'TableNames': ['table1', 'table2']} + self.mock_connect.return_value.list_tables.return_value = {"TableNames": ["table1", "table2"]} response = self.handler.check_connection() self.assertTrue(response.success) @@ -81,15 +76,12 @@ def test_check_connection_success(self): def test_query_select_success(self): """ - Test if the `query` method returns a response object with a data frame containing the query result. + Test if the `query` method returns a TableResponse object with a data frame containing the query result. `native_query` cannot be tested directly because it depends on some pre-processing steps handled by the `query` method. """ mock_boto3_client = Mock() mock_boto3_client.execute_statement.return_value = { - 'Items': [ - {'id': {'N': '1'}, 'name': {'S': 'Alice'}}, - {'id': {'N': '2'}, 'name': {'S': 'Bob'}} - ] + "Items": [{"id": {"N": "1"}, "name": {"S": "Alice"}}, {"id": {"N": "2"}, "name": {"S": "Bob"}}] } self.handler.connect = MagicMock(return_value=mock_boto3_client) @@ -97,18 +89,18 @@ def test_query_select_success(self): targets=[ Star(), ], - from_table=ast.Identifier('table1') + from_table=ast.Identifier("table1"), ) response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame self.assertEqual(len(df), 2) - self.assertEqual(df.columns.tolist(), ['id', 'name']) - self.assertEqual(df['id'].tolist(), [1, 2]) - self.assertEqual(df['name'].tolist(), ['Alice', 'Bob']) + self.assertEqual(df.columns.tolist(), ["id", "name"]) + self.assertEqual(df["id"].tolist(), [1, 2]) + self.assertEqual(df["name"].tolist(), ["Alice", "Bob"]) def test_query_select_failure_with_unsupported_clause(self): """ @@ -118,8 +110,8 @@ def test_query_select_failure_with_unsupported_clause(self): targets=[ Star(), ], - from_table=ast.Identifier('table1'), - limit=10 + from_table=ast.Identifier("table1"), + limit=10, ) with self.assertRaises(ValueError): self.handler.query(query) @@ -132,62 +124,58 @@ def test_query_insert_failure(self): mock_boto3_client.execute_statement.return_value = {} self.handler.connect = MagicMock(return_value=mock_boto3_client) - query = ast.Insert( - table=Identifier('table1'), - columns=['id', 'name'], - values=[[1, 'Alice']] - ) + query = ast.Insert(table=Identifier("table1"), columns=["id", "name"], values=[[1, "Alice"]]) with self.assertRaises(ValueError): self.handler.query(query) def test_get_tables(self): """ - Test if the `get_tables` method returns a response object with a list of tables. + Test if the `get_tables` method returns a TableResponse object with a list of tables. """ mock_boto3_client = Mock() - mock_boto3_client.list_tables.return_value = {'TableNames': ['table1', 'table2']} + mock_boto3_client.list_tables.return_value = {"TableNames": ["table1", "table2"]} self.handler.connection = mock_boto3_client response = self.handler.get_tables() - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame self.assertEqual(len(df), 2) - self.assertEqual(df.columns.tolist(), ['table_name']) - self.assertEqual(df['table_name'].tolist(), ['table1', 'table2']) + self.assertEqual(df.columns.tolist(), ["table_name"]) + self.assertEqual(df["table_name"].tolist(), ["table1", "table2"]) def test_get_columns(self): """ - Test if the `get_columns` method returns a response object with a list of columns for a given table. + Test if the `get_columns` method returns a TableResponse object with a list of columns for a given table. """ mock_boto3_client = Mock() mock_boto3_client.describe_table.return_value = { - 'Table': { - 'KeySchema': [ - {'AttributeName': 'id', 'KeyType': 'HASH'}, - {'AttributeName': 'name', 'KeyType': 'RANGE'} + "Table": { + "KeySchema": [ + {"AttributeName": "id", "KeyType": "HASH"}, + {"AttributeName": "name", "KeyType": "RANGE"}, + ], + "AttributeDefinitions": [ + {"AttributeName": "id", "AttributeType": "N"}, + {"AttributeName": "name", "AttributeType": "S"}, ], - 'AttributeDefinitions': [ - {'AttributeName': 'id', 'AttributeType': 'N'}, - {'AttributeName': 'name', 'AttributeType': 'S'} - ] } } self.handler.connection = mock_boto3_client - response = self.handler.get_columns('table1') + response = self.handler.get_columns("table1") - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame self.assertEqual(len(df), 2) - self.assertEqual(df.columns.tolist(), ['column_name', 'data_type']) - self.assertEqual(df['column_name'].tolist(), ['id', 'name']) - self.assertEqual(df['data_type'].tolist(), ['N', 'S']) + self.assertEqual(df.columns.tolist(), ["column_name", "data_type"]) + self.assertEqual(df["column_name"].tolist(), ["id", "name"]) + self.assertEqual(df["data_type"].tolist(), ["N", "S"]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/handlers/test_mariadb.py b/tests/unit/handlers/test_mariadb.py index be2cc4f6120..ecc5fdbfd46 100644 --- a/tests/unit/handlers/test_mariadb.py +++ b/tests/unit/handlers/test_mariadb.py @@ -6,19 +6,18 @@ from base_handler_test import BaseDatabaseHandlerTest, MockCursorContextManager from mindsdb.integrations.handlers.mariadb_handler.mariadb_handler import MariaDBHandler -from mindsdb.integrations.libs.response import HandlerResponse as Response +from mindsdb.integrations.libs.response import TableResponse class TestMariaDBHandler(BaseDatabaseHandlerTest, unittest.TestCase): - @property def dummy_connection_data(self): return OrderedDict( - host='127.0.0.1', + host="127.0.0.1", port=3307, - user='example_user', - password='example_pass', - database='example_db', + user="example_user", + password="example_pass", + database="example_db", ) @property @@ -64,18 +63,16 @@ def get_columns_query(self): """ def create_handler(self): - return MariaDBHandler('mariadb', connection_data=self.dummy_connection_data) + return MariaDBHandler("mariadb", connection_data=self.dummy_connection_data) def create_patcher(self): - return patch('mysql.connector.connect') + return patch("mysql.connector.connect") def test_native_query(self): - """Test that native_query returns a Response object with no error - """ + """Test that native_query returns a TableResponse object with no error""" mock_conn = MagicMock() mock_cursor = MockCursorContextManager( - data=[{'id': 1}], - description=[('id', 3, None, None, None, None, 1, 0, 45)] + data=[{"id": 1}], description=[("id", 3, None, None, None, None, 1, 0, 45)] ) self.handler.connect = MagicMock(return_value=mock_conn) @@ -84,9 +81,8 @@ def test_native_query(self): query_str = f"SELECT * FROM {self.mock_table}" data = self.handler.native_query(query_str) - self.assertIsInstance(data, Response) - self.assertFalse(data.error_code) + self.assertIsInstance(data, TableResponse) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/handlers/test_mongodb.py b/tests/unit/handlers/test_mongodb.py index 36ae12c8479..1939fdb6342 100644 --- a/tests/unit/handlers/test_mongodb.py +++ b/tests/unit/handlers/test_mongodb.py @@ -11,7 +11,9 @@ from base_handler_test import BaseHandlerTestSetup from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, + TableResponse, + OkResponse, + ErrorResponse, HandlerStatusResponse as StatusResponse, RESPONSE_TYPE, ) @@ -88,7 +90,7 @@ def test_check_connection_success(self): def test_query_failure_with_non_existent_collection(self): """ - Test if the `query` method returns a response object with an error message on failed query due to non-existent collection. + Test if the `query` method returns an ErrorResponse object with an error message on failed query due to non-existent collection. """ self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [ "movies" @@ -103,7 +105,7 @@ def test_query_failure_with_non_existent_collection(self): response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, ErrorResponse) self.assertEqual(response.type, RESPONSE_TYPE.ERROR) self.assertTrue(response.error_message) @@ -139,7 +141,7 @@ def test_query_failure_with_unsupported_operation(self): def test_query_select_success(self): """ - Test if the `query` method returns a response object with a data frame containing the query result. + Test if the `query` method returns a TableResponse object with a data frame containing the query result. `native_query` cannot be tested directly because it depends on some pre-processing steps handled by the `query` method. """ self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [ @@ -164,7 +166,7 @@ def test_query_select_success(self): response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -174,7 +176,7 @@ def test_query_select_success(self): def test_query_update_success(self): """ - Test if the `query` method returns a response object with a 'OK' status. + Test if the `query` method returns an OkResponse object with a 'OK' status. `native_query` cannot be tested directly because it depends on some pre-processing steps handled by the `query` method. """ self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [ @@ -201,12 +203,12 @@ def test_query_update_success(self): response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, OkResponse) self.assertEqual(response.type, RESPONSE_TYPE.OK) def test_get_tables(self): """ - Tests the `get_tables` method returns a response object with a list of tables (collections) in the database. + Tests the `get_tables` method returns a TableResponse object with a list of tables (collections) in the database. """ self.mock_connect.return_value[self.dummy_connection_data["database"]].list_collection_names.return_value = [ "theaters", @@ -219,7 +221,7 @@ def test_get_tables(self): response = self.handler.get_tables() - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -232,7 +234,7 @@ def test_get_tables(self): def test_get_columns(self): """ - Tests the `get_columns` method returns a response object with a list of columns (fields) for a given table (collection). + Tests the `get_columns` method returns a TableResponse object with a list of columns (fields) for a given table (collection). """ self.mock_connect.return_value[self.dummy_connection_data["database"]]["movies"].find_one.return_value = { "_id": ObjectId("5f5b3f3b3f3b3f3b3f3b3f3b"), @@ -243,7 +245,7 @@ def test_get_columns(self): response = self.handler.get_columns("movies") - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -287,7 +289,7 @@ def test_query_select_with_subquery_success(self): response = self.handler.query(main_query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -353,7 +355,7 @@ def test_query_select_with_complex_subquery_success(self): response = self.handler.query(main_query) - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -388,7 +390,7 @@ def test_query_select_with_where_operators(self): response = self.handler.query(query) - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -431,7 +433,7 @@ def test_query_select_with_and_or_conditions(self): response = self.handler.query(query) - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -494,7 +496,7 @@ def test_select_with_match_and_projection(self): response = self.handler.query(query) - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -525,7 +527,7 @@ def test_select_constant_with_alias(self): response = self.handler.query(query) - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -557,7 +559,7 @@ def test_select_with_constant_no_alias(self): response = self.handler.query(query) - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -604,7 +606,7 @@ def test_query_select_with_subquery_and_where(self): response = self.handler.query(main_query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -643,7 +645,7 @@ def test_query_select_nested_field_projection(self): response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -689,7 +691,7 @@ def test_query_select_nested_field_with_where(self): response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -725,7 +727,7 @@ def test_query_aggregation_on_nested_field(self): response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -772,7 +774,7 @@ def test_query_group_by_with_nested_aggregation(self): response = self.handler.query(query) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame diff --git a/tests/unit/handlers/test_mssql.py b/tests/unit/handlers/test_mssql.py index dbb097754f9..37e4d06c8f7 100644 --- a/tests/unit/handlers/test_mssql.py +++ b/tests/unit/handlers/test_mssql.py @@ -17,7 +17,13 @@ from pandas import DataFrame from base_handler_test import BaseDatabaseHandlerTest -from mindsdb.integrations.libs.response import HandlerResponse as Response, INF_SCHEMA_COLUMNS_NAMES_SET, RESPONSE_TYPE +from mindsdb.integrations.libs.response import ( + OkResponse, + TableResponse, + ErrorResponse, + INF_SCHEMA_COLUMNS_NAMES_SET, + RESPONSE_TYPE, +) from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE @@ -91,8 +97,7 @@ def test_native_query_with_results(self): mock_conn.cursor.assert_called_once_with(as_dict=True) mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) - self.assertFalse(data.error_code) + assert isinstance(data, TableResponse) self.assertEqual(data.type, RESPONSE_TYPE.TABLE) self.assertIsInstance(data.data_frame, DataFrame) expected_columns = ["id", "name"] @@ -121,8 +126,7 @@ def test_native_query_no_results(self): mock_conn.cursor.assert_called_once_with(as_dict=True) mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) - self.assertFalse(data.error_code) + assert isinstance(data, OkResponse) self.assertEqual(data.type, RESPONSE_TYPE.OK) mock_conn.commit.assert_called_once() @@ -149,7 +153,7 @@ def test_native_query_error(self): mock_conn.cursor.assert_called_once_with(as_dict=True) mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) + assert isinstance(data, ErrorResponse) self.assertEqual(data.type, RESPONSE_TYPE.ERROR) self.assertEqual(data.error_message, str(error)) @@ -166,7 +170,7 @@ def test_query_method(self): try: self.handler.renderer = renderer_mock self.handler.native_query = MagicMock() - self.handler.native_query.return_value = Response(RESPONSE_TYPE.OK) + self.handler.native_query.return_value = OkResponse() mock_ast = MagicMock() result = self.handler.query(mock_ast) @@ -180,7 +184,7 @@ def test_get_tables(self): """ Tests that get_tables calls native_query with the correct SQL """ - expected_response = Response(RESPONSE_TYPE.OK) + expected_response = OkResponse() self.handler.native_query = MagicMock(return_value=expected_response) response = self.handler.get_tables() @@ -199,9 +203,7 @@ def test_get_columns(self): """ Tests that get_columns calls native_query with the correct SQL """ - expected_response = Response( - RESPONSE_TYPE.TABLE, data_frame=DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET)) - ) + expected_response = TableResponse(data=DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET))) self.handler.native_query = MagicMock(return_value=expected_response) table_name = "test_table" @@ -259,7 +261,7 @@ def test_meta_get_tables_returns_response(self): }, ] ) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) # without filter @@ -271,7 +273,7 @@ def test_meta_get_tables_returns_response(self): self.handler.native_query.reset_mock() tables = ["customers", "orders"] filtered_df = df[df["table_name"].isin(tables)].reset_index(drop=True) - filtered_response = Response(RESPONSE_TYPE.TABLE, data_frame=filtered_df) + filtered_response = TableResponse(data=filtered_df) self.handler.native_query = MagicMock(return_value=filtered_response) response = self.handler.meta_get_tables(table_names=tables) self.handler.native_query.assert_called_once() @@ -307,7 +309,7 @@ def test_meta_get_columns_returns_response(self): }, ] ) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) # without filter @@ -319,7 +321,7 @@ def test_meta_get_columns_returns_response(self): self.handler.native_query.reset_mock() tables = ["customers"] filtered_df = df[df["table_name"].isin(tables)].reset_index(drop=True) - filtered_response = Response(RESPONSE_TYPE.TABLE, data_frame=filtered_df) + filtered_response = TableResponse(data=filtered_df) self.handler.native_query = MagicMock(return_value=filtered_response) response = self.handler.meta_get_columns(table_names=tables) self.handler.native_query.assert_called_once() @@ -351,7 +353,7 @@ def test_meta_get_column_statistics_returns_response(self): }, ] ) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) # without filter @@ -363,7 +365,7 @@ def test_meta_get_column_statistics_returns_response(self): self.handler.native_query.reset_mock() tables = ["customers"] filtered_df = df[df["TABLE_NAME"].isin(tables)].reset_index(drop=True) - filtered_response = Response(RESPONSE_TYPE.TABLE, data_frame=filtered_df) + filtered_response = TableResponse(data=filtered_df) self.handler.native_query = MagicMock(return_value=filtered_response) response = self.handler.meta_get_column_statistics(table_names=tables) self.handler.native_query.assert_called_once() @@ -382,7 +384,7 @@ def test_meta_get_primary_keys_returns_response(self): {"table_name": "orders", "column_name": "id", "ordinal_position": 1, "constraint_name": "pk_orders"}, ] ) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) # without filter @@ -394,7 +396,7 @@ def test_meta_get_primary_keys_returns_response(self): self.handler.native_query.reset_mock() tables = ["customers"] filtered_df = df[df["table_name"].isin(tables)].reset_index(drop=True) - filtered_response = Response(RESPONSE_TYPE.TABLE, data_frame=filtered_df) + filtered_response = TableResponse(data=filtered_df) self.handler.native_query = MagicMock(return_value=filtered_response) response = self.handler.meta_get_primary_keys(table_names=tables) self.handler.native_query.assert_called_once() @@ -420,7 +422,7 @@ def test_meta_get_foreign_keys_returns_response(self): }, ] ) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) # without filter @@ -432,7 +434,7 @@ def test_meta_get_foreign_keys_returns_response(self): self.handler.native_query.reset_mock() tables = ["orders"] filtered_df = df[df["child_table_name"].isin(tables)].reset_index(drop=True) - filtered_response = Response(RESPONSE_TYPE.TABLE, data_frame=filtered_df) + filtered_response = TableResponse(data=filtered_df) self.handler.native_query = MagicMock(return_value=filtered_response) response = self.handler.meta_get_foreign_keys(table_names=tables) self.handler.native_query.assert_called_once() @@ -521,7 +523,7 @@ def test_meta_methods_result_shape_and_exceptions(self): for name, df_factory, method in methods: with self.subTest(method=name, case="no_filter"): df = df_factory() - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) res = method() self.handler.native_query.assert_called_once() @@ -533,7 +535,7 @@ def test_meta_methods_result_shape_and_exceptions(self): with self.subTest(method=name, case="with_filter"): df = df_factory() - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) res = ( method(table_names=["A", "B"]) @@ -726,7 +728,7 @@ def test_types_casting(self): ("n_real", 3, None, None, None, None, None), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str) excepted_mysql_types = [ MYSQL_DATA_TYPE.TINYINT, MYSQL_DATA_TYPE.INT, @@ -741,7 +743,7 @@ def test_types_casting(self): MYSQL_DATA_TYPE.FLOAT, MYSQL_DATA_TYPE.FLOAT, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for columns_name, input_value in input_row.items(): result_value = response.data_frame[columns_name][0] self.assertEqual(result_value, input_value) @@ -818,7 +820,7 @@ def test_types_casting(self): ("t_uniqueidentifier", 2, None, None, None, None, None), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str) excepted_mysql_types = [ MYSQL_DATA_TYPE.TEXT, MYSQL_DATA_TYPE.TEXT, @@ -832,7 +834,7 @@ def test_types_casting(self): MYSQL_DATA_TYPE.TEXT, MYSQL_DATA_TYPE.BINARY, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for columns_name, input_value in input_row.items(): result_value = response.data_frame[columns_name][0] self.assertEqual(result_value, input_value) @@ -901,7 +903,7 @@ def test_types_casting(self): ("d_datetimeoffset_p", 2, None, None, None, None, None), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str) excepted_mysql_types = [ # DATE and TIME is not possible to infer, so they are BINARY MYSQL_DATA_TYPE.BINARY, @@ -914,7 +916,7 @@ def test_types_casting(self): MYSQL_DATA_TYPE.DATETIME, MYSQL_DATA_TYPE.DATETIME, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for columns_name, input_value in input_row.items(): result_value = response.data_frame[columns_name][0] if columns_name == "d_datetimeoffset_p": @@ -1099,7 +1101,7 @@ def __getitem__(self, idx): mock_conn.cursor.assert_called_once_with() mock_cursor.execute.assert_called_once_with(query_str) - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) self.assertIsInstance(response.data_frame, DataFrame) self.assertEqual(list(response.data_frame.columns), ["id", "name"]) @@ -1168,10 +1170,10 @@ def __getitem__(self, idx): response = handler.native_query("SELECT * FROM test") - self.assertIsInstance(response, Response) + self.assertIsInstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) - self.assertIsNotNone(response.mysql_types) - self.assertTrue(len(response.mysql_types) > 0) + self.assertIsNotNone(response.columns) + self.assertTrue(len(response.columns) > 0) finally: if "pyodbc" in sys.modules: del sys.modules["pyodbc"] diff --git a/tests/unit/handlers/test_mysql.py b/tests/unit/handlers/test_mysql.py index bbb3ab93e56..065cc66896f 100644 --- a/tests/unit/handlers/test_mysql.py +++ b/tests/unit/handlers/test_mysql.py @@ -12,7 +12,13 @@ from base_handler_test import BaseDatabaseHandlerTest, MockCursorContextManager from mindsdb.integrations.handlers.mysql_handler.mysql_handler import MySQLHandler -from mindsdb.integrations.libs.response import HandlerResponse as Response, INF_SCHEMA_COLUMNS_NAMES_SET, RESPONSE_TYPE +from mindsdb.integrations.libs.response import ( + OkResponse, + TableResponse, + DataHandlerResponse as Response, + INF_SCHEMA_COLUMNS_NAMES_SET, + RESPONSE_TYPE, +) from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE @@ -89,13 +95,12 @@ def test_native_query(self): query_str = f"SELECT * FROM {self.mock_table}" data = self.handler.native_query(query_str) - self.assertIsInstance(data, Response) - self.assertFalse(data.error_code) + self.assertIsInstance(data, TableResponse) def test_native_query_with_results(self): """ Tests the `native_query` method to ensure it executes a SQL query and handles the case - where the query returns a result set + where the query returns a result set, streaming data via fetchmany """ mock_conn = MagicMock() mock_cursor = MagicMock() @@ -106,7 +111,11 @@ def test_native_query_with_results(self): mock_conn.cursor = MagicMock(return_value=mock_cursor) mock_conn.is_connected = MagicMock(return_value=True) - mock_cursor.fetchall.return_value = [{"id": 1, "name": "test1"}, {"id": 2, "name": "test2"}] + # fetchmany returns tuples (non-dictionary cursor), then empty list to signal end + mock_cursor.fetchmany.side_effect = [ + [(1, "test1"), (2, "test2")], + [], + ] # MySQL cursor provides column info via description attribute mock_cursor.description = [ @@ -119,12 +128,10 @@ def test_native_query_with_results(self): query_str = "SELECT * FROM test_table" data = self.handler.native_query(query_str) - mock_conn.cursor.assert_called_once_with(dictionary=True, buffered=True) + mock_conn.cursor.assert_called_once_with(buffered=False) mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) - self.assertFalse(data.error_code) - self.assertEqual(data.type, RESPONSE_TYPE.TABLE) + assert isinstance(data, TableResponse) self.assertIsInstance(data.data_frame, DataFrame) expected_columns = ["id", "name"] @@ -150,12 +157,10 @@ def test_native_query_no_results(self): query_str = "INSERT INTO test_table VALUES (1, 'test')" data = self.handler.native_query(query_str) - mock_conn.cursor.assert_called_once_with(dictionary=True, buffered=True) + mock_conn.cursor.assert_called_once_with(buffered=False) mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) - self.assertFalse(data.error_code) - self.assertEqual(data.type, RESPONSE_TYPE.OK) + assert isinstance(data, OkResponse) self.assertEqual(data.affected_rows, 1) def test_native_query_error(self): @@ -178,7 +183,7 @@ def test_native_query_error(self): query_str = "INVALID SQL" data = self.handler.native_query(query_str) - mock_conn.cursor.assert_called_once_with(dictionary=True, buffered=True) + mock_conn.cursor.assert_called_once_with(buffered=False) mock_cursor.execute.assert_called_once_with(query_str) assert isinstance(data, Response) @@ -377,7 +382,7 @@ def test_query_method(self): mock_renderer_class.return_value = mock_renderer self.handler.native_query = MagicMock() - self.handler.native_query.return_value = Response(RESPONSE_TYPE.OK) + self.handler.native_query.return_value = OkResponse() mock_ast = MagicMock() @@ -406,7 +411,7 @@ def test_get_tables(self): """ Tests that get_tables calls native_query with the correct SQL """ - expected_response = Response(RESPONSE_TYPE.OK) + expected_response = OkResponse() self.handler.native_query = MagicMock(return_value=expected_response) response = self.handler.get_tables() @@ -425,9 +430,7 @@ def test_get_columns(self): """ Tests that get_columns calls native_query with the correct SQL """ - expected_response = Response( - RESPONSE_TYPE.TABLE, data_frame=DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET)) - ) + expected_response = TableResponse(data=DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET))) self.handler.native_query = MagicMock(return_value=expected_response) table_name = "test_table" @@ -473,19 +476,19 @@ def test_types_casting(self): mock_conn.is_connected = MagicMock(return_value=True) # region test TEXT/BLOB types and sub-types - input_row = { - "t_varchar": "v_varchar", - "t_tinytext": "v_tinytext", - "t_text": "v_text", - "t_mediumtext": "v_mediumtext", - "t_longtext": "v_longtext", - "t_tinyblon": "v_tinyblon", - "t_blob": "v_blob", - "t_mediumblob": "v_mediumblob", - "t_longblob": "v_longblob", - "t_json": '{"key": "value"}', - } - mock_cursor.fetchall.return_value = [input_row] + input_row = OrderedDict( + t_varchar="v_varchar", + t_tinytext="v_tinytext", + t_text="v_text", + t_mediumtext="v_mediumtext", + t_longtext="v_longtext", + t_tinyblon="v_tinyblon", + t_blob="v_blob", + t_mediumblob="v_mediumblob", + t_longblob="v_longblob", + t_json='{"key": "value"}', + ) + mock_cursor.fetchall.return_value = [list(input_row.values())] mock_cursor.description = [ ("t_varchar", 253, None, None, None, None, 1, 0, 45), @@ -500,7 +503,7 @@ def test_types_casting(self): ("t_json", 245, None, None, None, None, 1, 144, 63), ] - response: Response = self.handler.native_query(query_str) + response: Response = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [ MYSQL_DATA_TYPE.VARBINARY, MYSQL_DATA_TYPE.TEXT, @@ -513,7 +516,8 @@ def test_types_casting(self): MYSQL_DATA_TYPE.BLOB, MYSQL_DATA_TYPE.JSON, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) for key, input_value in input_row.items(): result_value = response.data_frame[key][0] self.assertEqual(type(result_value), type(input_value)) @@ -521,17 +525,18 @@ def test_types_casting(self): # endregion # region test TINYINT/BOOL/BOOLEAN types - input_row = {"t_tinyint": 1, "t_bool": 1, "t_boolean": 1} - mock_cursor.fetchall.return_value = [input_row] + input_row = OrderedDict(t_tinyint=1, t_bool=1, t_boolean=1) + mock_cursor.fetchall.return_value = [list(input_row.values())] mock_cursor.description = [ ("t_tinyint", 1, None, None, None, None, 1, 0, 63), ("t_bool", 1, None, None, None, None, 1, 0, 63), ("t_boolean", 1, None, None, None, None, 1, 0, 63), ] - response: Response = self.handler.native_query(query_str) + response: Response = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [MYSQL_DATA_TYPE.TINYINT, MYSQL_DATA_TYPE.TINYINT, MYSQL_DATA_TYPE.TINYINT] - self.assertEqual(response.mysql_types, excepted_mysql_types) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) for key, input_value in input_row.items(): result_value = response.data_frame[key][0] # without None values in result columns types will be one of pandas types @@ -540,19 +545,19 @@ def test_types_casting(self): # endregion # region test numeric types - input_row = { - "t_tinyint": 1, - "t_bool": 0, - "t_smallint": 2, - "t_year": 2025, - "t_mediumint": 3, - "t_int": 4, - "t_bigint": 5, - "t_float": 1.1, - "t_double": 2.2, - "t_decimal": Decimal("3.3"), - } - mock_cursor.fetchall.return_value = [input_row] + input_row = OrderedDict( + t_tinyint=1, + t_bool=0, + t_smallint=2, + t_year=2025, + t_mediumint=3, + t_int=4, + t_bigint=5, + t_float=1.1, + t_double=2.2, + t_decimal=Decimal("3.3"), + ) + mock_cursor.fetchall.return_value = [list(input_row.values())] mock_cursor.description = [ ("t_tinyint", 1, None, None, None, None, 1, 0, 63), ("t_bool", 1, None, None, None, None, 1, 0, 63), @@ -565,7 +570,7 @@ def test_types_casting(self): ("t_double", 5, None, None, None, None, 1, 0, 63), ("t_decimal", 246, None, None, None, None, 1, 0, 63), ] - response: Response = self.handler.native_query(query_str) + response: Response = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [ MYSQL_DATA_TYPE.TINYINT, MYSQL_DATA_TYPE.TINYINT, @@ -579,21 +584,22 @@ def test_types_casting(self): MYSQL_DATA_TYPE.DECIMAL, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) for key, input_value in input_row.items(): result_value = response.data_frame[key][0] self.assertEqual(result_value, input_value) # endregion # test date/time types - input_row = { - "t_date": datetime.date(2025, 4, 16), - "t_time": datetime.timedelta(seconds=45600), - "t_year": 2025, - "t_datetime": datetime.datetime(2025, 4, 16, 12, 30, 15), - "t_timestamp": datetime.datetime(2025, 4, 16, 12, 30, 15), - } - mock_cursor.fetchall.return_value = [input_row] + input_row = OrderedDict( + t_date=datetime.date(2025, 4, 16), + t_time=datetime.timedelta(seconds=45600), + t_year=2025, + t_datetime=datetime.datetime(2025, 4, 16, 12, 30, 15), + t_timestamp=datetime.datetime(2025, 4, 16, 12, 30, 15), + ) + mock_cursor.fetchall.return_value = [list(input_row.values())] mock_cursor.description = [ ("t_date", 10, None, None, None, None, 1, 128, 63), @@ -603,7 +609,7 @@ def test_types_casting(self): ("t_timestamp", 7, None, None, None, None, 1, 128, 63), ] - response: Response = self.handler.native_query(query_str) + response: Response = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [ MYSQL_DATA_TYPE.DATE, MYSQL_DATA_TYPE.TIME, @@ -611,7 +617,8 @@ def test_types_casting(self): MYSQL_DATA_TYPE.DATETIME, MYSQL_DATA_TYPE.TIMESTAMP, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) for key, input_value in input_row.items(): result_value = response.data_frame[key][0] self.assertEqual(result_value, input_value) @@ -619,14 +626,14 @@ def test_types_casting(self): # region test casting of nullable types bigint_val = 9223372036854775807 - input_rows = [{"t_bigint": bigint_val, "t_boolean": 1}, {"t_bigint": None, "t_boolean": None}] - mock_cursor.fetchall.return_value = input_rows + input_rows = [OrderedDict(t_bigint=bigint_val, t_boolean=1), OrderedDict(t_bigint=None, t_boolean=None)] + mock_cursor.fetchall.return_value = [list(row.values()) for row in input_rows] description = [ ("t_bigint", 8, None, None, None, None, 1, 0, 63), ("t_boolean", 1, None, None, None, None, 1, 0, 63), ] mock_cursor.description = description - response: Response = self.handler.native_query(query_str) + response: Response = self.handler.native_query(query_str, stream=False) self.assertEqual(response.data_frame.dtypes.iloc[0], "Int64") self.assertEqual(response.data_frame.dtypes.iloc[1], "Int64") self.assertEqual(response.data_frame.iloc[0, 0], bigint_val) @@ -636,16 +643,17 @@ def test_types_casting(self): # endregion # region test vector type - input_row = { - "t_vector": array("f", [1.1, 2.2, 3.3]), - } - mock_cursor.fetchall.return_value = [input_row] + input_row = OrderedDict( + t_vector=array("f", [1.1, 2.2, 3.3]), + ) + mock_cursor.fetchall.return_value = [list(input_row.values())] mock_cursor.description = [("t_vector", 242, None, None, None, None, 1, 144, 63)] - response: Response = self.handler.native_query(query_str) + response: Response = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [MYSQL_DATA_TYPE.VECTOR] - self.assertEqual(response.mysql_types, excepted_mysql_types) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) self.assertEqual(input_row["t_vector"], response.data_frame["t_vector"][0]) # endregion @@ -661,7 +669,7 @@ def _test_meta_method_with_filter(self, method, sample_data, filter_column, filt """ # Test without filter df = DataFrame(sample_data) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + expected_response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=expected_response) response = method() @@ -671,7 +679,7 @@ def _test_meta_method_with_filter(self, method, sample_data, filter_column, filt # Test with filter self.handler.native_query.reset_mock() filtered_df = df[df[filter_column].isin(filter_values)].reset_index(drop=True) - filtered_response = Response(RESPONSE_TYPE.TABLE, data_frame=filtered_df) + filtered_response = TableResponse(data=filtered_df) self.handler.native_query = MagicMock(return_value=filtered_response) response = method(table_names=filter_values) diff --git a/tests/unit/handlers/test_oracle.py b/tests/unit/handlers/test_oracle.py index cfd8dd7423f..fb18a57fcc6 100644 --- a/tests/unit/handlers/test_oracle.py +++ b/tests/unit/handlers/test_oracle.py @@ -18,9 +18,11 @@ import pandas as pd from pandas import DataFrame -from base_handler_test import BaseDatabaseHandlerTest +from base_handler_test import BaseDatabaseHandlerTest, MockCursorContextManager from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, + TableResponse, + OkResponse, + ErrorResponse, INF_SCHEMA_COLUMNS_NAMES_SET, RESPONSE_TYPE, ) @@ -165,9 +167,42 @@ def test_thick_mode_connection(self): handler.connect() mock_init.assert_called_once_with(lib_dir="/path/to/oracle/client/lib") - def test_native_query_with_results(self): + def test_native_query_with_results_streaming(self): """ - Tests the `native_query` method for a SELECT statement returning results. + Tests the `native_query` method for a SELECT statement returning results at server side execution. + """ + mock_conn = MagicMock() + mock_cursor = MockCursorContextManager() + + self.handler.connect = MagicMock(return_value=mock_conn) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + + # Server-side execution uses fetchmany, not fetchall + mock_cursor.fetchmany = MagicMock(side_effect=[[(1, "test1"), (2, "test2")], []]) + mock_cursor.description = [ + ("ID", None, None, None, None, None, None), + ("NAME", None, None, None, None, None, None), + ] + + query_str = "SELECT ID, NAME FROM test_table" + data = self.handler.native_query(query_str, stream=True) + + mock_conn.cursor.assert_called_once() + mock_cursor.execute.assert_called_once_with(query_str) + + # Verify the response + self.assertIsInstance(data, TableResponse) + self.assertEqual(data.type, RESPONSE_TYPE.TABLE) + self.assertIsNone(data._data) + data.fetchall() + self.assertIsInstance(data._data, DataFrame) + expected_columns = ["ID", "NAME"] + self.assertListEqual(list(data.data_frame.columns), expected_columns) + self.assertEqual(len(data.data_frame), 2) + + def test_native_query_with_no_streaming(self): + """ + Tests the `native_query` method for a SELECT statement returning results at client side execution. """ mock_conn = MagicMock() mock_cursor = MagicMock() @@ -177,22 +212,21 @@ def test_native_query_with_results(self): self.handler.connect = MagicMock(return_value=mock_conn) mock_conn.cursor = MagicMock(return_value=mock_cursor) - mock_cursor.fetchall.return_value = [(1, "test1"), (2, "test2")] + mock_cursor.fetchall = MagicMock(return_value=[(1, "test1"), (2, "test2")]) mock_cursor.description = [ ("ID", None, None, None, None, None, None), ("NAME", None, None, None, None, None, None), ] query_str = "SELECT ID, NAME FROM test_table" - data = self.handler.native_query(query_str) + data = self.handler.native_query(query_str, stream=False) mock_conn.cursor.assert_called_once() mock_cursor.execute.assert_called_once_with(query_str) mock_cursor.fetchall.assert_called_once() mock_conn.commit.assert_called_once() - self.assertIsInstance(data, Response) - self.assertFalse(data.error_code) + self.assertIsInstance(data, TableResponse) self.assertEqual(data.type, RESPONSE_TYPE.TABLE) self.assertIsInstance(data.data_frame, DataFrame) expected_columns = ["ID", "NAME"] @@ -222,8 +256,7 @@ def test_native_query_no_results(self): mock_cursor.fetchall.assert_not_called() mock_conn.commit.assert_called_once() - self.assertIsInstance(data, Response) - self.assertFalse(data.error_code) + self.assertIsInstance(data, OkResponse) self.assertEqual(data.type, RESPONSE_TYPE.OK) self.assertEqual(data.affected_rows, 1) @@ -252,7 +285,7 @@ def test_native_query_error(self): mock_conn.rollback.assert_called_once() mock_conn.commit.assert_not_called() - self.assertIsInstance(data, Response) + self.assertIsInstance(data, ErrorResponse) self.assertEqual(data.type, RESPONSE_TYPE.ERROR) self.assertEqual(data.error_message, error_msg) @@ -265,7 +298,7 @@ def test_query_method(self): orig_renderer = self.handler.renderer self.handler.native_query = MagicMock() - expected_response = Response(RESPONSE_TYPE.TABLE) + expected_response = TableResponse() self.handler.native_query.return_value = expected_response mock_ast = MagicMock() @@ -299,7 +332,7 @@ def test_get_tables(self): ], columns=["TABLE_SCHEMA", "TABLE_NAME", "TABLE_TYPE"], ) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + expected_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=expected_response) @@ -364,7 +397,7 @@ def test_get_tables_multiple_schemas(self): ], columns=["TABLE_SCHEMA", "TABLE_NAME", "TABLE_TYPE"], ) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + expected_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=expected_response) @@ -448,7 +481,7 @@ def test_get_columns(self): ] expected_df = DataFrame(expected_df_data, columns=query_columns) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + expected_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=expected_response) table_name = "test_table" @@ -573,7 +606,7 @@ def test_types_casting(self): ("N_BINARY_DOUBLE", oracledb.DB_TYPE_NUMBER, 127, None, None, None, True), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [ MYSQL_DATA_TYPE.FLOAT, MYSQL_DATA_TYPE.DECIMAL, @@ -590,7 +623,7 @@ def test_types_casting(self): MYSQL_DATA_TYPE.FLOAT, MYSQL_DATA_TYPE.FLOAT, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for i, input_value in enumerate(input_row): result_value = response.data_frame[response.data_frame.columns[i]][0] self.assertEqual(result_value, input_value) @@ -612,9 +645,9 @@ def test_types_casting(self): ("T_BOOLEAN", oracledb.DB_TYPE_BOOLEAN, None, None, None, None, True), ("T_BOOL", oracledb.DB_TYPE_BOOLEAN, None, None, None, None, True), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [MYSQL_DATA_TYPE.BOOLEAN, MYSQL_DATA_TYPE.BOOLEAN] - self.assertEqual(response.mysql_types, excepted_mysql_types) + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for i, input_value in enumerate(input_row): result_value = response.data_frame[response.data_frame.columns[i]][0] self.assertEqual(result_value, input_value) @@ -680,7 +713,7 @@ def test_types_casting(self): ("T_RAW", oracledb.DB_TYPE_RAW, 100, 100, None, None, True), ("T_BLOB", oracledb.DB_TYPE_LONG_RAW, None, None, None, None, True), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [ MYSQL_DATA_TYPE.TEXT, MYSQL_DATA_TYPE.TEXT, @@ -692,7 +725,7 @@ def test_types_casting(self): MYSQL_DATA_TYPE.BINARY, MYSQL_DATA_TYPE.BINARY, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for i, input_value in enumerate(input_row): result_value = response.data_frame[response.data_frame.columns[i]][0] self.assertEqual(result_value, input_value) @@ -739,13 +772,13 @@ def test_types_casting(self): ("D_TIMESTAMP", oracledb.DB_TYPE_TIMESTAMP, 23, None, 0, 6, True), ("D_TIMESTAMP_P", oracledb.DB_TYPE_TIMESTAMP, 23, None, 0, 9, True), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [ MYSQL_DATA_TYPE.DATE, MYSQL_DATA_TYPE.TIMESTAMP, MYSQL_DATA_TYPE.TIMESTAMP, ] - self.assertEqual(response.mysql_types, excepted_mysql_types) + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for i, input_value in enumerate(input_row): result_value = response.data_frame[response.data_frame.columns[i]][0] self.assertEqual(result_value, input_value) @@ -767,7 +800,7 @@ def test_types_casting(self): ), # set 17 just to force cast to Int64 ("T_BOOLEAN", oracledb.DB_TYPE_BOOLEAN, None, None, None, None, True), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) self.assertEqual(response.data_frame.dtypes[0], "Int64") self.assertEqual(response.data_frame.dtypes[1], "boolean") self.assertEqual(response.data_frame.iloc[0, 0], bigint_val) @@ -800,12 +833,13 @@ def test_types_casting(self): ("T_EMBEDDING", oracledb.DB_TYPE_VECTOR, None, None, None, None, True), ("T_JSON", oracledb.DB_TYPE_JSON, None, None, None, None, True), ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) excepted_mysql_types = [MYSQL_DATA_TYPE.VECTOR, MYSQL_DATA_TYPE.JSON] + self.assertEqual([col.type for col in response.columns], excepted_mysql_types) for i, input_value in enumerate(input_row): result_value = response.data_frame[response.data_frame.columns[i]][0] self.assertEqual(result_value, input_value) - # endreion + # endregion def test_insert(self): """ @@ -813,9 +847,7 @@ def test_insert(self): using insertmany for batch inserts. """ mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) - mock_cursor.__exit__ = MagicMock(return_value=None) + mock_cursor = MockCursorContextManager() self.handler.connect = MagicMock(return_value=mock_conn) mock_conn.cursor = MagicMock(return_value=mock_cursor) @@ -837,9 +869,7 @@ def test_insert_error(self): Tests the insert method to ensure it correctly handles errors """ mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) - mock_cursor.__exit__ = MagicMock(return_value=None) + mock_cursor = MockCursorContextManager() self.handler.connect = MagicMock(return_value=mock_conn) mock_conn.cursor = MagicMock(return_value=mock_cursor) @@ -869,7 +899,7 @@ def test_meta_get_tables(self, table_names=None): "row_count", ], ) - mock_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + mock_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=mock_response) response = self.handler.meta_get_tables(table_names=table_names) @@ -900,7 +930,7 @@ def test_meta_get_columns(self, table_names=None): ], ) - mock_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + mock_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=mock_response) table_name = "TABLE1" @@ -934,7 +964,7 @@ def test_meta_get_column_statistics(self, table_names=None): ], ) - mock_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + mock_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=mock_response) table_names = ["STATS_TABLE"] response = self.handler.meta_get_column_statistics(table_names=table_names) @@ -975,7 +1005,7 @@ def test_meta_get_primary_keys(self): ], ) - mock_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + mock_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=mock_response) table_names = ["USERS", "ORDERS"] @@ -1024,7 +1054,7 @@ def test_meta_get_foreign_keys(self, table_names=None): ], ) - mock_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + mock_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=mock_response) table_names = ["ORDERS", "ORDER_ITEMS"] diff --git a/tests/unit/handlers/test_postgres.py b/tests/unit/handlers/test_postgres.py index 8ad5be6d414..dc6d8c64569 100644 --- a/tests/unit/handlers/test_postgres.py +++ b/tests/unit/handlers/test_postgres.py @@ -17,7 +17,12 @@ from base_handler_test import BaseDatabaseHandlerTest, MockCursorContextManager from mindsdb.integrations.handlers.postgres_handler.postgres_handler import PostgresHandler, _map_type -from mindsdb.integrations.libs.response import HandlerResponse as Response, RESPONSE_TYPE +from mindsdb.integrations.libs.response import ( + RESPONSE_TYPE, + TableResponse, + OkResponse, + ErrorResponse, +) from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE @@ -96,35 +101,64 @@ def create_handler(self): def create_patcher(self): return patch("psycopg.connect") - def test_native_query_command_ok(self): + def test_native_query_command_ok_stream(self): """ Tests the `native_query` method to ensure it executes a SQL query and handles the case where the query doesn't return a result set (ExecStatus.COMMAND_OK) """ mock_conn = MagicMock() - # Use MockCursorContextManager for simplified mocking - mock_cursor = MockCursorContextManager() + mock_cursor_server = MockCursorContextManager() + mock_cursor_client = MockCursorContextManager() self.handler.connect = MagicMock(return_value=mock_conn) - mock_conn.cursor = MagicMock(return_value=mock_cursor) + mock_conn.cursor = MagicMock(side_effect=[mock_cursor_server, mock_cursor_client]) - mock_cursor.execute.return_value = None + syntax_error = psycopg.errors.SyntaxError('syntax error at or near "insert"') + mock_cursor_server.execute.side_effect = syntax_error + mock_cursor_client.execute.return_value = None # Setup pgresult mock_pgresult = MagicMock() mock_pgresult.status = ExecStatus.COMMAND_OK - mock_cursor.pgresult = mock_pgresult - mock_cursor.rowcount = 1 + mock_cursor_client.pgresult = mock_pgresult + mock_cursor_client.rowcount = 1 query_str = "INSERT INTO table VALUES (1, 2, 3)" - data = self.handler.native_query(query_str) - mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) - self.assertFalse(data.error_code) - self.assertEqual(data.type, RESPONSE_TYPE.OK) + data = self.handler.native_query(query_str, stream=True) + mock_cursor_server.execute.assert_called_once_with(query_str) + mock_cursor_client.execute.assert_called_once_with(query_str) + assert isinstance(data, OkResponse) self.assertEqual(data.affected_rows, 1) - def test_native_query_with_results(self): + def test_native_query_command_ok_no_stream(self): + """ + Tests the `native_query` at client side execution + """ + mock_conn = MagicMock() + # mock_cursor_server = MockCursorContextManager() + mock_cursor_client = MockCursorContextManager() + + self.handler.connect = MagicMock(return_value=mock_conn) + mock_conn.cursor = MagicMock(side_effect=[mock_cursor_client]) + + # syntax_error = psycopg.errors.SyntaxError('syntax error at or near "insert"') + # mock_cursor_server.execute.side_effect = syntax_error + mock_cursor_client.execute.return_value = None + + # Setup pgresult + mock_pgresult = MagicMock() + mock_pgresult.status = ExecStatus.COMMAND_OK + mock_cursor_client.pgresult = mock_pgresult + mock_cursor_client.rowcount = 1 + + query_str = "INSERT INTO table VALUES (1, 2, 3)" + data = self.handler.native_query(query_str, stream=False) + # mock_cursor_server.execute.assert_called_once_with(query_str) + mock_cursor_client.execute.assert_called_once_with(query_str) + assert isinstance(data, OkResponse) + self.assertEqual(data.affected_rows, 1) + + def test_native_query_with_results_client_side(self): """ Tests the `native_query` method to ensure it executes a SQL query and handles the case where the query returns a result set @@ -135,7 +169,7 @@ def test_native_query_with_results(self): self.handler.connect = MagicMock(return_value=mock_conn) mock_conn.cursor = MagicMock(return_value=mock_cursor) - mock_cursor.fetchall = MagicMock(return_value=[[1, "name1"], [2, "name2"]]) + mock_cursor.fetchall = MagicMock(side_effect=[[[1, "name1"], [2, "name2"]], []]) # Create proper description objects with necessary type_code for _cast_dtypes mock_cursor.description = [ @@ -149,14 +183,51 @@ def test_native_query_with_results(self): mock_cursor.pgresult = mock_pgresult query_str = "SELECT * FROM table" - data = self.handler.native_query(query_str) + data = self.handler.native_query(query_str, stream=False) mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) - self.assertFalse(data.error_code) + assert isinstance(data, TableResponse) + assert getattr(data, "error_code", None) is None self.assertEqual(data.type, RESPONSE_TYPE.TABLE) self.assertIsInstance(data.data_frame, DataFrame) self.assertEqual(list(data.data_frame.columns), ["id", "name"]) + def test_native_query_with_results_stream(self): + """ + Tests the `native_query` method to ensure it executes a SQL query and handles the case + where the query returns a result set at server side execution + """ + mock_conn = MagicMock() + mock_cursor = MockCursorContextManager() + + self.handler.connect = MagicMock(return_value=mock_conn) + mock_conn.cursor = MagicMock(return_value=mock_cursor) + + # Server-side execution uses fetchmany, not fetchall + mock_cursor.fetchmany = MagicMock(side_effect=[[[1, "name1"], [2, "name2"]], []]) + + mock_cursor.description = [ + ColumnDescription(name="id", type_code=regtype_to_oid["integer"]), # int4 type code + ColumnDescription(name="name", type_code=regtype_to_oid["text"]), # text type code + ] + + query_str = "SELECT * FROM table" + data = self.handler.native_query(query_str, stream=True) + mock_cursor.execute.assert_called_once_with(query_str) + + # Verify the response + assert isinstance(data, TableResponse) + assert getattr(data, "error_code", None) is None + self.assertEqual(data.type, RESPONSE_TYPE.TABLE) + self.assertIsNone(data._data) + data.fetchall() + self.assertIsInstance(data._data, DataFrame) + self.assertEqual(list(data.data_frame.columns), ["id", "name"]) + + # Verify DataFrame contains all expected rows + self.assertEqual(len(data.data_frame), 2) + self.assertEqual(data.data_frame["id"].tolist(), [1, 2]) + self.assertEqual(data.data_frame["name"].tolist(), ["name1", "name2"]) + def test_native_query_with_params(self): """ Tests the `native_query` method with parameters to ensure executemany is called correctly @@ -175,8 +246,7 @@ def test_native_query_with_params(self): params = [(1, "a"), (2, "b")] data = self.handler.native_query(query_str, params=params) mock_cursor.executemany.assert_called_once_with(query_str, params) - assert isinstance(data, Response) - self.assertFalse(data.error_code) + assert isinstance(data, OkResponse) def test_native_query_error(self): """ @@ -198,8 +268,7 @@ def test_native_query_error(self): mock_cursor.execute.assert_called_once_with(query_str) - assert isinstance(data, Response) - self.assertEqual(data.type, RESPONSE_TYPE.ERROR) + assert isinstance(data, ErrorResponse) # The handler implementation sets error_code to 0, check error_message instead self.assertEqual(data.error_code, 0) @@ -260,30 +329,7 @@ def test_query_method_uses_renderer_params(self): self.assertEqual(result, "ok") self.handler.renderer.get_exec_params.assert_called_once_with(query_node, with_failback=True) - self.handler.native_query.assert_called_once_with("SELECT 1", ["foo"]) - - def test_query_stream_yields_batches(self): - mock_conn = MagicMock() - mock_cursor = MockCursorContextManager() - mock_cursor.pgresult = MagicMock(status=ExecStatus.TUPLES_OK) - mock_cursor.fetchmany = MagicMock(side_effect=[[(1, "name")], []]) - mock_cursor.description = [ - ColumnDescription(name="id", type_code=regtype_to_oid["integer"]), - ColumnDescription(name="name", type_code=regtype_to_oid["text"]), - ] - - self.handler.connect = MagicMock(return_value=mock_conn) - mock_conn.cursor = MagicMock(return_value=mock_cursor) - self.handler.renderer.get_exec_params = MagicMock(return_value=("SELECT * FROM table", None)) - self.handler.disconnect = MagicMock() - - batches = list(self.handler.query_stream(MagicMock(), fetch_size=1)) - - self.assertEqual(len(batches), 1) - self.assertListEqual(list(batches[0].columns), ["id", "name"]) - mock_conn.commit.assert_called_once() - mock_conn.rollback.assert_called_once() - self.handler.disconnect.assert_called_once() + self.handler.native_query.assert_called_once_with("SELECT 1", ["foo"], stream=False) def test_insert_respects_existing_column_case(self): if getattr(self.handler, "name", None) != "postgres": @@ -299,9 +345,8 @@ def test_insert_respects_existing_column_case(self): mock_conn.cursor = MagicMock(return_value=mock_cursor) self.handler.disconnect = MagicMock() self.handler.get_columns = MagicMock( - return_value=Response( - RESPONSE_TYPE.TABLE, - data_frame=pd.DataFrame({"COLUMN_NAME": ["Id", "Amount"]}), + return_value=TableResponse( + data=pd.DataFrame({"COLUMN_NAME": ["Id", "Amount"]}), ) ) @@ -444,13 +489,13 @@ def test_insert(self): mock_pgresult.status = ExecStatus.TUPLES_OK mock_cursor.pgresult = mock_pgresult mock_cursor.rowcount = 1 - mock_cursor.fetchall = MagicMock( - return_value=[ - ["a", "int", 1, None, "YES", None, None, None, None, None, None, None], - ["b", "int", 2, None, "YES", None, None, None, None, None, None, None], - ["c", "int", 3, None, "YES", None, None, None, None, None, None, None], - ] - ) + + get_columns_result = [ + ["id", "int", 1, None, "YES", None, None, None, None, None, None, None], + ["name", "text", 2, None, "YES", None, None, None, None, None, None, None], + ] + mock_cursor.fetchmany = MagicMock(side_effect=[get_columns_result, []]) + information_schema_description = [ ColumnDescription(name="COLUMN_NAME", type_code=regtype_to_oid["text"]), ColumnDescription(name="DATA_TYPE", type_code=regtype_to_oid["text"]), @@ -474,19 +519,6 @@ def test_insert(self): copy_obj.__enter__ = MagicMock(return_value=copy_obj) copy_obj.__exit__ = MagicMock(return_value=None) - # region add result for 'get_columns' call - mock_pgresult = MagicMock() - mock_pgresult.status = ExecStatus.TUPLES_OK - mock_cursor.pgresult = mock_pgresult - mock_cursor.fetchall = MagicMock( - return_value=[ - ["id", "int", 1, None, "YES", None, None, None, None, None, None, None], - ["name", "text", 2, None, "YES", None, None, None, None, None, None, None], - ] - ) - mock_cursor.description = information_schema_description - # endregino - df = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]}) self.handler.insert("test_table", df) @@ -643,9 +675,11 @@ def test_types_casting(self): MYSQL_DATA_TYPE.VARCHAR, MYSQL_DATA_TYPE.VARCHAR, ] - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) + + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) - self.assertEqual(response.mysql_types, excepted_mysql_types) for i, input_value in enumerate(input_row): result_value = response.data_frame[description[i].name][0] self.assertEqual(type(result_value), type(input_value), f"type mismatch: {result_value} != {input_value}") @@ -657,8 +691,9 @@ def test_types_casting(self): mock_cursor.fetchall.return_value = input_rows mock_cursor.description = [ColumnDescription(name="t_boolean", type_code=16)] excepted_mysql_types = [MYSQL_DATA_TYPE.BOOL] - response: Response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + response: TableResponse = self.handler.native_query(query_str, stream=False) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) self.assertTrue(pd_types.is_bool_dtype(response.data_frame["t_boolean"][0])) self.assertTrue(bool(response.data_frame["t_boolean"][0]) is True) self.assertTrue(bool(response.data_frame["t_boolean"][1]) is False) @@ -774,8 +809,9 @@ def test_types_casting(self): MYSQL_DATA_TYPE.FLOAT, # n_float4 MYSQL_DATA_TYPE.DOUBLE, # n_float8 ] - response: Response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + response: TableResponse = self.handler.native_query(query_str, stream=False) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) for i, input_value in enumerate(input_row): result_value = response.data_frame[description[i].name][0] self.assertEqual(result_value, input_value, f"value mismatch: {result_value} != {input_value}") @@ -850,8 +886,9 @@ def test_types_casting(self): MYSQL_DATA_TYPE.TIME, # TIMETZ ] - response: Response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + response: TableResponse = self.handler.native_query(query_str, stream=False) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) for i, input_value in enumerate(input_row): result_value = response.data_frame[description[i].name][0] self.assertEqual(result_value, input_value, f"value mismatch: {result_value} != {input_value}") @@ -866,7 +903,7 @@ def test_types_casting(self): ColumnDescription(name="t_boolean", type_code=16), ] mock_cursor.description = description - response: Response = self.handler.native_query(query_str) + response: TableResponse = self.handler.native_query(query_str, stream=False) self.assertEqual(response.data_frame.dtypes[0], "Int64") self.assertEqual(response.data_frame.dtypes[1], "boolean") self.assertEqual(response.data_frame.iloc[0, 0], bigint_val) @@ -921,8 +958,9 @@ def test_types_casting(self): MYSQL_DATA_TYPE.VECTOR, ] - response: Response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + response: TableResponse = self.handler.native_query(query_str, stream=False) + for column, mysql_type in zip(response.columns, excepted_mysql_types): + self.assertEqual(column.type, mysql_type) for i, input_value in enumerate(input_row): result_value = response.data_frame[description[i].name][0] self.assertEqual(type(result_value), type(input_value), f"type mismatch: {result_value} != {input_value}") @@ -933,7 +971,7 @@ def test_types_casting(self): # endregion def test_get_tables_all_flag(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.get_tables(all=True) query = self.handler.native_query.call_args[0][0] self.assertNotIn("current_schema()", query.split("table_schema")[-1]) @@ -955,19 +993,19 @@ def test_get_columns_with_schema_name(self): "COLLATION_NAME": [None], } ) - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=df)) + self.handler.native_query = MagicMock(return_value=TableResponse(data=df)) self.handler.get_columns("customers", schema_name="analytics") query = self.handler.native_query.call_args[0][0] self.assertIn("table_schema = 'analytics'", query) def test_meta_get_tables_filters_by_list(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_tables(table_names=["orders"]) query = self.handler.native_query.call_args[0][0] self.assertIn("IN ('orders')", query) def test_meta_get_columns_filters_by_list(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_columns(table_names=["orders"]) query = self.handler.native_query.call_args[0][0] self.assertIn("IN ('orders')", query) @@ -984,7 +1022,7 @@ def test_meta_get_column_statistics_transforms_histogram(self): "histogram_bounds": ["{1,5,10}"], } ) - response = Response(RESPONSE_TYPE.TABLE, data_frame=df) + response = TableResponse(data=df) self.handler.native_query = MagicMock(return_value=response) result = self.handler.meta_get_column_statistics(table_names=["orders"]) @@ -995,13 +1033,13 @@ def test_meta_get_column_statistics_transforms_histogram(self): self.assertEqual(result.data_frame.loc[0, "MOST_COMMON_VALUES"], ["A", "B"]) def test_meta_get_primary_keys_with_filter(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_primary_keys(table_names=["orders"]) query = self.handler.native_query.call_args[0][0] self.assertIn("AND tc.table_name IN ('orders')", query) def test_meta_get_foreign_keys_with_filter(self): - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame())) + self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) self.handler.meta_get_foreign_keys(table_names=["orders"]) query = self.handler.native_query.call_args[0][0] self.assertIn("AND tc.table_name IN ('orders')", query) diff --git a/tests/unit/handlers/test_redshift.py b/tests/unit/handlers/test_redshift.py index 8ee9a4f7e27..1d40b93fb4d 100644 --- a/tests/unit/handlers/test_redshift.py +++ b/tests/unit/handlers/test_redshift.py @@ -5,18 +5,21 @@ import pandas as pd import psycopg -from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, - RESPONSE_TYPE -) +from mindsdb.integrations.libs.response import OkResponse, ErrorResponse, RESPONSE_TYPE from mindsdb.integrations.handlers.redshift_handler.redshift_handler import RedshiftHandler from test_postgres import TestPostgresHandler class TestRedshiftHandler(TestPostgresHandler): - def create_handler(self): - return RedshiftHandler('redshift', connection_data=self.dummy_connection_data) + return RedshiftHandler("redshift", connection_data=self.dummy_connection_data) + + def test_native_query(self): + """ + This test is overridden to avoid issues with the generic MockCursorContextManager not being compatible with Postgres/Redshift cursor behavior. + More specific tests (test_native_query_with_results, test_native_query_command_ok, test_native_query_error) cover this functionality. + """ + pass def test_insert(self): """ @@ -32,20 +35,17 @@ def test_insert(self): mock_cursor.executemany.return_value = None - df = pd.DataFrame({ - 'column1': [1, 2, 3, np.nan], - 'column2': ['a', 'b', 'c', None] - }) + df = pd.DataFrame({"column1": [1, 2, 3, np.nan], "column2": ["a", "b", "c", None]}) - table_name = 'mock_table' + table_name = "mock_table" response = self.handler.insert(table_name, df) - columns = ', '.join([f'"{col}"' if ' ' in col else col for col in df.columns]) - values = ', '.join(['%s' for _ in range(len(df.columns))]) - expected_query = f'INSERT INTO {table_name} ({columns}) VALUES ({values})' + columns = ", ".join([f'"{col}"' if " " in col else col for col in df.columns]) + values = ", ".join(["%s" for _ in range(len(df.columns))]) + expected_query = f"INSERT INTO {table_name} ({columns}) VALUES ({values})" mock_cursor.executemany.assert_called_once_with(expected_query, df.replace({np.nan: None}).values.tolist()) - assert isinstance(response, Response) + assert isinstance(response, OkResponse) self.assertEqual(response.type, RESPONSE_TYPE.OK) mock_conn.commit.assert_called_once() @@ -65,17 +65,14 @@ def test_insert_error(self): error = psycopg.Error(error_msg) mock_cursor.executemany.side_effect = error - df = pd.DataFrame({ - 'column1': [1, 2, 3, np.nan], - 'column2': ['a', 'b', 'c', None] - }) + df = pd.DataFrame({"column1": [1, 2, 3, np.nan], "column2": ["a", "b", "c", None]}) - response = self.handler.insert('nonexistent_table', df) + response = self.handler.insert("nonexistent_table", df) mock_cursor.executemany.assert_called_once() mock_conn.rollback.assert_called_once() - assert isinstance(response, Response) + assert isinstance(response, ErrorResponse) self.assertEqual(response.type, RESPONSE_TYPE.ERROR) self.assertEqual(response.error_message, error_msg) @@ -91,21 +88,21 @@ def test_insert_with_empty_dataframe(self): self.handler.connect = MagicMock(return_value=mock_conn) mock_conn.cursor = MagicMock(return_value=mock_cursor) - df = pd.DataFrame(columns=['column1', 'column2']) + df = pd.DataFrame(columns=["column1", "column2"]) - table_name = 'mock_table' + table_name = "mock_table" response = self.handler.insert(table_name, df) - columns = ', '.join([f'"{col}"' if ' ' in col else col for col in df.columns]) - values = ', '.join(['%s' for _ in range(len(df.columns))]) - expected_query = f'INSERT INTO {table_name} ({columns}) VALUES ({values})' + columns = ", ".join([f'"{col}"' if " " in col else col for col in df.columns]) + values = ", ".join(["%s" for _ in range(len(df.columns))]) + expected_query = f"INSERT INTO {table_name} ({columns}) VALUES ({values})" mock_cursor.executemany.assert_called_once() call_args, call_kwargs = mock_cursor.executemany.call_args self.assertEqual(call_args[0], expected_query) self.assertEqual(len(call_args[1]), 0) - assert isinstance(response, Response) + assert isinstance(response, OkResponse) self.assertEqual(response.type, RESPONSE_TYPE.OK) mock_conn.commit.assert_called_once() @@ -123,25 +120,27 @@ def test_insert_with_special_column_names(self): self.handler.connect = MagicMock(return_value=mock_conn) mock_conn.cursor = MagicMock(return_value=mock_cursor) - df = pd.DataFrame({ - 'normal_column': [1, 2], - 'column with spaces': ['a', 'b'], - 'column-with-hyphens': [True, False], - 'mixed@column#123': [3.14, 2.71] - }) + df = pd.DataFrame( + { + "normal_column": [1, 2], + "column with spaces": ["a", "b"], + "column-with-hyphens": [True, False], + "mixed@column#123": [3.14, 2.71], + } + ) - table_name = 'mock_table' + table_name = "mock_table" response = self.handler.insert(table_name, df) call_args = mock_cursor.executemany.call_args[0][0] for col in df.columns: - if ' ' in col: + if " " in col: self.assertIn(f'"{col}"', call_args) else: self.assertTrue(col in call_args or f'"{col}"' in call_args) - assert isinstance(response, Response) + assert isinstance(response, OkResponse) self.assertEqual(response.type, RESPONSE_TYPE.OK) def test_insert_disconnect_when_needed(self): @@ -159,15 +158,15 @@ def test_insert_disconnect_when_needed(self): self.handler.disconnect = MagicMock() mock_conn.cursor = MagicMock(return_value=mock_cursor) - df = pd.DataFrame({'column1': [1, 2, 3]}) - self.handler.insert('mock_table', df) + df = pd.DataFrame({"column1": [1, 2, 3]}) + self.handler.insert("mock_table", df) self.handler.disconnect.assert_called_once() self.handler.connect.reset_mock() self.handler.disconnect.reset_mock() self.handler.is_connected = True - self.handler.insert('mock_table', df) + self.handler.insert("mock_table", df) self.handler.disconnect.assert_not_called() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/handlers/test_s3.py b/tests/unit/handlers/test_s3.py index 16f3e7e64b2..6911cef1cf7 100644 --- a/tests/unit/handlers/test_s3.py +++ b/tests/unit/handlers/test_s3.py @@ -11,32 +11,33 @@ from base_handler_test import BaseHandlerTestSetup from mindsdb.integrations.handlers.s3_handler.s3_handler import S3Handler from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, + OkResponse, + TableResponse, + DataHandlerResponse as Response, HandlerStatusResponse as StatusResponse, - RESPONSE_TYPE + RESPONSE_TYPE, ) class TestS3Handler(BaseHandlerTestSetup, unittest.TestCase): - @property def object_name(self): - return '`my-bucket/my-file.csv`' + return "`my-bucket/my-file.csv`" @property def dummy_connection_data(self): return OrderedDict( - aws_access_key_id='AQAXEQK89OX07YS34OP', - aws_secret_access_key='wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', - bucket='mindsdb-bucket', - region_name='us-east-2', + aws_access_key_id="AQAXEQK89OX07YS34OP", + aws_secret_access_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + bucket="mindsdb-bucket", + region_name="us-east-2", ) def create_handler(self): - return S3Handler('s3', connection_data=self.dummy_connection_data) + return S3Handler("s3", connection_data=self.dummy_connection_data) def create_patcher(self): - return patch('boto3.client') + return patch("boto3.client") def test_connect(self): """ @@ -51,7 +52,7 @@ def test_connect(self): self.assertTrue(self.handler.is_connected) self.mock_connect.assert_called_once() - @patch('boto3.client') + @patch("boto3.client") def test_check_connection_success(self, mock_boto3_client): """ Test that the `check_connection` method returns a StatusResponse object and accurately reflects the connection status on a successful connection. @@ -66,7 +67,7 @@ def test_check_connection_success(self, mock_boto3_client): assert isinstance(response, StatusResponse) self.assertFalse(response.error_message) - @patch('boto3.client') + @patch("boto3.client") def test_check_connection_failure_invalid_bucket_or_no_access(self, mock_boto3_client): """ Test that the `check_connection` method returns a StatusResponse object and accurately reflects the connection status on failed connection due to invalid bucket or lack of access permissions. @@ -76,12 +77,12 @@ def test_check_connection_failure_invalid_bucket_or_no_access(self, mock_boto3_c mock_boto3_client.return_value = mock_boto3_client_instance mock_boto3_client_instance.head_bucket.side_effect = ClientError( error_response={ - 'Error': { - 'Code': '404', - 'Message': 'Not Found', + "Error": { + "Code": "404", + "Message": "Not Found", } }, - operation_name='HeadBucket' + operation_name="HeadBucket", ) response = self.handler.check_connection() @@ -90,7 +91,7 @@ def test_check_connection_failure_invalid_bucket_or_no_access(self, mock_boto3_c assert isinstance(response, StatusResponse) self.assertTrue(response.error_message) - @patch('boto3.client') + @patch("boto3.client") def test_query_select(self, mock_boto3_client): """ Tests the `query` method to ensure it executes a SELECT SQL query using a mock cursor and returns a Response object. @@ -104,18 +105,11 @@ def test_query_select(self, mock_boto3_client): duckdb_connect = MagicMock() self.handler._connect_duckdb = duckdb_connect duckdb_execute = duckdb_connect().__enter__().execute - duckdb_execute().fetchdf.return_value = pd.DataFrame([], columns=['col_2']) + duckdb_execute().fetchdf.return_value = pd.DataFrame([], columns=["col_2"]) # Craft the SELECT query and execute it. - object_name = 'my-bucket/my-file.csv' - select = ast.Select( - targets=[ - Star() - ], - from_table=Identifier( - parts=[object_name] - ) - ) + object_name = "my-bucket/my-file.csv" + select = ast.Select(targets=[Star()], from_table=Identifier(parts=[object_name])) duckdb_execute.reset_mock() response = self.handler.query(select) @@ -124,10 +118,9 @@ def test_query_select(self, mock_boto3_client): f"SELECT * FROM 's3://{self.dummy_connection_data['bucket']}/{object_name.replace('`', '')}'" ) - assert isinstance(response, Response) - self.assertFalse(response.error_code) + assert isinstance(response, TableResponse) - @patch('boto3.client') + @patch("boto3.client") def test_query_insert(self, mock_boto3_client): """ Tests the `query` method to ensure it executes a INSERT SQL query using a mock cursor and returns a Response object. @@ -145,29 +138,25 @@ def test_query_insert(self, mock_boto3_client): duckdb_execute().fetchdf.return_value = None # Craft the INSERT query and execute it. - columns = ['col_1', 'col_2'] - values = [('val_1', 'val_2')] - insert = ast.Insert( - table=Identifier( - parts=[self.object_name] - ), - columns=columns, - values=values - ) + columns = ["col_1", "col_2"] + values = [("val_1", "val_2")] + insert = ast.Insert(table=Identifier(parts=[self.object_name]), columns=columns, values=values) duckdb_execute.reset_mock() response = self.handler.query(insert) sqls = [i[0][0] for i in duckdb_execute.call_args_list] - assert sqls[0] == f"CREATE TABLE tmp_table AS SELECT * FROM 's3://{self.dummy_connection_data['bucket']}/{self.object_name}'" + assert ( + sqls[0] + == f"CREATE TABLE tmp_table AS SELECT * FROM 's3://{self.dummy_connection_data['bucket']}/{self.object_name}'" + ) assert sqls[1] == "INSERT INTO tmp_table BY NAME SELECT * FROM df" assert sqls[2] == f"COPY tmp_table TO 's3://{self.dummy_connection_data['bucket']}/{self.object_name}'" - assert isinstance(response, Response) - self.assertFalse(response.error_code) + assert isinstance(response, OkResponse) - @patch('boto3.client') + @patch("boto3.client") def test_get_tables(self, mock_boto3_client): """ Test that the `get_tables` method correctly calls the `list_objects_v2` method and returns a Response object with the supported objects (files). @@ -176,12 +165,12 @@ def test_get_tables(self, mock_boto3_client): mock_boto3_client_instance = MagicMock() mock_boto3_client.return_value = mock_boto3_client_instance mock_boto3_client_instance.list_objects_v2.return_value = { - 'Contents': [ - {'Key': 'file1.csv'}, - {'Key': 'file2.tsv'}, - {'Key': 'file3.json'}, - {'Key': 'file4.parquet'}, - {'Key': 'file5.xlsx'}, + "Contents": [ + {"Key": "file1.csv"}, + {"Key": "file2.tsv"}, + {"Key": "file3.json"}, + {"Key": "file4.parquet"}, + {"Key": "file5.xlsx"}, ] } @@ -192,37 +181,32 @@ def test_get_tables(self, mock_boto3_client): df = response.data_frame self.assertEqual(len(df), 5) # +1 table is 'files' - self.assertNotIn('file5.xlsx', df['table_name'].values) + self.assertNotIn("file5.xlsx", df["table_name"].values) - @patch('mindsdb.integrations.handlers.s3_handler.s3_handler.S3Handler.query') + @patch("mindsdb.integrations.handlers.s3_handler.s3_handler.S3Handler.query") def test_get_columns(self, mock_query): """ Test that the `get_columns` method correctly constructs the SQL query and calls `native_query` with the correct query. """ - mock_query.return_value = Response( - RESPONSE_TYPE.TABLE, - data_frame=pd.DataFrame( + mock_query.return_value = TableResponse( + data=pd.DataFrame( data={ - 'col_1': ['row_1', 'row_2', 'row_3'], - 'col_2': [1, 2, 3], + "col_1": ["row_1", "row_2", "row_3"], + "col_2": [1, 2, 3], }, ) ) - table_name = 'mock_table' + table_name = "mock_table" response = self.handler.get_columns(table_name) - expected_query = Select( - targets=[Star()], - from_table=Identifier(parts=[table_name]), - limit=Constant(1) - ) + expected_query = Select(targets=[Star()], from_table=Identifier(parts=[table_name]), limit=Constant(1)) self.handler.query.assert_called_once_with(expected_query) df = response.data_frame - self.assertEqual(df.columns.tolist(), ['column_name', 'data_type']) - self.assertEqual(df['data_type'].values.tolist(), ['string', 'int64']) + self.assertEqual(df.columns.tolist(), ["column_name", "data_type"]) + self.assertEqual(df["data_type"].values.tolist(), ["string", "int64"]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/handlers/test_salesforce.py b/tests/unit/handlers/test_salesforce.py index 54253f3eef8..62be61a18f2 100644 --- a/tests/unit/handlers/test_salesforce.py +++ b/tests/unit/handlers/test_salesforce.py @@ -16,7 +16,7 @@ from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator from mindsdb.integrations.libs.response import ( - HandlerResponse as Response, + TableResponse, HandlerStatusResponse as StatusResponse, RESPONSE_TYPE, ) @@ -157,7 +157,7 @@ def test_check_connection_failure(self): def test_get_tables(self): """ - Test that the `get_tables` method returns a list of tables mapped from the Salesforce API. + Test that the `get_tables` method returns a TableResponse with a list of tables mapped from the Salesforce API. """ mock_tables = ["Account", "Contact"] self.mock_connect.return_value = MagicMock( @@ -168,7 +168,7 @@ def test_get_tables(self): self.handler.connect() response = self.handler.get_tables() - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -177,7 +177,7 @@ def test_get_tables(self): def test_get_columns(self): """ - Test that the `get_columns` method returns a list of columns for a given table. + Test that the `get_columns` method returns a TableResponse with a list of columns for a given table. """ mock_columns = ["Id", "Name", "Email"] mock_table = "Contact" @@ -203,7 +203,7 @@ def test_get_columns(self): self.handler.connect() response = self.handler.get_columns(mock_table) - assert isinstance(response, Response) + assert isinstance(response, TableResponse) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) df = response.data_frame @@ -435,7 +435,7 @@ def test_meta_get_tables_filters_requested_tables(self): with patch( "mindsdb.integrations.handlers.salesforce_handler.salesforce_handler.MetaAPIHandler.meta_get_tables", - return_value=Response(RESPONSE_TYPE.TABLE, None), + return_value=TableResponse(), ) as mock_meta: response = self.handler.meta_get_tables(table_names=["contact"]) diff --git a/tests/unit/handlers/test_slack.py b/tests/unit/handlers/test_slack.py index 59c64de18b0..62a9ada8bc8 100644 --- a/tests/unit/handlers/test_slack.py +++ b/tests/unit/handlers/test_slack.py @@ -12,7 +12,7 @@ import pandas as pd from base_handler_test import BaseAPIChatHandlerTest, BaseAPIResourceTestSetup -from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse, HandlerResponse as Response +from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse, TableResponse from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator try: @@ -431,7 +431,7 @@ def test_native_query(self): response = self.handler.native_query(query) self.mock_connect.return_value.conversations_info.assert_called_once_with(channel="C1234567890") - assert isinstance(response, Response) + assert isinstance(response, TableResponse) expected_df = pd.DataFrame([MOCK_RESPONSE_CONV_INFO_1["channel"]]) pd.testing.assert_frame_equal(response.data_frame, expected_df) @@ -451,7 +451,7 @@ def test_native_query_with_pagination(self): self.mock_connect.return_value.conversations_list.assert_any_call() self.mock_connect.return_value.conversations_list.assert_any_call(cursor="dGVhbTpDMDYxRkE1UEI=") - assert isinstance(response, Response) + assert isinstance(response, TableResponse) expected_df = pd.DataFrame(MOCK_RESPONSE_CONV_LIST_1["channels"] + MOCK_RESPONSE_CONV_LIST_2["channels"]) pd.testing.assert_frame_equal(response.data_frame, expected_df) diff --git a/tests/unit/handlers/test_snowflake.py b/tests/unit/handlers/test_snowflake.py index e43aec5eac5..7118f3f1602 100644 --- a/tests/unit/handlers/test_snowflake.py +++ b/tests/unit/handlers/test_snowflake.py @@ -16,11 +16,16 @@ import numpy as np import pandas as pd from pandas import DataFrame -from types import SimpleNamespace from base_handler_test import BaseDatabaseHandlerTest -from mindsdb.integrations.libs.response import HandlerResponse as Response, INF_SCHEMA_COLUMNS_NAMES_SET, RESPONSE_TYPE +from mindsdb.integrations.libs.response import ( + OkResponse, + TableResponse, + ErrorResponse, + INF_SCHEMA_COLUMNS_NAMES_SET, + RESPONSE_TYPE, +) from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE @@ -246,8 +251,7 @@ def test_native_query_with_results(self): mock_cursor.fetch_pandas_batches.assert_called_once() mock_cursor.fetchall.assert_not_called() - self.assertIsInstance(data, Response) - self.assertFalse(data.error_code) + self.assertIsInstance(data, TableResponse) self.assertEqual(data.type, RESPONSE_TYPE.TABLE) self.assertIsInstance(data.data_frame, DataFrame) self.assertListEqual(list(data.data_frame.columns), expected_columns) @@ -285,8 +289,7 @@ def test_native_query_no_results(self): mock_cursor.execute.assert_called_once_with(query_str) mock_cursor.fetch_pandas_batches.assert_called_once() - self.assertIsInstance(data, Response) - self.assertFalse(data.error_code) + self.assertIsInstance(data, OkResponse) self.assertEqual(data.type, RESPONSE_TYPE.OK) self.assertEqual(data.affected_rows, 1) @@ -350,7 +353,7 @@ def test_native_query_error(self): mock_conn.cursor.assert_called_once() mock_cursor.execute.assert_called_once_with(query_str) - self.assertIsInstance(data, Response) + self.assertIsInstance(data, ErrorResponse) self.assertEqual(data.type, RESPONSE_TYPE.ERROR) self.assertIn(error_msg, data.error_message) @@ -376,33 +379,11 @@ def test_native_query_releases_memory_pool_when_jemalloc(self): mock_pool.backend_name = "jemalloc" mock_pool.release_unused = MagicMock() - response = self.handler.native_query("SELECT 1") + response = self.handler.native_query("SELECT 1", stream=False) self.assertEqual(response.type, RESPONSE_TYPE.TABLE) mock_pool.release_unused.assert_called_once() - def test_native_query_memory_estimation_error(self): - mock_conn = MagicMock() - mock_cursor = MagicMock() - mock_cursor.__enter__.return_value = mock_cursor - mock_cursor.__exit__.return_value = None - large_df = DataFrame({"ID": range(1500)}) - mock_cursor.fetch_pandas_batches.return_value = iter([large_df]) - mock_cursor.description = [ColumnDescription(name="ID", type_code=0, scale=0)] - mock_cursor.rowcount = 10000 - - self.handler.connect = MagicMock(return_value=mock_conn) - mock_conn.cursor.return_value = mock_cursor - - with patch( - "mindsdb.integrations.handlers.snowflake_handler.snowflake_handler.psutil.virtual_memory", - return_value=SimpleNamespace(available=512), - ): - response = self.handler.native_query("SELECT * FROM big_table") - - self.assertEqual(response.type, RESPONSE_TYPE.ERROR) - self.assertIn("query result is too large", response.error_message) - def test_key_pair_authentication_success(self): """ Tests successful connection using key pair authentication @@ -645,7 +626,7 @@ def test_query_method(self): renderer_mock.get_string.return_value = "SELECT * FROM test_table_rendered" self.handler.native_query = MagicMock() - expected_response = Response(RESPONSE_TYPE.TABLE) + expected_response = TableResponse(data=DataFrame()) self.handler.native_query.return_value = expected_response try: @@ -673,11 +654,8 @@ def test_get_tables(self): """ Tests that get_tables calls native_query with the correct SQL for Snowflake """ - expected_response = Response( - RESPONSE_TYPE.TABLE, - data_frame=DataFrame( - [("table1", "SCHEMA1", "BASE TABLE")], columns=["TABLE_NAME", "TABLE_SCHEMA", "TABLE_TYPE"] - ), + expected_response = TableResponse( + data=DataFrame([("table1", "SCHEMA1", "BASE TABLE")], columns=["TABLE_NAME", "TABLE_SCHEMA", "TABLE_TYPE"]) ) self.handler.native_query = MagicMock(return_value=expected_response) @@ -751,7 +729,7 @@ def test_get_columns(self): ] expected_df = DataFrame(expected_df_data, columns=query_columns) - expected_response = Response(RESPONSE_TYPE.TABLE, data_frame=expected_df) + expected_response = TableResponse(data=expected_df) self.handler.native_query = MagicMock(return_value=expected_response) table_name = "test_table" @@ -794,7 +772,7 @@ def test_meta_get_tables_casts_rowcount(self): } ] ) - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=df)) + self.handler.native_query = MagicMock(return_value=TableResponse(data=df)) result = self.handler.meta_get_tables(table_names=["orders"]) @@ -815,7 +793,7 @@ def test_meta_get_columns_filters(self): } ] ) - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=df)) + self.handler.native_query = MagicMock(return_value=TableResponse(data=df)) result = self.handler.meta_get_columns(table_names=["orders"]) @@ -849,8 +827,8 @@ def test_meta_get_column_statistics_success(self): ) self.handler.native_query = MagicMock( side_effect=[ - Response(RESPONSE_TYPE.TABLE, data_frame=columns_df), - Response(RESPONSE_TYPE.TABLE, data_frame=stats_df), + TableResponse(data=columns_df), + TableResponse(data=stats_df), ] ) @@ -864,9 +842,7 @@ def test_meta_get_column_statistics_success(self): self.assertEqual(id_stats["maximum_value"], 10) def test_meta_get_column_statistics_handles_error_response(self): - self.handler.native_query = MagicMock( - return_value=Response(RESPONSE_TYPE.ERROR, error_message="boom", data_frame=None) - ) + self.handler.native_query = MagicMock(return_value=ErrorResponse(error_message="boom")) result = self.handler.meta_get_column_statistics(table_names=["orders"]) self.assertEqual(result.type, RESPONSE_TYPE.ERROR) @@ -877,7 +853,7 @@ def test_meta_get_primary_keys_filters(self): {"table_name": "CUSTOMERS", "column_name": "ID", "key_sequence": 1, "constraint_name": "PK_CUSTOMERS"}, ] ) - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=df)) + self.handler.native_query = MagicMock(return_value=TableResponse(data=df)) result = self.handler.meta_get_primary_keys(table_names=["ORDERS"]) @@ -909,7 +885,7 @@ def test_meta_get_foreign_keys_filters(self): }, ] ) - self.handler.native_query = MagicMock(return_value=Response(RESPONSE_TYPE.TABLE, data_frame=df)) + self.handler.native_query = MagicMock(return_value=TableResponse(data=df)) result = self.handler.meta_get_foreign_keys(table_names=["ORDERS", "CUSTOMERS"]) @@ -1195,7 +1171,8 @@ def test_types_casting(self): ] response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + actual_mysql_types = [col.type for col in response.columns] + self.assertEqual(actual_mysql_types, excepted_mysql_types) for column_name in input_data.columns: result_value = response.data_frame[column_name][0] self.assertEqual(result_value, input_data[column_name][0]) @@ -1346,7 +1323,8 @@ def test_types_casting(self): ] response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + actual_mysql_types = [col.type for col in response.columns] + self.assertEqual(actual_mysql_types, excepted_mysql_types) for column_name in input_data.columns: result_value = response.data_frame[column_name][0] self.assertEqual(result_value, input_data[column_name][0]) @@ -1380,7 +1358,8 @@ def test_types_casting(self): excepted_mysql_types = [MYSQL_DATA_TYPE.BOOLEAN] response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + actual_mysql_types = [col.type for col in response.columns] + self.assertEqual(actual_mysql_types, excepted_mysql_types) for column_name in input_data.columns: result_value = response.data_frame[column_name][0] self.assertEqual(result_value, input_data[column_name][0]) @@ -1616,7 +1595,8 @@ def test_types_casting(self): } ) response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + actual_mysql_types = [col.type for col in response.columns] + self.assertEqual(actual_mysql_types, excepted_mysql_types) self.assertTrue(response.data_frame.equals(expected_result_df)) # endregion @@ -1679,7 +1659,8 @@ def test_types_casting(self): } ) response = self.handler.native_query(query_str) - self.assertEqual(response.mysql_types, excepted_mysql_types) + actual_mysql_types = [col.type for col in response.columns] + self.assertEqual(actual_mysql_types, excepted_mysql_types) self.assertTrue(response.data_frame.equals(expected_result_df)) # endregion diff --git a/tests/unit/handlers/test_timescaledb.py b/tests/unit/handlers/test_timescaledb.py index 32c3efb46de..52cbd771908 100644 --- a/tests/unit/handlers/test_timescaledb.py +++ b/tests/unit/handlers/test_timescaledb.py @@ -7,22 +7,19 @@ from base_handler_test import BaseDatabaseHandlerTest, MockCursorContextManager from mindsdb.integrations.handlers.timescaledb_handler.timescaledb_handler import TimeScaleDBHandler -from mindsdb.integrations.libs.response import ( - HandlerResponse as Response -) +from mindsdb.integrations.libs.response import DataHandlerResponse as Response class TestTimescaleHandler(BaseDatabaseHandlerTest, unittest.TestCase): - @property def dummy_connection_data(self): return OrderedDict( - host='127.0.0.1', + host="127.0.0.1", port=5432, - user='example_user', - schema='public', - password='example_pass', - database='example_db' + user="example_user", + schema="public", + password="example_pass", + database="example_db", ) @property @@ -69,10 +66,10 @@ def get_columns_query(self): """ def create_handler(self): - return TimeScaleDBHandler('timescaledb', connection_data=self.dummy_connection_data) + return TimeScaleDBHandler("timescaledb", connection_data=self.dummy_connection_data) def create_patcher(self): - return patch('psycopg.connect') + return patch("psycopg.connect") def test_native_query(self): """ @@ -99,5 +96,5 @@ def test_native_query(self): self.assertFalse(data.error_code) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/integrations/__init__.py b/tests/unit/integrations/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/integrations/libs/__init__.py b/tests/unit/integrations/libs/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/integrations/libs/test_response.py b/tests/unit/integrations/libs/test_response.py new file mode 100644 index 00000000000..18aa870d939 --- /dev/null +++ b/tests/unit/integrations/libs/test_response.py @@ -0,0 +1,671 @@ +"""Unit tests for response classes in mindsdb.integrations.libs.response module. + +This module tests all response types used by handlers: +- TableResponse: for queries that return data (SELECT, SHOW, etc.) +- OkResponse: for successful operations without data (CREATE, DROP, etc.) +- ErrorResponse: for error cases +- HandlerStatusResponse: for connection status checks +- normalize_response: for converting legacy HandlerResponse to new types +- _safe_pandas_concat: memory-safe DataFrame concatenation +""" + +from unittest.mock import patch, MagicMock + +import pandas as pd +import pytest + +from mindsdb.integrations.libs.response import ( + TableResponse, + OkResponse, + ErrorResponse, + HandlerStatusResponse, + HandlerResponse, + normalize_response, + _safe_pandas_concat, + RESPONSE_TYPE, + DataHandlerResponse, +) +from mindsdb.utilities.types.column import Column +from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE + + +def _mock_virtual_memory(available_kb: int): + """Create a mock for psutil.virtual_memory() with given available memory in KB.""" + mock_mem = MagicMock() + mock_mem.available = available_kb << 10 # convert KB back to bytes + return mock_mem + + +class TestHandlerStatusResponse: + """Tests for HandlerStatusResponse class.""" + + def test_init_success(self): + """Test initialization with success status.""" + redirect_url = "https://example.com/auth" + copy_storage = "s3://bucket/path" + response = HandlerStatusResponse(success=True, redirect_url=redirect_url, copy_storage=copy_storage) + + assert response.success is True + assert response.error_message is None + assert response.redirect_url == redirect_url + assert response.copy_storage == copy_storage + + json_data = response.to_json() + assert json_data["success"] is True + assert json_data["error"] is None + assert json_data["redirect_url"] == redirect_url + assert json_data["copy_storage"] == copy_storage + + def test_init_failure(self): + """Test initialization with failure status.""" + error_msg = "Connection failed" + response = HandlerStatusResponse(success=False, error_message=error_msg) + + assert response.success is False + assert response.error_message == error_msg + assert response.redirect_url is None + assert response.copy_storage is None + + json_data = response.to_json() + assert json_data["success"] is False + assert json_data["error"] == error_msg + assert "redirect_url" not in json_data + assert "copy_storage" not in json_data + + +class TestErrorResponse: + """Unit tests for ErrorResponse class.""" + + def test_init_basic(self): + """Test basic initialization.""" + response = ErrorResponse(error_code=1, error_message="Test error", is_expected_error=True) + + assert response.type == RESPONSE_TYPE.ERROR + assert response.resp_type == RESPONSE_TYPE.ERROR + assert response.error_code == 1 + assert response.error_message == "Test error" + assert response.is_expected_error is True + assert response.exception is None + assert isinstance(response, DataHandlerResponse) + + def test_exception_capture(self): + """Test that exception is captured from current context.""" + try: + raise ValueError("Test exception") + except ValueError: + response = ErrorResponse(error_message="Caught exception") + assert response.exception is not None + assert isinstance(response.exception, ValueError) + + +class TestOkResponse: + """Unit tests for OkResponse class.""" + + def test_init(self): + """Test initialization with affected rows count.""" + response = OkResponse(affected_rows=5) + + assert response.type == RESPONSE_TYPE.OK + assert response.resp_type == RESPONSE_TYPE.OK + assert response.affected_rows == 5 + assert isinstance(response, DataHandlerResponse) + + def test_init_without_affected_rows(self): + """Test initialization without affected rows.""" + response = OkResponse() + + assert response.affected_rows is None + + +class TestTableResponse: + """Unit tests for TableResponse class.""" + + def test_init_with_data(self): + """Test initialization with DataFrame.""" + df = pd.DataFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + response = TableResponse(data=df) + + assert response.type == RESPONSE_TYPE.TABLE + assert response.resp_type == RESPONSE_TYPE.TABLE + assert response._fetched is True + pd.testing.assert_frame_equal(response._data, df) + # 'columns' was not provided as attr, so should be as in df + assert [c.name for c in response.columns] == ["id", "name"] + + def test_complex_init_with_generator(self): + """Test initialization with data generator.""" + column1 = Column(name="id", type=MYSQL_DATA_TYPE.INT) + column2 = Column(name="name", type=MYSQL_DATA_TYPE.VARCHAR) + columns = [column1, column2] + df = pd.DataFrame({"id": [0, 1], "name": ["a", "b"]}) + df1 = pd.DataFrame({"id": [2, 3], "name": ["d", "e"]}) + df2 = pd.DataFrame({"id": [4, 5], "name": ["f", "g"]}) + + def data_gen(): + yield df1 + yield df2 + + response = TableResponse(data=df, data_generator=data_gen(), columns=columns) + + assert response.columns[0] is column1 + assert response.columns[1] is column2 + assert response.data_generator is not None + pd.testing.assert_frame_equal(response._data, df) + assert response._fetched is False + pieces = [] + while isinstance(el := response.fetchmany(), pd.DataFrame): + pieces.append(el) + pd.testing.assert_frame_equal(pieces[0], df1) + pd.testing.assert_frame_equal(pieces[1], df2) + pd.testing.assert_frame_equal(response._data, pd.concat([df, df1, df2])) + assert response._fetched is True + assert response.data_generator is None + + def test_data_frame_property(self): + """Test initialization with explicit columns.""" + columns = [Column(name="id", type=MYSQL_DATA_TYPE.INT), Column(name="name", type=MYSQL_DATA_TYPE.VARCHAR)] + df = pd.DataFrame({"id": [0, 1], "name": ["a", "b"]}) + df1 = pd.DataFrame({"id": [2, 3], "name": ["d", "e"]}) + df2 = pd.DataFrame({"id": [4, 5], "name": ["f", "g"]}) + + def data_gen(): + yield df1 + yield df2 + + response = TableResponse(data=df, data_generator=data_gen(), columns=columns) + assert response._fetched is False + pd.testing.assert_frame_equal(response._data, df) + pd.testing.assert_frame_equal(response.data_frame, pd.concat([df, df1, df2])) + assert response._fetched is True + + # should not change result + response.fetchall() + pd.testing.assert_frame_equal(response.data_frame, pd.concat([df, df1, df2])) + + def test_init_with_affected_rows(self): + """Test initialization with affected_rows.""" + df = pd.DataFrame({"id": [1, 2, 3]}) + response = TableResponse(data=df, affected_rows=100) + + assert response.affected_rows == 100 + + def test_iterate_no_save_no_generator(self): + """Test iterate_no_save yields existing data.""" + df = pd.DataFrame({"id": [1, 2, 3]}) + # Need to provide a generator (even empty) to avoid TypeError + response = TableResponse(data=df, data_generator=iter([])) + + chunks = list(response.iterate_no_save()) + + assert len(chunks) == 1 + pd.testing.assert_frame_equal(chunks[0], df) + + # after `iterate_no_save` result should be invalid + with pytest.raises(ValueError): + pd.testing.assert_frame_equal(response.data_frame, df) + + def test_iterate_no_save_with_generator(self): + """Test iterate_no_save yields all chunks without saving.""" + df1 = pd.DataFrame({"id": [4, 5]}) + df2 = pd.DataFrame({"id": [6, 7]}) + + def data_gen(): + yield df1 + yield df2 + + df = pd.DataFrame({"id": [1, 2, 3]}) + response = TableResponse(data=df, data_generator=data_gen()) + chunks = list(response.iterate_no_save()) + + assert len(chunks) == 3 + pd.testing.assert_frame_equal(chunks[0], df) + pd.testing.assert_frame_equal(chunks[1], df1) + pd.testing.assert_frame_equal(chunks[2], df2) + + # after `iterate_no_save` result should be invalid + with pytest.raises(ValueError): + pd.testing.assert_frame_equal(response.data_frame, df) + + +class TestNormalizeResponse: + """Unit tests for normalize_response function.""" + + def test_normalize_table_response(self): + """Test that TableResponse is returned as-is.""" + original = TableResponse(data=pd.DataFrame({"id": [1, 2]})) + result = normalize_response(original) + + assert result is original + + def test_normalize_ok_response(self): + """Test that OkResponse is returned as-is.""" + original = OkResponse(affected_rows=5) + result = normalize_response(original) + + assert result is original + + def test_normalize_error_response(self): + """Test that ErrorResponse is returned as-is.""" + original = ErrorResponse(error_message="Test error") + result = normalize_response(original) + + assert result is original + + def test_normalize_legacy_error_response(self): + """Test conversion of legacy HandlerResponse with ERROR type.""" + legacy = HandlerResponse(resp_type=RESPONSE_TYPE.ERROR, error_code=1, error_message="Legacy error") + result = normalize_response(legacy) + + assert isinstance(result, ErrorResponse) + assert result.error_code == 1 + assert result.error_message == "Legacy error" + + def test_normalize_legacy_ok_response(self): + """Test conversion of legacy HandlerResponse with OK type.""" + legacy = HandlerResponse(resp_type=RESPONSE_TYPE.OK, affected_rows=10) + result = normalize_response(legacy) + + assert isinstance(result, OkResponse) + assert result.affected_rows == 10 + + def test_normalize_legacy_table_response(self): + """Test conversion of legacy HandlerResponse with TABLE type.""" + df = pd.DataFrame({"id": [1, 2], "name": ["a", "b"]}) + legacy = HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=df) + result = normalize_response(legacy) + + assert isinstance(result, TableResponse) + pd.testing.assert_frame_equal(result.data_frame, df) + + def test_normalize_legacy_table_response_with_mysql_types(self): + """Test conversion preserves mysql_types as column types.""" + df = pd.DataFrame({"id": [1, 2], "name": ["a", "b"]}) + mysql_types = [MYSQL_DATA_TYPE.INT, MYSQL_DATA_TYPE.VARCHAR] + legacy = HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=df, mysql_types=mysql_types) + result = normalize_response(legacy) + + assert isinstance(result, TableResponse) + assert len(result.columns) == 2 + assert result.columns[0].type == MYSQL_DATA_TYPE.INT + assert result.columns[1].type == MYSQL_DATA_TYPE.VARCHAR + + def test_normalize_legacy_table_response_empty_dataframe(self): + """Test conversion with empty DataFrame.""" + df = pd.DataFrame() + legacy = HandlerResponse(resp_type=RESPONSE_TYPE.TABLE, data_frame=df) + result = normalize_response(legacy) + + assert isinstance(result, TableResponse) + assert len(result.columns) == 0 + + +class TestSafePandasConcat: + """Unit tests for _safe_pandas_concat function.""" + + @patch("mindsdb.integrations.libs.response.psutil") + def test_concat_with_enough_memory(self, mock_psutil): + """Test successful concatenation when sufficient memory is available.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df1 = pd.DataFrame({"id": [1, 2]}) + df2 = pd.DataFrame({"id": [3, 4]}) + result = _safe_pandas_concat([df1, df2]) + + pd.testing.assert_frame_equal(result, pd.concat([df1, df2])) + + @patch("mindsdb.integrations.libs.response.psutil") + def test_concat_raises_memory_error_when_not_enough_memory(self, mock_psutil): + """Test MemoryError is raised when available memory is too low.""" + # Set available memory to essentially 0 + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=10) + + df1 = pd.DataFrame({"x": list(range(1000))}) + df2 = pd.DataFrame({"x": list(range(1000))}) + + with pytest.raises(MemoryError): + _safe_pandas_concat([df1, df2]) + + @patch("mindsdb.integrations.libs.response.psutil") + def test_concat_single_piece(self, mock_psutil): + """Test concatenation with a single DataFrame.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df = pd.DataFrame({"id": [1, 2, 3]}) + result = _safe_pandas_concat([df]) + + pd.testing.assert_frame_equal(result, df) + + +class TestRaiseIfLowMemory: + """Unit tests for TableResponse._raise_if_low_memory method.""" + + @patch("mindsdb.integrations.libs.response.psutil") + def test_with_known_affected_rows_enough_memory(self, mock_psutil): + """Test no error when affected_rows is known and memory is sufficient.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + response = TableResponse(data=pd.DataFrame({"id": [1, 2]}), affected_rows=100) + response._last_data_piece = pd.DataFrame({"id": list(range(10))}) + response.rows_fetched = 10 + + # Should not raise + response._raise_if_low_memory() + + @patch("mindsdb.integrations.libs.response.psutil") + def test_with_known_affected_rows_not_enough_memory(self, mock_psutil): + """Test MemoryError when affected_rows is known and memory is insufficient.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1) + + # Use strings to ensure DataFrame memory > 1KB after >> 10 + large_piece = pd.DataFrame({"text": ["x" * 200 for _ in range(100)]}) + response = TableResponse(data=pd.DataFrame({"text": ["a"]}), affected_rows=1000) + response._last_data_piece = large_piece + response.rows_fetched = 100 + + with pytest.raises(MemoryError, match="Not enough memory"): + response._raise_if_low_memory() + + @patch("mindsdb.integrations.libs.response.psutil") + def test_with_unknown_affected_rows_enough_memory(self, mock_psutil): + """Test no error when affected_rows is None and memory is sufficient.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + response = TableResponse(data=pd.DataFrame({"id": [1, 2]})) + response._last_data_piece = pd.DataFrame({"id": list(range(10))}) + + # Should not raise + response._raise_if_low_memory() + + @patch("mindsdb.integrations.libs.response.psutil") + def test_with_unknown_affected_rows_not_enough_memory(self, mock_psutil): + """Test MemoryError when affected_rows is None and memory is insufficient.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1) + + # Use strings to ensure DataFrame memory > 1KB after >> 10 + large_piece = pd.DataFrame({"text": ["x" * 200 for _ in range(100)]}) + response = TableResponse(data=pd.DataFrame({"text": ["a"]})) + response._last_data_piece = large_piece + + with pytest.raises(MemoryError, match="Not enough memory"): + response._raise_if_low_memory() + + @patch("mindsdb.integrations.libs.response.psutil") + def test_all_rows_already_fetched(self, mock_psutil): + """Test no error when all rows have been fetched (rows_expected = 0).""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=0) + + response = TableResponse(data=pd.DataFrame({"id": [1, 2]}), affected_rows=10) + response._last_data_piece = pd.DataFrame({"id": list(range(10))}) + response.rows_fetched = 10 # all rows fetched + + # rows_expected = min(10 - 10, 10) = 0, should not raise + response._raise_if_low_memory() + + +class TestIterateWithMemoryCheck: + """Unit tests for TableResponse._iterate_with_memory_check method.""" + + def test_none_generator_yields_nothing(self): + """Test that no chunks are yielded when data_generator is None.""" + response = TableResponse(data=pd.DataFrame({"id": [1]})) + assert response._data_generator is None + + chunks = list(response._iterate_with_memory_check()) + assert chunks == [] + + @patch("mindsdb.integrations.libs.response.psutil") + def test_normal_iteration(self, mock_psutil): + """Test that all chunks are yielded during normal iteration.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df1 = pd.DataFrame({"id": [1, 2]}) + df2 = pd.DataFrame({"id": [3, 4]}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="id")] + response = TableResponse(data_generator=data_gen(), columns=columns) + + chunks = list(response._iterate_with_memory_check()) + + assert len(chunks) == 2 + pd.testing.assert_frame_equal(chunks[0], df1) + pd.testing.assert_frame_equal(chunks[1], df2) + + @patch("mindsdb.integrations.libs.response.psutil") + def test_memory_error_stops_iteration_after_first_chunk(self, mock_psutil): + """Test that MemoryError is raised after the first chunk when memory runs out. + + The pre-loop _raise_if_low_memory() is a no-op (since _last_data_piece is None), + so the first real psutil.virtual_memory() call happens at the post-yield check. + """ + # Use strings to ensure DataFrame memory > 1KB after >> 10 + df1 = pd.DataFrame({"text": ["x" * 200 for _ in range(100)]}) + df2 = pd.DataFrame({"text": ["y" * 200 for _ in range(100)]}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="text")] + response = TableResponse(data_generator=data_gen(), columns=columns) + + gen = response._iterate_with_memory_check() + + # First chunk succeeds — post-yield check will be the first real psutil call + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1) + first = next(gen) + pd.testing.assert_frame_equal(first, df1) + + # Resuming the generator triggers _raise_if_low_memory with 0 available memory + with pytest.raises(MemoryError): + next(gen) + + @patch("mindsdb.integrations.libs.response.psutil") + def test_updates_last_data_piece_and_rows_fetched(self, mock_psutil): + """Test that _last_data_piece and rows_fetched are updated during iteration.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df1 = pd.DataFrame({"id": [1, 2, 3]}) + df2 = pd.DataFrame({"id": [4, 5]}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="id")] + response = TableResponse(data_generator=data_gen(), columns=columns) + assert response.rows_fetched == 0 + + list(response._iterate_with_memory_check()) + + pd.testing.assert_frame_equal(response._last_data_piece, df2) + assert response.rows_fetched == 5 + + +class TestTableResponseFetchallEdgeCases: + """Additional edge-case tests for TableResponse.fetchall.""" + + def test_fetchall_no_generator_returns_existing_data(self): + """Test fetchall returns existing data when no generator is set.""" + df = pd.DataFrame({"id": [1, 2, 3]}) + response = TableResponse(data=df) + + result = response.fetchall() + pd.testing.assert_frame_equal(result, df) + + @patch("mindsdb.integrations.libs.response.psutil") + def test_fetchall_generator_only_no_initial_data(self, mock_psutil): + """Test fetchall with generator but no initial data.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df1 = pd.DataFrame({"id": [1, 2]}) + df2 = pd.DataFrame({"id": [3, 4]}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="id")] + response = TableResponse(data_generator=data_gen(), columns=columns) + + result = response.fetchall() + pd.testing.assert_frame_equal(result, pd.concat([df1, df2])) + assert response._fetched is True + assert response._data_generator is None + + @patch("mindsdb.integrations.libs.response.psutil") + def test_fetchall_empty_generator_creates_empty_df(self, mock_psutil): + """Test fetchall with empty generator creates DataFrame with column names.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + columns = [Column(name="id"), Column(name="name")] + response = TableResponse(data_generator=iter([]), columns=columns) + + result = response.fetchall() + assert list(result.columns) == ["id", "name"] + assert len(result) == 0 + + def test_fetchall_raises_if_invalid(self): + """Test fetchall raises ValueError if data was already consumed by iterate_no_save.""" + df = pd.DataFrame({"id": [1]}) + response = TableResponse(data=df, data_generator=iter([])) + list(response.iterate_no_save()) + + with pytest.raises(ValueError, match="Data has already been fetched"): + response.fetchall() + + +class TestTableResponseFetchmanyEdgeCases: + """Additional edge-case tests for TableResponse.fetchmany.""" + + @patch("mindsdb.integrations.libs.response.psutil") + def test_fetchmany_first_piece_with_no_initial_data(self, mock_psutil): + """Test fetchmany sets _data directly when no initial data exists.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df1 = pd.DataFrame({"id": [1, 2]}) + columns = [Column(name="id")] + response = TableResponse(data_generator=iter([df1]), columns=columns) + + piece = response.fetchmany() + pd.testing.assert_frame_equal(piece, df1) + pd.testing.assert_frame_equal(response._data, df1) + + @patch("mindsdb.integrations.libs.response.psutil") + def test_fetchmany_accumulates_data(self, mock_psutil): + """Test fetchmany accumulates pieces in _data.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df = pd.DataFrame({"id": [0]}) + df1 = pd.DataFrame({"id": [1]}) + df2 = pd.DataFrame({"id": [2]}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="id")] + response = TableResponse(data=df, data_generator=data_gen(), columns=columns) + + response.fetchmany() # df1 + response.fetchmany() # df2 + + pd.testing.assert_frame_equal(response._data, pd.concat([df, df1, df2])) + + @patch("mindsdb.integrations.libs.response.psutil") + def test_fetchmany_returns_none_when_exhausted(self, mock_psutil): + """Test fetchmany returns None and marks response as fetched when generator is empty.""" + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + + df1 = pd.DataFrame({"id": [1]}) + columns = [Column(name="id")] + response = TableResponse(data_generator=iter([df1]), columns=columns) + + piece1 = response.fetchmany() + assert isinstance(piece1, pd.DataFrame) + + piece2 = response.fetchmany() + assert piece2 is None + assert response._fetched is True + assert response._data_generator is None + + def test_fetchmany_raises_if_invalid(self): + """Test fetchmany raises ValueError after iterate_no_save.""" + df = pd.DataFrame({"id": [1]}) + response = TableResponse(data=df, data_generator=iter([])) + list(response.iterate_no_save()) + + with pytest.raises(ValueError, match="Data has already been fetched"): + response.fetchmany() + + +class TestMemoryErrorPropagation: + """Tests for MemoryError propagation through fetchall, fetchmany, and iterate_no_save.""" + + @patch("mindsdb.integrations.libs.response.psutil") + def test_fetchall_raises_memory_error(self, mock_psutil): + """Test MemoryError propagates through fetchall.""" + # Enough memory for first chunk, then out of memory + mock_psutil.virtual_memory.side_effect = [ + _mock_virtual_memory(available_kb=1_000_000), # pre-loop check + _mock_virtual_memory(available_kb=0), # post-yield check + ] + + df1 = pd.DataFrame({"x": list(range(1000))}) + df2 = pd.DataFrame({"x": list(range(1000))}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="x")] + response = TableResponse(data_generator=data_gen(), columns=columns) + + with pytest.raises(MemoryError): + response.fetchall() + + @patch("mindsdb.integrations.libs.response.psutil") + def test_fetchmany_raises_memory_error(self, mock_psutil): + """Test MemoryError propagates through fetchmany on second call.""" + df1 = pd.DataFrame({"x": list(range(1000))}) + df2 = pd.DataFrame({"x": list(range(1000))}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="x")] + response = TableResponse(data_generator=data_gen(), columns=columns) + + # First fetchmany: enough memory (pre-loop check is no-op since _last_data_piece is None) + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=1_000_000) + response.fetchmany() + + # Second fetchmany: pre-loop check fails because we now have _last_data_piece set + mock_psutil.virtual_memory.return_value = _mock_virtual_memory(available_kb=0) + with pytest.raises(MemoryError): + response.fetchmany() + + @patch("mindsdb.integrations.libs.response.psutil") + def test_iterate_no_save_raises_memory_error(self, mock_psutil): + """Test MemoryError propagates through iterate_no_save.""" + mock_psutil.virtual_memory.side_effect = [ + _mock_virtual_memory(available_kb=1_000_000), # pre-loop check + _mock_virtual_memory(available_kb=0), # post-yield check after first chunk + ] + + df1 = pd.DataFrame({"x": list(range(1000))}) + df2 = pd.DataFrame({"x": list(range(1000))}) + + def data_gen(): + yield df1 + yield df2 + + columns = [Column(name="x")] + response = TableResponse(data_generator=data_gen(), columns=columns) + + with pytest.raises(MemoryError): + list(response.iterate_no_save()) From 7ff5b4f60cf2aa68c9f75f034a05e78a9e2ce170 Mon Sep 17 00:00:00 2001 From: Andrey Date: Fri, 6 Mar 2026 16:04:51 +0300 Subject: [PATCH 033/125] Faiss ivf file (#12162) --- .../duckdb_faiss_handler.py | 12 +- .../duckdb_faiss_handler/faiss_index.py | 260 ++++++++++++++---- .../test_faiss_handler.py | 42 ++- .../interfaces/query_context/query_task.py | 1 - 4 files changed, 235 insertions(+), 80 deletions(-) diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py index dc536a6430f..fc413f14d68 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py @@ -70,7 +70,7 @@ def __init__(self, name: str, **kwargs): # Initialize storage paths self.duckdb_path = os.path.join(self.persist_directory, "duckdb.db") - self.faiss_index_path = os.path.join(self.persist_directory, "faiss_index") + self.faiss_index_path = self.persist_directory self.connect() # check keyword index @@ -129,11 +129,11 @@ def drop_table(self, table_name: str, if_exists=True): if self.faiss_index: self.faiss_index.drop() - def create_index(self, table_name: str, type: str = "ivf", nlist: int = 1024, train_count: int = 10000): - if type != "ivf": - raise NotImplementedError("Only ivf index is supported") + def create_index(self, table_name: str, type: str = "ivf_file", nlist: int = None, train_count: int = None): + if type not in ("ivf", "ivf_file"): + raise NotImplementedError("Only ivf or ivf_file indexes are supported") - self.faiss_index.create_index(nlist=nlist, train_count=train_count) + self.faiss_index.create_index(type, nlist=nlist, train_count=train_count) def insert(self, table_name: str, data: pd.DataFrame): """Insert data into both DuckDB and Faiss.""" @@ -421,7 +421,7 @@ def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> R self._sync() def get_dimension(self, table_name: str) -> int: - if self.faiss_index: + if self.faiss_index and self.faiss_index.index is not None: return self.faiss_index.dim def _sync(self): diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py index 8aef1808004..e596eaf0cf6 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py @@ -1,5 +1,5 @@ import os -from typing import Iterable, List +from typing import Iterable, List, Callable import numpy as np import psutil from pathlib import Path @@ -7,6 +7,7 @@ import portalocker import faiss # faiss or faiss-gpu +from faiss.contrib.ondisk import merge_ondisk from pydantic import BaseModel @@ -43,7 +44,7 @@ def __init__(self, path: str, config: dict): else: raise ValueError(f"Unknown metric: {metric}") - self.path = path + self.path = os.path.join(path, "faiss_index") self._since_ram_checked = 0 @@ -51,6 +52,15 @@ def __init__(self, path: str, config: dict): self.index_type = "flat" self.dim = None self.index_fd = None + + recover_path = Path(self.path).parent / "recover" + if recover_path.exists(): + # move all files from recover dir that might be left after index failing + for item in recover_path.iterdir(): + if item.is_dir(): + continue + item.rename(Path(self.path).parent / item.name) + if os.path.exists(self.path): self._load_index() @@ -79,8 +89,10 @@ def _load_index(self): self.index = faiss.read_index(self.path) self.dim = self.index.d - sub_index = faiss.downcast_index(self.index.index) - if isinstance(sub_index, faiss.IndexIVFFlat): + index = self.index + if hasattr(index, "index"): + index = faiss.downcast_index(index.index) + if isinstance(index, faiss.IndexIVFFlat): self.index_type = "ivf" def close(self): @@ -131,7 +143,7 @@ def _check_ram_usage(self, count_vectors, index_type: str = "flat", m=32, nlist= def insert( self, vectors: Iterable[Iterable[float]], - ids: Iterable[float], + ids: Iterable[int], ) -> None: if len(vectors) == 0: return @@ -170,12 +182,16 @@ def dump(self): def drop(self): self.close() - if os.path.exists(self.path): - os.remove(self.path) + + # remove index files (everything except duckdb) + for item in Path(self.path).parent.iterdir(): + if item.is_dir() or item.name.startswith("duckdb."): + continue + item.unlink() def search( self, - query: Iterable[Iterable[float]], + query: Iterable[float], limit: int = 10, # allowed_ids: Optional[Sequence[int]] = None, ): @@ -196,31 +212,70 @@ def search( class FaissIVFIndex(FaissIndex): - def _dump_vectors(self, index, path, batch_size: int = 10000): + def _dump_vectors(self, index, path: Path, batch_size: int = 30000): + """ + Extract and dump vectors and ids from index. Method is dependent on index type """ - Save vectors from a Faiss IndexIDMap to disk in batches using numpy memmap. - - Writes the one memmap for ids and batches for vectors + if hasattr(index, "id_map"): + ids = faiss.vector_to_array(index.id_map).astype(np.int64, copy=False) + inner = index.index + + def get_batch_vectors(start, size): + return inner.reconstruct_n(start, size).astype(np.float32, copy=False) + + return self._dump_vectors_to_file(ids, path, index.ntotal, batch_size, get_batch_vectors) + else: + invlists = index.invlists + + index.set_direct_map_type(faiss.DirectMap.Hashtable) + + ids_list = [] + for list_no in range(index.nlist): + list_size = invlists.list_size(list_no) + if list_size == 0: + continue + + # Get IDs stored in this inverted list + id_array = faiss.rev_swig_ptr(invlists.get_ids(list_no), list_size) + ids_list.append(id_array) - :param index: Faiss IndexIDMap - :param path: Output directory where batch files will be written - :param batch_size: Number of vectors per batch file + ids = np.hstack(ids_list).astype(np.int64) + + # to train index first batches will be used. shuffle ids to prevent using the same lists + # TODO shuffle only part of data? + np.random.shuffle(ids) + + def get_batch_vectors(start, size): + ids_batch = ids[start : start + size] + return index.reconstruct_batch(ids_batch).astype(np.float32, copy=False) + + return self._dump_vectors_to_file(ids, path, index.ntotal, batch_size, get_batch_vectors) + + def _dump_vectors_to_file( + self, + ids: np.ndarray, + path: Path, + ntotal: int, + batch_size: int, + get_batch_content: Callable[[int, int], np.ndarray], + ) -> int: """ - if not hasattr(index, "id_map") or not hasattr(index, "index"): - raise ValueError("Expected a Faiss IndexIDMap-like object with 'id_map' and 'index' attributes") + Write ids and vectors to memmap files in batches. - ntotal = index.ntotal + :param ids: vector IDs in the same order as vectors will be dumped. + :param path: directory to store dumps. + :param ntotal: total number of vectors. + :param batch_size: number of vectors per batch file. + :param get_batch_content: function to get a batch content - ids = faiss.vector_to_array(index.id_map).astype(np.int64, copy=False) + """ # Write all ids once to a single memmap file ids_path = path / "ids.mmap" mmap_ids = np.memmap(ids_path, dtype=np.int64, mode="w+", shape=(ntotal,)) mmap_ids[:] = ids - del mmap_ids # flush - - inner = index.index batch_num = 0 while True: @@ -233,8 +288,7 @@ def _dump_vectors(self, index, path, batch_size: int = 10000): ntotal -= size batch_num += 1 - # Reconstruct a contiguous block when possible - vecs = inner.reconstruct_n(start, size).astype(np.float32, copy=False) + vecs = get_batch_content(start, size) vecs_path = path / f"batch_{batch_num:05d}_vecs.mmap" @@ -244,14 +298,48 @@ def _dump_vectors(self, index, path, batch_size: int = 10000): mmap_vecs.flush() del mmap_vecs + del mmap_ids return batch_num - def _create_ifv_index_from_dump(self, path, train_count=10000, nlist=1024): + def _train_ivf(self, dump_path, train_count, nlist): + # Accumulate training data up to train_count + train_left = train_count + train_chunks = [] + + vec_files = self._get_dump_vector_files(dump_path) + + for fname in vec_files: + fpath = dump_path / fname + batch_data = np.fromfile(fpath, dtype="float32") + rows = int(batch_data.shape[0] / self.dim) + + train_chunks.append(batch_data.reshape([rows, self.dim])) + + train_left -= rows + if train_left <= 0: + break + + train_data = np.vstack(train_chunks) + train_data = train_data[:train_count, :] + + quantizer = faiss.IndexFlat(self.dim, self.metric) + ivf = faiss.IndexIVFFlat(quantizer, self.dim, nlist, self.metric) + + ivf.train(train_data) + return ivf + + def _get_dump_vector_files(self, dump_path): + # Collect vector batch files and sort by batch index + vec_files = [f for f in os.listdir(dump_path) if f.startswith("batch_")] + if not vec_files: + raise FileNotFoundError(f"No vector batch memmaps found in {dump_path}") + + vec_files.sort() + return vec_files + + def _create_ivf_index(self, path, train_count, nlist): """ - Build an IVF index (wrapped in IndexIDMap) from memmap batches - - Reads a single `ids.mmap` and multiple `batch_{i}_vecs.mmap` files from `path`. - - Accumulates up to `train_count` vectors to train the IVF quantizer. - - Creates IndexIVFFlat and adds all vectors with their ids to it. + Build an in-memory IVF index :param path: Directory containing memmap files :param train_count: Number of vectors to use for training @@ -262,45 +350,48 @@ def _create_ifv_index_from_dump(self, path, train_count=10000, nlist=1024): ids_path = path / "ids.mmap" if not os.path.exists(ids_path): raise FileNotFoundError(f"Missing ids memmap: {ids_path}") - ids = np.fromfile(ids_path, dtype="int64") - # Collect vector batch files and sort by batch index - vec_files = [f for f in os.listdir(path) if f.startswith("batch_")] - if not vec_files: - raise FileNotFoundError(f"No vector batch memmaps found in {path}") - - vec_files.sort() + ivf = self._train_ivf(path, nlist=nlist, train_count=train_count) - # Accumulate training data up to train_count - train_left = train_count - train_chunks = [] + vec_files = self._get_dump_vector_files(path) + # load data + start = 0 for fname in vec_files: fpath = path / fname + batch_data = np.fromfile(fpath, dtype="float32") rows = int(batch_data.shape[0] / self.dim) - train_chunks.append(batch_data.reshape([rows, self.dim])) + batch_vectors = batch_data.reshape([rows, self.dim]) - train_left -= rows - if train_left <= 0: - break + ids_batch = np.asarray(ids[start : start + rows]) + ivf.add_with_ids(batch_vectors, ids_batch) + start += rows - train_data = np.vstack(train_chunks) + return ivf - # nlist can't be less than train data - nlist = min(nlist, len(train_data)) + def _create_ivf_file_index(self, path, train_count, nlist): + """Build an IVF on disk index""" - quantizer = faiss.IndexFlat(self.dim, self.metric) - ivf = faiss.IndexIVFFlat(quantizer, self.dim, nlist, self.metric) + index_path = path.parent + trained_index = self._train_ivf(path, train_count=train_count, nlist=nlist) + # store trained index + trained_path = str(index_path / "faiss_index.trained") + faiss.write_index(trained_index, trained_path) - ivf.train(train_data) - ivf_id_map = faiss.IndexIDMap(ivf) + ids_path = path / "ids.mmap" + if not os.path.exists(ids_path): + raise FileNotFoundError(f"Missing ids memmap: {ids_path}") + ids = np.fromfile(ids_path, dtype="int64") + + vec_files = self._get_dump_vector_files(path) - # load data start = 0 - for fname in vec_files: + block_fnames = [] + for num, fname in enumerate(vec_files): + index = faiss.read_index(trained_path) fpath = path / fname batch_data = np.fromfile(fpath, dtype="float32") @@ -309,24 +400,58 @@ def _create_ifv_index_from_dump(self, path, train_count=10000, nlist=1024): batch_vectors = batch_data.reshape([rows, self.dim]) ids_batch = np.asarray(ids[start : start + rows]) - ivf_id_map.add_with_ids(batch_vectors, ids_batch) + index.add_with_ids(batch_vectors, ids_batch) + block_fname = str(index_path / f"faiss_index_block.{num}") + block_fnames.append(block_fname) + faiss.write_index(index, block_fname) start += rows - return ivf_id_map + index = faiss.read_index(trained_path) + + merge_ondisk(index, block_fnames, str(index_path / "faiss_index_merged")) + os.unlink(trained_path) + for block_fname in block_fnames: + os.unlink(block_fname) + + return index + + def create_index(self, index_type, nlist=None, train_count=None): + """ + Create or recreate IVF index + + :param index_type: options are: 'ivf' (in RAM) or 'ivf_file' (on disk) + :param nlist: number of inverted lists + :param train_count: count of vectors to use for training. + + """ - def create_index(self, nlist=1024, train_count=10000): # index might not fit into RAM, extract data to files - dump_path = Path(self.path).parent / "dump" + base_path = Path(self.path).parent + dump_path = base_path / "dump" # if self.index_type != 'flat': # raise ValueError('Index was already created') + # check params, apply defaults + if nlist is None: + nlist = self.config.nlist + if self.index is None: ntotal = 0 else: ntotal = self.index.ntotal - if nlist > ntotal: - raise ValueError(f"Not enough data to create: {ntotal}, required at lease {nlist} records") + + # faiss shows warning if train count is less than 39 * nlist and recommend to use at least this size for train data + nlist_k = 39 + if train_count is not None: + if train_count < nlist * nlist_k: + raise ValueError(f"Train_count can't be less than nlist * {nlist_k} (is {nlist * nlist_k})") + else: + # get 10k if possible but not less than nlist * k + train_count = max(nlist * nlist_k, min(ntotal, 10000)) + + if train_count > ntotal: + raise ValueError(f"Not enough data to create index: {ntotal}, at least {train_count} records are required") dump_path.mkdir(exist_ok=True) @@ -339,8 +464,22 @@ def create_index(self, nlist=1024, train_count=10000): # unload flat index from RAM self.close() + # buckup index files + recover_path = base_path / "recover" + recover_path.mkdir(exist_ok=True) + for item in base_path.iterdir(): + if item.is_dir() or item.name.startswith("duckdb."): + continue + item.rename(recover_path / item.name) + # create ivf index - ivf_index = self._create_ifv_index_from_dump(dump_path, train_count=train_count, nlist=nlist) + if index_type == "ivf": + ivf_index = self._create_ivf_index(dump_path, train_count=train_count, nlist=nlist) + + elif index_type == "ivf_file": + ivf_index = self._create_ivf_file_index(dump_path, train_count=train_count, nlist=nlist) + else: + raise ValueError(f"Unknown index type: {index_type}") self.index = ivf_index self.index_type = "ivf" @@ -350,3 +489,8 @@ def create_index(self, nlist=1024, train_count=10000): # remove unused items for item in dump_path.iterdir(): item.unlink() + dump_path.rmdir() + + for item in recover_path.iterdir(): + item.unlink() + recover_path.rmdir() diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py index 915d89f64ab..01eb44b2cae 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py @@ -1,3 +1,4 @@ +import pytest from unittest.mock import patch import pandas as pd @@ -30,8 +31,15 @@ def _get_storage_table(self, kb_name): return f"faiss_{kb_name}.kb_faiss" + @pytest.mark.parametrize("index_type", ["ivf", "ivf_file"]) @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_ivf_index(self, mock_litellm_embedding): + def test_ivf_index(self, mock_litellm_embedding, index_type): + """ + Run test two times: + - make ivf index and then reindex to ivf_file + - make ivf_file index and then reindex to ivf + """ + set_litellm_embedding(mock_litellm_embedding) df = self._get_ral_table() @@ -51,20 +59,24 @@ def test_ivf_index(self, mock_litellm_embedding): """ ) - self.run_sql("CREATE INDEX ON KNOWLEDGE_BASE kb_ral WITH (nlist=10)") + for i in range(2): + self.run_sql(f"CREATE INDEX ON KNOWLEDGE_BASE kb_ral WITH (nlist=10, type='{index_type}')") + + # search works + ret = self.run_sql("select * from kb_ral where k.content = 'white' limit 1") + assert "white" in ret["chunk_content"][0] - # search works - ret = self.run_sql("select * from kb_ral where k.content = 'white' limit 1") - assert "white" in ret["chunk_content"][0] + # -- test insert -- + self.run_sql("insert into kb_ral (id, english) values (10000, 'magpie')") + # search + ret = self.run_sql("select * from kb_ral where k.content = 'magpie' limit 1") + assert "magpie" in ret["chunk_content"][0] - # -- test insert -- - self.run_sql("insert into kb_ral (id, english) values (10000, 'magpie')") - # search - ret = self.run_sql("select * from kb_ral where k.content = 'magpie' limit 1") - assert "magpie" in ret["chunk_content"][0] + # -- test delete -- + self.run_sql("delete from kb_ral where id=10000") + # search + ret = self.run_sql("select * from kb_ral where k.content = 'magpie' limit 1") + assert len(ret) == 0 or "magpie" not in ret["chunk_content"][0] - # -- test delete -- - self.run_sql("delete from kb_ral where id=10000") - # search - ret = self.run_sql("select * from kb_ral where k.content = 'magpie' limit 1") - assert len(ret) == 0 or "magpie" not in ret["chunk_content"][0] + # toggle index type + index_type = "ivf_file" if index_type == "ivf" else "ivf" diff --git a/mindsdb/interfaces/query_context/query_task.py b/mindsdb/interfaces/query_context/query_task.py index 57cc62d7f81..97cbbdcbf26 100644 --- a/mindsdb/interfaces/query_context/query_task.py +++ b/mindsdb/interfaces/query_context/query_task.py @@ -10,7 +10,6 @@ def __init__(self, *args, **kwargs): self.query_id = self.object_id def run(self, stop_event): - try: session = SessionController() SQLQuery(None, query_id=self.query_id, session=session, stop_event=stop_event) From 2e62d37c35bfdd636885c959edf106a388c52cee Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Fri, 6 Mar 2026 17:03:28 +0300 Subject: [PATCH 034/125] Fix for queries to API handler with `IN` filter (#12226) --- mindsdb/integrations/utilities/sql_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/utilities/sql_utils.py b/mindsdb/integrations/utilities/sql_utils.py index 1d796f49b02..e123b9ed837 100644 --- a/mindsdb/integrations/utilities/sql_utils.py +++ b/mindsdb/integrations/utilities/sql_utils.py @@ -458,7 +458,10 @@ def filter_dataframe(df: pd.DataFrame, conditions: list, raw_conditions=None, or else: item = ast.BinaryOperation(op=op, args=[arg1_identifier, ast.Constant(arg2)]) else: - item = ast.BinaryOperation(op=op, args=[arg1_identifier, ast.Constant(arg2)]) + if isinstance(arg2, ASTNode): + item = ast.BinaryOperation(op=op, args=[arg1_identifier, arg2]) + else: + item = ast.BinaryOperation(op=op, args=[arg1_identifier, ast.Constant(arg2)]) if where_query is None: where_query = item From 98781f652744c31662bb1811d3886473d218b3fe Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Fri, 6 Mar 2026 17:03:39 +0300 Subject: [PATCH 035/125] Use `uv` as a fallback if `pip` is not found when installing a handler's dependencies (#12234) --- mindsdb/integrations/utilities/install.py | 69 +++++++++++++---------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/mindsdb/integrations/utilities/install.py b/mindsdb/integrations/utilities/install.py index 388edc9703d..9a56b2e4ae4 100644 --- a/mindsdb/integrations/utilities/install.py +++ b/mindsdb/integrations/utilities/install.py @@ -1,66 +1,77 @@ import os import sys import subprocess +from enum import Enum from typing import Text, List -def install_dependencies(dependencies: List[Text]) -> dict: +class InstallTool(Enum): + pip = (sys.executable, "-m", "pip") + uv = ("uv", "pip") + + +def install_dependencies(dependencies: List[Text], tool: InstallTool = InstallTool.pip) -> dict: """ Installs the dependencies for a handler by calling the `pip install` command via subprocess. Args: dependencies (List[Text]): List of dependencies for the handler. + tool (InstallTool): tool the tool that will be used to install dependencies Returns: dict: A dictionary containing the success status and an error message if an error occurs. """ - outs = b'' - errs = b'' - result = { - 'success': False, - 'error_message': None - } + outs = b"" + errs = b"" + result = {"success": False, "error_message": None} code = None try: # Split the dependencies by parsing the contents of the requirements.txt file. split_dependencies = parse_dependencies(dependencies) except FileNotFoundError as file_not_found_error: - result['error_message'] = f"Error parsing dependencies, file not found: {str(file_not_found_error)}" + result["error_message"] = f"Error parsing dependencies, file not found: {str(file_not_found_error)}" return result except Exception as unknown_error: - result['error_message'] = f"Unknown error parsing dependencies: {str(unknown_error)}" + result["error_message"] = f"Unknown error parsing dependencies: {str(unknown_error)}" return result try: - # Install the dependencies using the `pip install` command. + # Install the dependencies using the selected tool. sp = subprocess.Popen( - [sys.executable, '-m', 'pip', 'install', *split_dependencies], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE + [*tool.value, "install", *split_dependencies], stdout=subprocess.PIPE, stderr=subprocess.PIPE ) code = sp.wait() outs, errs = sp.communicate(timeout=1) except subprocess.TimeoutExpired as timeout_error: sp.kill() - result['error_message'] = f"Timeout error while installing dependencies: {str(timeout_error)}" + result["error_message"] = f"Timeout error while installing dependencies: {str(timeout_error)}" + return result + except FileNotFoundError as e: + if e.filename == "uv": + result["error_message"] = "The 'pip' and 'uv' tools are not found. Please install them." + else: + result["error_message"] = f"FileNotFoundError error while installing dependencies: {str(e)}" return result except Exception as unknown_error: - result['error_message'] = f"Unknown error while installing dependencies: {str(unknown_error)}" + result["error_message"] = f"Unknown error while installing dependencies: {str(unknown_error)}" return result # Return the result of the installation if successful, otherwise return an error message. if code != 0: - output = '' + output = "" if isinstance(outs, bytes) and len(outs) > 0: - output = output + 'Output: ' + outs.decode() + output = output + "Output: " + outs.decode() if isinstance(errs, bytes) and len(errs) > 0: if len(output) > 0: - output = output + '\n' - output = output + 'Errors: ' + errs.decode() - result['error_message'] = output + output = output + "\n" + output = output + "Errors: " + errs.decode() + if "no module named pip" in output.lower() and tool is InstallTool.pip: + # try with uv + return install_dependencies(dependencies, InstallTool.uv) + result["error_message"] = output else: - result['success'] = True + result["success"] = True return result @@ -85,19 +96,19 @@ def parse_dependencies(dependencies: List[Text]) -> List[Text]: split_dependencies = [] for dependency in dependencies: # ignore standalone comments - if dependency.startswith('#'): + if dependency.startswith("#"): continue # remove inline comments - if '#' in dependency: - dependency = dependency.split('#')[0].strip() + if "#" in dependency: + dependency = dependency.split("#")[0].strip() # check if the dependency is a path to a requirements file - if dependency.startswith('-r'): + if dependency.startswith("-r"): # get the path to the requirements file - req_path = dependency.split(' ')[1] + req_path = dependency.split(" ")[1] # create the absolute path to the requirements file - abs_req_path = os.path.abspath(os.path.join(script_path, req_path.replace('mindsdb/integrations', '..'))) + abs_req_path = os.path.abspath(os.path.join(script_path, req_path.replace("mindsdb/integrations", ".."))) # check if the file exists if os.path.exists(abs_req_path): inner_dependencies, inner_split_dependencies = [], [] @@ -128,7 +139,7 @@ def read_dependencies(path: Text) -> List[Text]: """ dependencies = [] # read the dependencies from the file - with open(str(path), 'rt') as f: - dependencies = [x.strip(' \t\n') for x in f.readlines()] + with open(str(path), "rt") as f: + dependencies = [x.strip(" \t\n") for x in f.readlines()] dependencies = [x for x in dependencies if len(x) > 0] return dependencies From fab39290d5aee49cae9bfb20eb8f1a537c21c1ef Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Tue, 10 Mar 2026 16:59:56 +0300 Subject: [PATCH 036/125] fix: Optimize get_integration by fetching directly from DB (#11973) (#12281) Co-authored-by: SyedaAnshrahGillani --- mindsdb/interfaces/database/database.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mindsdb/interfaces/database/database.py b/mindsdb/interfaces/database/database.py index bbacc9c256a..3f0fb602ace 100644 --- a/mindsdb/interfaces/database/database.py +++ b/mindsdb/interfaces/database/database.py @@ -101,11 +101,15 @@ def get_dict(self, filter_type: Optional[str] = None, lowercase: bool = True): def get_integration(self, integration_id): # get integration by id - - # TODO get directly from db? - for rec in self.get_list(): - if rec["id"] == integration_id and rec["type"] == "data": - return {"name": rec["name"], "type": rec["type"], "engine": rec["engine"], "id": rec["id"]} + integration = self.integration_controller.get_by_id(integration_id) + if integration and integration.get("type", "data") == "data": + return { + "name": integration["name"], + "type": integration["type"], + "engine": integration["engine"], + "id": integration["id"], + } + return None def exists(self, db_name: str) -> bool: return db_name.lower() in self.get_dict() From 562f982c28ce8a8d2e0cd6c70a48e937c42dd8da Mon Sep 17 00:00:00 2001 From: andrew Date: Wed, 11 Mar 2026 11:39:27 +0300 Subject: [PATCH 037/125] review fixes --- .../sql_query/steps/fetch_dataframe_partition.py | 2 +- mindsdb/interfaces/database/projects.py | 10 +++++----- mindsdb/interfaces/query_context/last_query.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py index 819f7f1d563..77f53fc8bd3 100644 --- a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +++ b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py @@ -129,7 +129,7 @@ def repeat_till_reach_limit(self, step, limit): break # break if process is too long or to many tries - if try_num > 3 or started_at - time.time() > 5: + if try_num > 3 or time.time() - started_at > 5: # the last try without the limit first_table_limit = None continue diff --git a/mindsdb/interfaces/database/projects.py b/mindsdb/interfaces/database/projects.py index bd7b443f833..bdf84bb2c10 100644 --- a/mindsdb/interfaces/database/projects.py +++ b/mindsdb/interfaces/database/projects.py @@ -203,17 +203,17 @@ def get_conditions_to_move(node): if isinstance(arg1, Identifier): name = arg1.parts[-1].lower() - if name in white_list: - arg1 = white_list[name] # don't move condition for join with Star - elif name in black_list or not (has_star and not is_join): + if name in black_list or not (has_star and not is_join): continue + elif name in white_list: + arg1 = white_list[name] if isinstance(arg2, Identifier): name = arg2.parts[-1].lower() - if name in white_list: - arg2 = white_list[name] if name in black_list or not (has_star and not is_join): continue + elif name in white_list: + arg2 = white_list[name] # condition can be moved into view condition2 = BinaryOperation(condition.op, [arg1, arg2]) diff --git a/mindsdb/interfaces/query_context/last_query.py b/mindsdb/interfaces/query_context/last_query.py index 0043c55aa1a..7e00a08c846 100644 --- a/mindsdb/interfaces/query_context/last_query.py +++ b/mindsdb/interfaces/query_context/last_query.py @@ -224,7 +224,7 @@ def to_string(self) -> str: and query.offset is None and query.cte is None ): - query = query.from_table + query = copy.deepcopy(query.from_table) query.parentheses = False query.alias = None From ef27f73300f0444fc4e37298ef7d3daca65bc685 Mon Sep 17 00:00:00 2001 From: Andrey Date: Wed, 11 Mar 2026 16:04:34 +0300 Subject: [PATCH 038/125] FQE fixes for agent (#12263) --- mindsdb/api/executor/planner/plan_join.py | 16 +- .../steps/fetch_dataframe_partition.py | 5 +- .../sql_query/steps/subselect_step.py | 2 + mindsdb/interfaces/database/projects.py | 22 ++- .../interfaces/query_context/last_query.py | 139 +++++++++++------- tests/unit/executor/test_base_queires.py | 21 +++ tests/unit/planner/test_join_tables.py | 59 ++++++-- 7 files changed, 182 insertions(+), 82 deletions(-) diff --git a/mindsdb/api/executor/planner/plan_join.py b/mindsdb/api/executor/planner/plan_join.py index 603528ac6f8..a7eb26800ef 100644 --- a/mindsdb/api/executor/planner/plan_join.py +++ b/mindsdb/api/executor/planner/plan_join.py @@ -358,6 +358,7 @@ def _check_identifiers(node, is_table, **kwargs): else: self.has_ambiguous_columns = True + query.cte = None # already used before query_traversal(query, _check_identifiers) self.check_query_conditions(query) @@ -371,6 +372,8 @@ def _check_identifiers(node, is_table, **kwargs): # create plan # TODO add optimization: one integration without predictor + planned_steps_before_join = len(self.planner.plan.steps) + self.step_stack = [] for item in join_sequence: if isinstance(item, TableInfo): @@ -400,20 +403,25 @@ def _check_identifiers(node, is_table, **kwargs): query_in.where = query.where if self.query_context["optimize_inner_join"]: - self.planner.plan.steps = self.optimize_inner_join(self.planner.plan.steps) + self.planner.plan.steps = self.optimize_inner_join(self.planner.plan.steps, planned_steps_before_join) self.close_partition() return self.planner.plan.steps[-1] - def optimize_inner_join(self, steps_in): + def optimize_inner_join(self, steps_in, min_step_num): steps_out = [] partition_step = None partition_used = False - for step in steps_in: + for i, step in enumerate(steps_in): if partition_step is None: - if isinstance(step, FetchDataframeStep) and not partition_used and step.query.limit is not None: + if ( + i >= min_step_num + and isinstance(step, FetchDataframeStep) + and not partition_used + and step.query.limit is not None + ): limit = step.query.limit.value step.query.limit = None partition_used = True diff --git a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py index 30de48b9442..77f53fc8bd3 100644 --- a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +++ b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py @@ -1,3 +1,4 @@ +import time import copy import pandas as pd from typing import List @@ -105,6 +106,7 @@ def repeat_till_reach_limit(self, step, limit): query, context_callback = query_context_controller.handle_db_context_vars(query, dn, self.session) try_num = 1 + started_at = time.time() while True: self.substeps = copy.deepcopy(step.steps) query2 = copy.deepcopy(query) @@ -126,7 +128,8 @@ def repeat_till_reach_limit(self, step, limit): result = result[:limit] break - if try_num > 3: + # break if process is too long or to many tries + if try_num > 3 or time.time() - started_at > 5: # the last try without the limit first_table_limit = None continue diff --git a/mindsdb/api/executor/sql_query/steps/subselect_step.py b/mindsdb/api/executor/sql_query/steps/subselect_step.py index ac07cd47d3f..40e3dfbd2f2 100644 --- a/mindsdb/api/executor/sql_query/steps/subselect_step.py +++ b/mindsdb/api/executor/sql_query/steps/subselect_step.py @@ -184,6 +184,8 @@ def check_fields(node, is_target=None, **kwargs): "version for the right syntax to use near '$$' at line 1" ) + key, column_quoted = (), False + match node.parts, node.is_quoted: case [column_name], [column_quoted]: if column_name in aliases: diff --git a/mindsdb/interfaces/database/projects.py b/mindsdb/interfaces/database/projects.py index d51811d3f04..bdf84bb2c10 100644 --- a/mindsdb/interfaces/database/projects.py +++ b/mindsdb/interfaces/database/projects.py @@ -8,7 +8,7 @@ import numpy as np from mindsdb_sql_parser.ast.base import ASTNode -from mindsdb_sql_parser.ast import Select, Star, Constant, Identifier, BinaryOperation +from mindsdb_sql_parser.ast import Select, Star, Constant, Identifier, BinaryOperation, Join from mindsdb_sql_parser import parse_sql from mindsdb.interfaces.storage import db @@ -185,29 +185,35 @@ def get_conditions_to_move(node): # column is not in black list AND (query has star(*) OR column in white list) has_star = False - white_list, black_list = [], [] + white_list, black_list = {}, [] for target in view_query.targets: if isinstance(target, Star): has_star = True if isinstance(target, Identifier): name = target.parts[-1].lower() if target.alias is None or target.alias.parts[-1].lower() == name: - white_list.append(name) + white_list[name] = target elif target.alias is not None: black_list.append(target.alias.parts[-1].lower()) + is_join = isinstance(view_query.from_table, Join) view_where = view_query.where for condition in conditions: arg1, arg2 = condition.args if isinstance(arg1, Identifier): name = arg1.parts[-1].lower() - if name in black_list or not (has_star or name in white_list): + # don't move condition for join with Star + if name in black_list or not (has_star and not is_join): continue + elif name in white_list: + arg1 = white_list[name] if isinstance(arg2, Identifier): name = arg2.parts[-1].lower() - if name in black_list or not (has_star or name in white_list): + if name in black_list or not (has_star and not is_join): continue + elif name in white_list: + arg2 = white_list[name] # condition can be moved into view condition2 = BinaryOperation(condition.op, [arg1, arg2]) @@ -224,7 +230,13 @@ def get_conditions_to_move(node): # combine outer query with view's query view_query.parentheses = True + + # keep alias (column of the query might relate to it) + alias = query.from_table.alias if query.from_table.alias is not None else query.from_table + view_query.alias = Identifier(parts=[alias.parts[-1]]) + query.from_table = view_query + return query def query_view(self, query: Select, session) -> pd.DataFrame: diff --git a/mindsdb/interfaces/query_context/last_query.py b/mindsdb/interfaces/query_context/last_query.py index 1df233d4405..7e00a08c846 100644 --- a/mindsdb/interfaces/query_context/last_query.py +++ b/mindsdb/interfaces/query_context/last_query.py @@ -3,7 +3,17 @@ from collections import defaultdict from mindsdb_sql_parser.ast import ( - Identifier, Select, BinaryOperation, Last, Constant, Star, ASTNode, NullConstant, OrderBy, Function, TypeCast + Identifier, + Select, + BinaryOperation, + Last, + Constant, + Star, + ASTNode, + NullConstant, + OrderBy, + Function, + TypeCast, ) from mindsdb.integrations.utilities.query_traversal import query_traversal @@ -34,21 +44,21 @@ def __init__(self, query: ASTNode): def _find_last_columns(self, query: ASTNode) -> Union[dict, None]: """ - This function: - - Searches LAST column in the input query - - Replaces it with constants and memorises link to these constants - - Link to constants will be used to inject values to query instead of LAST - - Provide checks: - - if it is possible to find the table for column - - if column in select target - - Generates and returns last_column variable which is dict - last_columns[table_name] = { - 'table':
, - 'column': , - 'links': [, ... ], - 'target_idx': , - 'gen_init_query': if true: to generate query to initial values for LAST - } + This function: + - Searches LAST column in the input query + - Replaces it with constants and memorises link to these constants + - Link to constants will be used to inject values to query instead of LAST + - Provide checks: + - if it is possible to find the table for column + - if column in select target + - Generates and returns last_column variable which is dict + last_columns[table_name] = { + 'table':
, + 'column': , + 'links': [, ... ], + 'target_idx': , + 'gen_init_query': if true: to generate query to initial values for LAST + } """ # index last variables in query @@ -76,7 +86,6 @@ def replace_last_in_tree(node: ASTNode, injected: Constant): return found def index_query(node, is_table, parent_query, **kwargs): - parent_query_id = id(parent_query) last = None if is_table and isinstance(node, Identifier): @@ -105,13 +114,15 @@ def index_query(node, is_table, parent_query, **kwargs): if last is not None: # memorize - conditions.append({ - 'query_id': parent_query_id, - 'condition': node, - 'last': last, - 'column': col, - 'gen_init_query': gen_init_query # generate query to fetch initial last values from table - }) + conditions.append( + { + "query_id": parent_query_id, + "condition": node, + "last": last, + "column": col, + "gen_init_query": gen_init_query, # generate query to fetch initial last values from table + } + ) # find lasts query_traversal(query, index_query) @@ -122,7 +133,7 @@ def index_query(node, is_table, parent_query, **kwargs): self.query_orig = copy.deepcopy(query) for info in conditions: - self.last_idx[info['query_id']].append(info) + self.last_idx[info["query_id"]].append(info) # index query targets query_id = id(query) @@ -152,21 +163,20 @@ def index_query(node, is_table, parent_query, **kwargs): last_columns = {} for parent_query_id, items in self.last_idx.items(): for info in items: - col = info['column'] - last = info['last'] + col = info["column"] + last = info["last"] tables = tables_idx[parent_query_id] uniq_tables = len(set([id(v) for v in tables.values()])) if len(col.parts) > 1: - table = tables.get(col.parts[-2]) if table is None: - raise ValueError('cant find table') + raise ValueError("cant find table") elif uniq_tables == 1: table = list(tables.values())[0] else: # or just skip it? - raise ValueError('cant find table') + raise ValueError("cant find table") col_name = col.parts[-1] @@ -179,29 +189,46 @@ def index_query(node, is_table, parent_query, **kwargs): # will try to get by name ... else: - raise ValueError('Last value should be in query target') + raise ValueError("Last value should be in query target") last_columns[table_name] = { - 'table': table, - 'column': col_name, - 'links': [last], - 'target_idx': target_idx, - 'gen_init_query': info['gen_init_query'] + "table": table, + "column": col_name, + "links": [last], + "target_idx": target_idx, + "gen_init_query": info["gen_init_query"], } - elif last_columns[table_name]['column'] == col_name: - last_columns[table_name]['column'].append(last) + elif last_columns[table_name]["column"] == col_name: + last_columns[table_name]["column"].append(last) else: - raise ValueError('possible to use only one column') + raise ValueError("possible to use only one column") return last_columns def to_string(self) -> str: """ - String representation of the query - Used to identify query in query_context table + String representation of the query + Used to identify query in query_context table """ - return self.query_orig.to_string() + query = self.query_orig + if isinstance(query.from_table, Select) and query.targets == [Star()]: + # simplify nested query + if ( + query.group_by is None + and query.order_by is None + and query.having is None + and query.distinct is False + and query.where is None + and query.limit is None + and query.offset is None + and query.cte is None + ): + query = copy.deepcopy(query.from_table) + query.parentheses = False + query.alias = None + + return query.to_string() def get_last_columns(self) -> List[dict]: """ @@ -210,11 +237,11 @@ def get_last_columns(self) -> List[dict]: """ return [ { - 'table': info['table'], - 'table_name': table_name, - 'column_name': info['column'], - 'target_idx': info['target_idx'], - 'gen_init_query': info['gen_init_query'], + "table": info["table"], + "table_name": table_name, + "column_name": info["column"], + "target_idx": info["target_idx"], + "gen_init_query": info["gen_init_query"], } for table_name, info in self.last_tables.items() ] @@ -224,8 +251,8 @@ def apply_values(self, values: dict) -> ASTNode: Fills query with new values and return it """ for table_name, info in self.last_tables.items(): - value = values.get(table_name, {}).get(info['column']) - for last in info['links']: + value = values.get(table_name, {}).get(info["column"]) + for last in info["links"]: last.value = value return self.query @@ -239,9 +266,9 @@ def get_init_queries(self): # replace values for items in self.last_idx.values(): for info in items: - node = info['condition'] + node = info["condition"] back_up_values.append([node.op, node.args[1]]) - node.op = 'is not' + node.op = "is not" node.args[1] = NullConstant() query2 = copy.deepcopy(self.query) @@ -249,18 +276,16 @@ def get_init_queries(self): # return values for items in self.last_idx.values(): for info in items: - node = info['condition'] + node = info["condition"] op, arg1 = back_up_values.pop(0) node.op = op node.args[1] = arg1 for info in self.get_last_columns(): - if not info['gen_init_query']: + if not info["gen_init_query"]: continue - col = Identifier(info['column_name']) + col = Identifier(info["column_name"]) query2.targets = [col] - query2.order_by = [ - OrderBy(col, direction='DESC') - ] + query2.order_by = [OrderBy(col, direction="DESC")] query2.limit = Constant(1) yield query2, info diff --git a/tests/unit/executor/test_base_queires.py b/tests/unit/executor/test_base_queires.py index 5fbece5c4d3..468312d40a5 100644 --- a/tests/unit/executor/test_base_queires.py +++ b/tests/unit/executor/test_base_queires.py @@ -899,6 +899,27 @@ def test_subselect_1row_aggregate(self, data_handler): assert len(ret) == 1 assert ret["result"][0] == 1 + @patch("mindsdb.integrations.handlers.postgres_handler.Handler") + def test_cte_join(self, data_handler): + self.set_handler(data_handler, name="pg", tables={"stores": get_stores_df()}) + self.save_file("regions", get_regions_df()) + + ret = self.run_sql(""" + WITH regions AS ( + SELECT DISTINCT id, name FROM files.regions + ), + stores AS ( + SELECT * FROM pg.stores + LIMIT 10 + ) + SELECT format, region_id FROM pg.stores s + JOIN regions r on r.id = s.region_id + WHERE s.format IN (SELECT format FROM stores WHERE format='a') + LIMIT 100; + """) + assert len(ret) > 1 + assert ret["format"][0] == "a" + class TestSet(BaseExecutorTest): @pytest.mark.parametrize("var", ["var", "@@var", "@@session.var", "session var"]) diff --git a/tests/unit/planner/test_join_tables.py b/tests/unit/planner/test_join_tables.py index 24cef73b8fa..7bd8a463d7a 100644 --- a/tests/unit/planner/test_join_tables.py +++ b/tests/unit/planner/test_join_tables.py @@ -11,6 +11,7 @@ Star, BinaryOperation, Function, + Parameter, ) from mindsdb_sql_parser.utils import JoinType @@ -319,43 +320,71 @@ def test_join_tables_plan_limit_offset(self): def test_join_tables_plan_order_by(self): query = parse_sql(""" + WITH tab2 AS ( + SELECT * FROM int2.tab2 limit 100 + ), + categories as ( + SELECT * FROM int3.cats + ) SELECT tab1.column1, tab2.column1, tab2.column2 - FROM int.tab1 INNER - JOIN int2.tab2 ON tab1.column1 > tab2.column1 + FROM int.tab1 tab1 + INNER JOIN tab2 ON tab1.column1 > tab2.column1 + WHERE tab2.category_id = (SELECT id FROM categories WHERE name='book') ORDER BY tab1.column1 LIMIT 10 """) subquery = copy.deepcopy(query) + subquery.cte = None subquery.from_table = None subquery.offset = None + subquery.where.args[1] = Parameter(Result(2)) - plan = plan_query(query, integrations=["int", "int2"]) + plan = plan_query(query, integrations=["int", "int2", "int3"], default_namespace="mindsdb") expected_plan = QueryPlan( integrations=["int"], steps=[ - FetchDataframeStepPartition( + FetchDataframeStep( step_num=0, + integration="int2", + query=parse_sql("select * from tab2 limit 100"), + ), + FetchDataframeStep( + step_num=1, + integration="int3", + query=parse_sql("select * from cats"), + ), + SubSelectStep( + step_num=2, + query=Select( + targets=[Identifier("id")], + where=BinaryOperation(op="=", args=[Identifier("name"), Constant("book")]), + ), + dataframe=Result(1), + table_name="categories", + ), + FetchDataframeStepPartition( + step_num=3, integration="int", - query=parse_sql("select column1 AS column1 from tab1 order by column1"), + query=parse_sql("select column1 AS column1 from tab1 AS tab1 order by column1"), condition={"limit": 10}, steps=[ - FetchDataframeStep( - step_num=1, - integration="int2", + SubSelectStep( + step_num=4, + dataframe=Result(0), query=Select( targets=[ - Identifier("column1", alias=Identifier("column1")), - Identifier("column2", alias=Identifier("column2")), + Star(), ], # Column pruning - from_table=Identifier("tab2"), + where=BinaryOperation(op="=", args=[Identifier("category_id"), Parameter(Result(2))]), ), + table_name="tab2", ), JoinStep( - step_num=2, - left=Result(0), - right=Result(1), + step_num=5, + left=Result(3), + right=Result(4), query=Join( left=Identifier("tab1"), right=Identifier("tab2"), @@ -367,7 +396,7 @@ def test_join_tables_plan_order_by(self): ), ], ), - QueryStep(subquery, from_table=Result(0), strict_where=False), + QueryStep(subquery, from_table=Result(3), strict_where=False), ], ) From af08622e9f92beac03954273432250bef2c647ed Mon Sep 17 00:00:00 2001 From: Lucas Koontz Date: Wed, 11 Mar 2026 21:19:28 -0700 Subject: [PATCH 039/125] Clean up blank lines in mindsdb.Dockerfile Removed unnecessary blank lines in Dockerfile. --- docker/mindsdb.Dockerfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/docker/mindsdb.Dockerfile b/docker/mindsdb.Dockerfile index 2712c854daf..84366dacfb0 100644 --- a/docker/mindsdb.Dockerfile +++ b/docker/mindsdb.Dockerfile @@ -18,8 +18,6 @@ COPY mindsdb/__about__.py mindsdb/ # Which will mean the next stage can be cached, even if the cache for the above stage was invalidated. - - # Use the stage from above to install our deps with as much caching as possible FROM python:3.10 AS build WORKDIR /mindsdb From 919cec464cc87e866ffacb968c7552856c2345c4 Mon Sep 17 00:00:00 2001 From: "Vignesh S.M" <90998381+vigbav36@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:30:03 +0530 Subject: [PATCH 040/125] Support additional gitlab configurations (#11741) --- .../handlers/gitlab_handler/gitlab_handler.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mindsdb/integrations/handlers/gitlab_handler/gitlab_handler.py b/mindsdb/integrations/handlers/gitlab_handler/gitlab_handler.py index cae572732d8..1a8bb0b2d6f 100644 --- a/mindsdb/integrations/handlers/gitlab_handler/gitlab_handler.py +++ b/mindsdb/integrations/handlers/gitlab_handler/gitlab_handler.py @@ -43,11 +43,14 @@ def connect(self) -> StatusResponse: connection_kwargs = {} - if self.connection_data.get("url", None): - connection_kwargs["url"] = self.connection_data["url"] - - if self.connection_data.get("api_key", None): - connection_kwargs["private_token"] = self.connection_data["api_key"] + connection_params = ["url", "api_key", "http_username", "http_password"] + + for connection_param in connection_params: + if connection_param in self.connection_data.keys(): + if connection_param == "api_key": + connection_kwargs["private_token"] = self.connection_data["api_key"] + else: + connection_kwargs[connection_param] = self.connection_data.get(connection_param, None) self.connection = gitlab.Gitlab(**connection_kwargs) self.is_connected = True From 27c6ead5a2f6c94009019782e584d84067b77165 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Thu, 12 Mar 2026 12:04:08 +0300 Subject: [PATCH 041/125] Update gitlab handler readme (#12285) --- mindsdb/integrations/handlers/gitlab_handler/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mindsdb/integrations/handlers/gitlab_handler/README.md b/mindsdb/integrations/handlers/gitlab_handler/README.md index d8ab4df33d3..4a960ee2cff 100644 --- a/mindsdb/integrations/handlers/gitlab_handler/README.md +++ b/mindsdb/integrations/handlers/gitlab_handler/README.md @@ -14,6 +14,8 @@ The GitLab handler is initialized with the following parameters: - `repository`: a required name of a GitLab repository to connect to - `api_key`: an optional GitLab API key to use for authentication - `url`: an optional GitLab server URL (defaults to https://gitlab.com) +- `http_username`: an optional username for HTTP authentication +- `http_password`: an optional password for HTTP authentication ## Implemented Features From c859bf3d1b296296a8f5c9809f9fb53774fc7f75 Mon Sep 17 00:00:00 2001 From: Hamoon Mohammadian Pour Date: Thu, 12 Mar 2026 15:59:38 +0330 Subject: [PATCH 042/125] Add ClickHouse data catalog (#11858) --- .../clickhouse_handler/clickhouse_handler.py | 318 +++++++++++++++++- .../handlers/hubspot_handler/README.md | 5 + 2 files changed, 321 insertions(+), 2 deletions(-) diff --git a/mindsdb/integrations/handlers/clickhouse_handler/clickhouse_handler.py b/mindsdb/integrations/handlers/clickhouse_handler/clickhouse_handler.py index feda48c1323..28836020e73 100644 --- a/mindsdb/integrations/handlers/clickhouse_handler/clickhouse_handler.py +++ b/mindsdb/integrations/handlers/clickhouse_handler/clickhouse_handler.py @@ -1,4 +1,5 @@ from urllib.parse import quote, urlencode +from typing import Optional, List import pandas as pd from sqlalchemy import create_engine @@ -8,7 +9,7 @@ from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender from mindsdb.utilities import log -from mindsdb.integrations.libs.base import DatabaseHandler +from mindsdb.integrations.libs.base import MetaDatabaseHandler from mindsdb.integrations.libs.response import ( HandlerStatusResponse as StatusResponse, HandlerResponse as Response, @@ -18,7 +19,7 @@ logger = log.getLogger(__name__) -class ClickHouseHandler(DatabaseHandler): +class ClickHouseHandler(MetaDatabaseHandler): """ This handler handles connection and execution of the ClickHouse statements. """ @@ -32,6 +33,7 @@ def __init__(self, name, connection_data, **kwargs): self.renderer = SqlalchemyRender(ClickHouseDialect) self.is_connected = False self.protocol = connection_data.get("protocol", "native") + self._has_is_nullable_column = None # Cache for version check def __del__(self): if self.is_connected is True: @@ -165,3 +167,315 @@ def get_columns(self, table_name) -> Response: q = f"DESCRIBE {table_name}" result = self.native_query(q) return result + + def _check_has_is_nullable_column(self) -> bool: + """ + Checks if the is_nullable column exists in system.columns table. + This column was added in ClickHouse 23.x. + + Returns: + bool: True if is_nullable column exists, False otherwise. + """ + if self._has_is_nullable_column is not None: + return self._has_is_nullable_column + + try: + check_query = """ + SELECT name + FROM system.columns + WHERE database = 'system' + AND table = 'columns' + AND name = 'is_nullable' + """ + result = self.native_query(check_query) + self._has_is_nullable_column = result.resp_type == RESPONSE_TYPE.TABLE and not result.data_frame.empty + except Exception as e: + logger.warning(f"Could not check for is_nullable column: {e}") + self._has_is_nullable_column = False + + return self._has_is_nullable_column + + def meta_get_tables(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves metadata information about the tables in the ClickHouse database + to be stored in the data catalog. + + Args: + table_names (list): A list of table names for which to retrieve metadata information. + + Returns: + Response: A response object containing the metadata information. + """ + database = self.connection_data["database"] + + query = f""" + SELECT + name as table_name, + database as table_schema, + engine as table_type, + comment as table_description, + total_rows as row_count + FROM system.tables + WHERE database = '{database}' + """ + + if table_names is not None and len(table_names) > 0: + quoted_names = [f"'{t}'" for t in table_names] + query += f" AND name IN ({','.join(quoted_names)})" + + query += " ORDER BY name" + + result = self.native_query(query) + return result + + def meta_get_columns(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves column metadata for the specified tables (or all tables if no list is provided). + This includes column comments that you can set in ClickHouse using: + ALTER TABLE table_name MODIFY COLUMN column_name Type COMMENT 'description' + + Args: + table_names (list): A list of table names for which to retrieve column metadata. + + Returns: + Response: A response object containing the column metadata. + """ + database = self.connection_data["database"] + + # Check if is_nullable column is available (ClickHouse 23.x+) + has_is_nullable = self._check_has_is_nullable_column() + + # Build the SELECT clause based on available columns + select_clause = """ + table as table_name, + name as column_name, + type as data_type, + comment as column_description, + default_expression as column_default""" + + if has_is_nullable: + select_clause += """, + is_nullable as is_nullable""" + + query = f""" + SELECT {select_clause} + FROM system.columns + WHERE database = '{database}' + """ + + if table_names is not None and len(table_names) > 0: + quoted_names = [f"'{t}'" for t in table_names] + query += f" AND table IN ({','.join(quoted_names)})" + + query += " ORDER BY table, position" + + result = self.native_query(query) + return result + + def meta_get_column_statistics(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves column statistics for the specified tables (or all tables if no list is provided). + Uses the base class implementation which calls meta_get_column_statistics_for_table for each table. + + Args: + table_names (list): A list of table names for which to retrieve column statistics. + + Returns: + Response: A response object containing the column statistics. + """ + # Use the base class implementation that calls meta_get_column_statistics_for_table + return super().meta_get_column_statistics(table_names) + + def meta_get_column_statistics_for_table( + self, table_name: str, column_names: Optional[List[str]] = None + ) -> Response: + """ + Retrieves column statistics for a specific table. + + Args: + table_name (str): The name of the table. + column_names (Optional[List[str]]): List of column names to retrieve statistics for. + If None, statistics for all columns will be returned. + Returns: + Response: A response object containing the column statistics for the table. + """ + database = self.connection_data["database"] + + # Get the list of columns for this table + columns_query = f""" + SELECT name, type + FROM system.columns + WHERE database = '{database}' AND table = '{table_name}' + """ + + if column_names: + quoted_names = [f"'{c}'" for c in column_names] + columns_query += f" AND name IN ({','.join(quoted_names)})" + + try: + columns_result = self.native_query(columns_query) + + if columns_result.resp_type == RESPONSE_TYPE.ERROR or columns_result.data_frame.empty: + logger.warning(f"No columns found for table {table_name}") + return Response(RESPONSE_TYPE.TABLE, pd.DataFrame()) + + # Build statistics query - collect all stats in one query + select_parts = [] + for _, row in columns_result.data_frame.iterrows(): + col = row["name"] + # Use backticks to handle special characters in column names + select_parts.extend( + [ + f"countIf(`{col}` IS NULL) AS nulls_{col}", + f"uniq(`{col}`) AS distincts_{col}", + f"toString(min(`{col}`)) AS min_{col}", + f"toString(max(`{col}`)) AS max_{col}", + ] + ) + + if not select_parts: + return Response(RESPONSE_TYPE.TABLE, pd.DataFrame()) + + # Build the query to get stats for all columns at once + stats_query = f""" + SELECT + count(*) AS total_rows, + {", ".join(select_parts)} + FROM `{database}`.`{table_name}` + """ + + stats_result = self.native_query(stats_query) + + if stats_result.resp_type != RESPONSE_TYPE.TABLE or stats_result.data_frame.empty: + logger.warning(f"Could not retrieve stats for table {table_name}") + # Return placeholder stats + placeholder_data = [] + for _, row in columns_result.data_frame.iterrows(): + placeholder_data.append( + { + "table_name": table_name, + "column_name": row["name"], + "null_percentage": None, + "distinct_values_count": None, + "most_common_values": None, + "most_common_frequencies": None, + "minimum_value": None, + "maximum_value": None, + } + ) + return Response(RESPONSE_TYPE.TABLE, pd.DataFrame(placeholder_data)) + + # Parse the stats result + stats_data = stats_result.data_frame.iloc[0] + total_rows = stats_data.get("total_rows", 0) + + # Build the final statistics DataFrame + all_stats = [] + for _, row in columns_result.data_frame.iterrows(): + col = row["name"] + nulls = stats_data.get(f"nulls_{col}", 0) + distincts = stats_data.get(f"distincts_{col}", None) + min_val = stats_data.get(f"min_{col}", None) + max_val = stats_data.get(f"max_{col}", None) + + # Calculate null percentage + null_pct = None + if total_rows is not None and total_rows > 0: + null_pct = round((nulls / total_rows) * 100, 2) + + all_stats.append( + { + "table_name": table_name, + "column_name": col, + "null_percentage": null_pct, + "distinct_values_count": distincts, + "most_common_values": None, + "most_common_frequencies": None, + "minimum_value": min_val, + "maximum_value": max_val, + } + ) + + return Response(RESPONSE_TYPE.TABLE, pd.DataFrame(all_stats)) + + except Exception as e: + logger.error(f"Exception while fetching statistics for table {table_name}: {e}") + # Return empty stats on error + return Response( + RESPONSE_TYPE.ERROR, error_message=f"Could not retrieve statistics for table {table_name}: {str(e)}" + ) + + def meta_get_primary_keys(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves primary key information for the specified tables (or all tables if no list is provided). + + Args: + table_names (list): A list of table names for which to retrieve primary key information. + + Returns: + Response: A response object containing the primary key information. + """ + database = self.connection_data["database"] + + query = f""" + SELECT + table as table_name, + name as column_name, + position as ordinal_position, + 'PRIMARY' as constraint_name + FROM system.columns + WHERE database = '{database}' + AND is_in_primary_key = 1 + """ + + if table_names is not None and len(table_names) > 0: + quoted_names = [f"'{t}'" for t in table_names] + query += f" AND table IN ({','.join(quoted_names)})" + + query += " ORDER BY table, position" + + result = self.native_query(query) + return result + + def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves foreign key information for the specified tables (or all tables if no list is provided). + Note: ClickHouse does not enforce foreign key constraints, but this method is provided for completeness. + + Args: + table_names (list): A list of table names for which to retrieve foreign key information. + + Returns: + Response: A response object containing an empty DataFrame (ClickHouse doesn't support foreign keys). + """ + # ClickHouse does not support foreign key constraints + # Return an empty DataFrame with the expected columns + df = pd.DataFrame( + columns=[ + "parent_table_name", + "parent_column_name", + "child_table_name", + "child_column_name", + "constraint_name", + ] + ) + return Response(RESPONSE_TYPE.TABLE, df) + + def meta_get_handler_info(self, **kwargs) -> str: + """ + Retrieves information about the ClickHouse handler design and implementation. + + Returns: + str: A string containing information about the ClickHouse handler's capabilities. + """ + return ( + "ClickHouse is a fast open-source column-oriented database management system.\n" + "Key features:\n" + "- Supports standard SQL syntax with some extensions\n" + "- Use backticks (`) to quote table and column names with special characters\n" + "- Does NOT support traditional foreign key constraints (they are not enforced)\n" + "- Optimized for analytical queries (OLAP) rather than transactional operations (OLTP)\n" + "- Supports various table engines (MergeTree, ReplacingMergeTree, SummingMergeTree, etc.)\n" + "- All ClickHouse functions are case-sensitive\n" + "- Native support for arrays, nested structures, and approximate algorithms\n" + ) diff --git a/mindsdb/integrations/handlers/hubspot_handler/README.md b/mindsdb/integrations/handlers/hubspot_handler/README.md index 032024df64e..2bd51529968 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/README.md +++ b/mindsdb/integrations/handlers/hubspot_handler/README.md @@ -111,6 +111,11 @@ Association tables are read-only and support `SELECT` only. They expose relation The handler provides `SHOW TABLES` and `information_schema.columns` support for all tables. Column statistics are sampled for core CRM and engagement tables. +**Important Notes on Field Values:** +- **Industry codes**: HubSpot uses predefined industry values (e.g., `COMPUTER_SOFTWARE`, `BIOTECHNOLOGY`, `FINANCIAL_SERVICES`). See [HubSpot's industry list](https://knowledge.hubspot.com/properties/hubspots-default-company-properties#industry) for all valid options. +- **Deal stages**: Each HubSpot account has custom pipeline stages. Use the stage IDs from your account (e.g., `presentationscheduled`, `closedwon`, `closedlost`, or numeric IDs like `110382973`). +- **Email validation**: Contact email addresses must be valid email formats (e.g., `user@example.com`). + ## Example Usage ### Basic Connection From 67d86983bcaf38d0b86a1fc395f4070a6772d2a1 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Thu, 12 Mar 2026 15:36:28 +0300 Subject: [PATCH 043/125] Fix HubSpot handler readme (#12287) --- mindsdb/integrations/handlers/hubspot_handler/README.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/README.md b/mindsdb/integrations/handlers/hubspot_handler/README.md index 2bd51529968..032024df64e 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/README.md +++ b/mindsdb/integrations/handlers/hubspot_handler/README.md @@ -111,11 +111,6 @@ Association tables are read-only and support `SELECT` only. They expose relation The handler provides `SHOW TABLES` and `information_schema.columns` support for all tables. Column statistics are sampled for core CRM and engagement tables. -**Important Notes on Field Values:** -- **Industry codes**: HubSpot uses predefined industry values (e.g., `COMPUTER_SOFTWARE`, `BIOTECHNOLOGY`, `FINANCIAL_SERVICES`). See [HubSpot's industry list](https://knowledge.hubspot.com/properties/hubspots-default-company-properties#industry) for all valid options. -- **Deal stages**: Each HubSpot account has custom pipeline stages. Use the stage IDs from your account (e.g., `presentationscheduled`, `closedwon`, `closedlost`, or numeric IDs like `110382973`). -- **Email validation**: Contact email addresses must be valid email formats (e.g., `user@example.com`). - ## Example Usage ### Basic Connection From 67ed223c3279a0d72ae007759d2ea89d0a428c1b Mon Sep 17 00:00:00 2001 From: Lukas Wolfsteiner Date: Thu, 12 Mar 2026 16:43:38 +0100 Subject: [PATCH 044/125] feat: Add Raindrop.io integration handler w/ full CRUD support & complete test coverage (#11300) --- .../handlers/raindrop_handler/README.md | 651 ++++++ .../handlers/raindrop_handler/__about__.py | 9 + .../handlers/raindrop_handler/__init__.py | 30 + .../raindrop_handler/connection_args.py | 16 + .../handlers/raindrop_handler/icon.svg | 1 + .../raindrop_handler/raindrop_handler.py | 379 +++ .../raindrop_handler/raindrop_tables.py | 1784 ++++++++++++++ .../raindrop_handler/requirements.txt | 0 .../raindrop_handler/tests/__init__.py | 1 + .../tests/test_raindrop_handler.py | 2068 +++++++++++++++++ .../tests/test_raindrop_integration.py | 178 ++ .../raindrop_handler/verify_implementation.py | 182 ++ 12 files changed, 5299 insertions(+) create mode 100644 mindsdb/integrations/handlers/raindrop_handler/README.md create mode 100644 mindsdb/integrations/handlers/raindrop_handler/__about__.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/__init__.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/connection_args.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/icon.svg create mode 100644 mindsdb/integrations/handlers/raindrop_handler/raindrop_handler.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/raindrop_tables.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/requirements.txt create mode 100644 mindsdb/integrations/handlers/raindrop_handler/tests/__init__.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_handler.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_integration.py create mode 100644 mindsdb/integrations/handlers/raindrop_handler/verify_implementation.py diff --git a/mindsdb/integrations/handlers/raindrop_handler/README.md b/mindsdb/integrations/handlers/raindrop_handler/README.md new file mode 100644 index 00000000000..835d24fd0d7 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/README.md @@ -0,0 +1,651 @@ +# Raindrop.io Handler + +Raindrop.io handler for MindsDB provides interfaces to connect to Raindrop.io via APIs and pull data into MindsDB. This handler also supports creating, updating, and deleting bookmarks and collections. + +--- + +## Table of Contents + +- [Raindrop.io Handler](#raindropio-handler) + - [Table of Contents](#table-of-contents) + - [About Raindrop.io](#about-raindropio) + - [Raindrop.io Handler Implementation](#raindropio-handler-implementation) + - [Raindrop.io Handler Initialization](#raindropio-handler-initialization) + - [How to Get Your Raindrop.io API Key](#how-to-get-your-raindropio-api-key) + - [Implemented Features](#implemented-features) + - [Tables](#tables) + - [Raindrops (Bookmarks)](#raindrops-bookmarks) + - [Collections](#collections) + - [Example Usage](#example-usage) + - [Connecting to Raindrop.io](#connecting-to-raindropio) + - [Selecting Bookmarks](#selecting-bookmarks) + - [Creating Bookmarks](#creating-bookmarks) + - [Updating Bookmarks](#updating-bookmarks) + - [Deleting Bookmarks](#deleting-bookmarks) + - [Working with Collections](#working-with-collections) + +--- + +## About Raindrop.io + +Raindrop.io is a bookmarking service that allows users to organize, save, and manage web bookmarks. It provides a clean interface for saving links, organizing them into collections, adding tags, notes, and highlights. Raindrop.io offers both personal and collaborative features for managing bookmarks across teams. + +Website: https://raindrop.io + +## Raindrop.io Handler Implementation + +This handler was implemented using the [Raindrop.io REST API v1](https://developer.raindrop.io/). The handler provides comprehensive support for managing bookmarks (called "raindrops" in the API) and collections through SQL-like operations. + +## Raindrop.io Handler Initialization + +The Raindrop.io handler is initialized with the following parameter: + +- `api_key`: a required Raindrop.io API access token + +## How to Get Your Raindrop.io API Key + +1. Sign up for an account on [Raindrop.io](https://raindrop.io) +2. Go to [App Management Console](https://app.raindrop.io/settings/integrations) +3. Create a new application or use an existing one +4. For testing purposes, you can copy the "Test token" from your application settings +5. For production use, implement OAuth2 flow as described in the [API documentation](https://developer.raindrop.io/v1/authentication/token) + +## Implemented Features + +### Raindrops (Bookmarks) Table +- [x] Support SELECT with advanced filtering and pagination + - [x] Filter by collection_id, search terms, title + - [x] Support for sorting by created, lastUpdate, sort, title + - [x] Support for LIMIT and pagination + - [x] Support for specific bookmark IDs + - [x] **NEW**: Advanced WHERE clause operators (>, <, >=, <=, BETWEEN, IN, LIKE) + - [x] **NEW**: Date range filtering with automatic datetime conversion + - [x] **NEW**: Complex condition combinations with AND/OR logic + - [x] **NEW**: Local filtering for non-API supported conditions +- [x] Support INSERT for creating new bookmarks + - [x] Single bookmark creation + - [x] Bulk bookmark creation +- [x] Support UPDATE for modifying existing bookmarks + - [x] Single bookmark updates + - [x] Bulk bookmark updates with WHERE conditions +- [x] Support DELETE for removing bookmarks + - [x] Single bookmark deletion + - [x] Bulk bookmark deletion with WHERE conditions + +### Collections Table +- [x] Support SELECT with filtering and pagination + - [x] Support for all collection fields + - [x] Support for LIMIT and ORDER BY +- [x] Support INSERT for creating new collections +- [x] Support UPDATE for modifying existing collections +- [x] Support DELETE for removing collections + - [x] Single collection deletion + - [x] Bulk collection deletion + +### Tags Table +- [x] Support SELECT for querying tag statistics + - [x] Tag usage counts and metadata + - [x] Support for filtering and sorting + - [x] Support for LIMIT and pagination +- [ ] Support INSERT for creating tags (not supported by API) +- [ ] Support UPDATE for modifying tags (not supported by API) +- [ ] Support DELETE for removing tags (not supported by API) + +### Parse Table +- [x] Support SELECT for URL metadata extraction + - [x] Extract title, description, and other metadata from URLs + - [x] Support for single URL parsing + - [x] Support for multiple URL parsing with IN operator + - [x] Error handling for invalid URLs +- [ ] Support INSERT for creating parsed URLs (read-only operation) +- [ ] Support UPDATE for modifying parsed URLs (read-only operation) +- [ ] Support DELETE for removing parsed URLs (read-only operation) + +### Bulk Operations Table +- [x] Support UPDATE for bulk collection moves + - [x] Move bookmarks by source collection ID + - [x] Move specific bookmarks by ID + - [x] Move bookmarks by search criteria + - [x] Batch processing with error handling +- [ ] Support SELECT for bulk operation status (not queryable) +- [ ] Support INSERT for bulk operations (not applicable) +- [ ] Support DELETE for bulk operations (use raindrops table) + +## Tables + +### Tags + +The `tags` table provides access to tag management and statistics from Raindrop.io. + +Available columns: +- `_id` (str): Unique tag identifier +- `label` (str): Tag name/label +- `count` (int): Number of bookmarks using this tag +- `created` (datetime): Tag creation timestamp +- `lastUpdate` (datetime): Last update timestamp + +**Note**: Direct tag creation, updates, and deletion are not supported by the Raindrop.io API. Tags are created automatically when bookmarks are tagged, and are removed automatically when no bookmarks use them. + +#### Selecting Tags + +```sql +-- Get all tags +SELECT * FROM raindrop_db.tags; + +-- Get tags sorted by usage count +SELECT label, count FROM raindrop_db.tags +ORDER BY count DESC; + +-- Get tags with specific usage count +SELECT label, count FROM raindrop_db.tags +WHERE count > 5 +ORDER BY count DESC; + +-- Get most popular tags (top 10) +SELECT label, count FROM raindrop_db.tags +ORDER BY count DESC +LIMIT 10; + +-- Get recently created tags +SELECT label, created FROM raindrop_db.tags +WHERE created > '2024-01-01' +ORDER BY created DESC; +``` + +### Parse + +The `parse` table provides URL metadata extraction functionality using Raindrop.io's parsing service. + +Available columns: +- `parsed_url` (str): The original URL that was parsed +- `title` (str): Extracted title from the URL +- `excerpt` (str): Brief description or excerpt from the URL +- `domain` (str): Domain name of the URL +- `type` (str): Content type (article, image, video, etc.) +- `cover` (str): Cover image URL if available +- `media` (list): Media attachments found on the page +- `lastUpdate` (datetime): Last update timestamp for the parsed content +- `error` (str): Error message if parsing failed + +**Note**: The parse table is read-only and used for extracting metadata from URLs before creating bookmarks. + +#### Parsing URLs + +```sql +-- Parse a single URL to extract metadata +SELECT parsed_url, title, excerpt, domain, type, cover +FROM raindrop_db.parse +WHERE url = 'https://example.com/article'; + +-- Parse multiple URLs at once +SELECT parsed_url, title, excerpt, domain +FROM raindrop_db.parse +WHERE url IN ('https://example1.com', 'https://example2.com', 'https://example3.com'); + +-- Get detailed metadata including media +SELECT parsed_url, title, excerpt, media, lastUpdate +FROM raindrop_db.parse +WHERE url = 'https://news.example.com/article'; + +-- Parse URLs with error handling +SELECT parsed_url, title, excerpt, error +FROM raindrop_db.parse +WHERE url IN ('https://valid-url.com', 'https://invalid-url.com'); + +-- Use parsed data to create bookmarks (combined query) +INSERT INTO raindrop_db.raindrops (link, title, excerpt, collection_id) +SELECT parsed_url, title, excerpt, 123 +FROM raindrop_db.parse +WHERE url = 'https://example.com/article-to-bookmark'; +``` + +### Raindrops (Bookmarks) + +Available columns: +- `_id` (int): Unique bookmark ID +- `link` (str): The URL of the bookmark +- `title` (str): Title of the bookmark +- `excerpt` (str): Brief description or excerpt +- `note` (str): Personal notes +- `type` (str): Type of bookmark (link, article, image, video, etc.) +- `cover` (str): Cover image URL +- `tags` (str): Comma-separated tags +- `important` (bool): Whether the bookmark is marked as important +- `reminder` (datetime): Reminder date/time +- `removed` (bool): Whether the bookmark is removed/trashed +- `created` (datetime): Creation timestamp +- `lastUpdate` (datetime): Last update timestamp +- `domain` (str): Domain of the bookmarked URL +- `collection.id` (int): ID of the collection containing this bookmark +- `collection.title` (str): Title of the collection +- `user.id` (int): ID of the user who owns this bookmark +- `broken` (bool): Whether the link is broken +- `cache` (str): Whether a cached copy exists +- `file.name` (str): File name (for file bookmarks) +- `file.size` (int): File size (for file bookmarks) +- `file.type` (str): File type (for file bookmarks) + +### Collections + +Available columns: +- `_id` (int): Unique collection ID +- `title` (str): Collection title +- `description` (str): Collection description +- `color` (str): Collection color (hex code) +- `view` (str): View type (list, grid, etc.) +- `public` (bool): Whether the collection is public +- `sort` (int): Sort order +- `count` (int): Number of bookmarks in collection +- `created` (datetime): Creation timestamp +- `lastUpdate` (datetime): Last update timestamp +- `expanded` (bool): Whether the collection is expanded in UI +- `parent.id` (int): Parent collection ID (for nested collections) +- `user.id` (int): ID of the user who owns this collection +- `cover` (str): Cover image URL +- `access.level` (int): Access level +- `access.draggable` (bool): Whether the collection can be dragged + +## Example Usage + +### Connecting to Raindrop.io + +```sql +CREATE DATABASE raindrop_db +WITH ENGINE = 'raindrop', +PARAMETERS = { + "api_key": "your_raindrop_api_token_here" +}; +``` + +### Selecting Bookmarks + +```sql +-- Get all bookmarks +SELECT * FROM raindrop_db.raindrops; + +-- Get bookmarks from a specific collection +SELECT * FROM raindrop_db.raindrops +WHERE collection_id = 12345; + +-- Search for bookmarks (enhanced search capabilities) +SELECT title, link, tags FROM raindrop_db.raindrops +WHERE search = 'programming' +LIMIT 10; + +-- Advanced search with multiple field searches (automatically optimized) +SELECT * FROM raindrop_db.raindrops +WHERE title = 'Python Tutorial' AND excerpt = 'Learn Python'; + +-- Optimized LIKE patterns (automatically converted to API search) +SELECT * FROM raindrop_db.raindrops +WHERE title LIKE '%python%' OR excerpt LIKE '%tutorial%'; + +-- Get bookmarks with specific tags +SELECT * FROM raindrop_db.raindrops +WHERE title LIKE '%python%' +ORDER BY created DESC; + +-- Get important bookmarks +SELECT title, link, created FROM raindrop_db.raindrops +WHERE important = true; + +-- Advanced filtering with comparison operators +SELECT * FROM raindrop_db.raindrops +WHERE created > '2024-01-01' +ORDER BY created DESC; + +SELECT title, link FROM raindrop_db.raindrops +WHERE sort <= 50 + AND important = true; + +-- Date range filtering +SELECT * FROM raindrop_db.raindrops +WHERE created BETWEEN '2024-01-01' AND '2024-12-31'; + +-- IN operator for multiple values +SELECT * FROM raindrop_db.raindrops +WHERE _id IN (123, 456, 789); + +SELECT * FROM raindrop_db.raindrops +WHERE collection_id IN (0, 1, 2); + +-- LIKE operator for pattern matching +SELECT * FROM raindrop_db.raindrops +WHERE title LIKE '%python%' + OR excerpt LIKE '%tutorial%'; + +-- Complex conditions with multiple filters +SELECT title, link, tags, created FROM raindrop_db.raindrops +WHERE created >= '2024-06-01' + AND important = true + AND (title LIKE '%project%' OR tags LIKE '%work%') +ORDER BY created DESC +LIMIT 20; + +-- Advanced filtering with multiple AND conditions +SELECT * FROM raindrop_db.raindrops +WHERE collection_id = 123 + AND created BETWEEN '2024-01-01' AND '2024-12-31' + AND important = true + AND sort > 10; + +-- Complex queries with mixed operators +SELECT title, link, excerpt FROM raindrop_db.raindrops +WHERE (title LIKE '%tutorial%' OR excerpt LIKE '%guide%') + AND created >= '2024-06-01' + AND _id NOT IN (123, 456, 789) +ORDER BY sort DESC; + +-- Query untagged bookmarks +SELECT * FROM raindrop_db.raindrops +WHERE tags = ""; + +-- Get specific columns for untagged bookmarks +SELECT _id, title, link, excerpt FROM raindrop_db.raindrops +WHERE tags = ""; + +-- Count untagged bookmarks +SELECT COUNT(*) FROM raindrop_db.raindrops +WHERE tags = ""; + +-- Get untagged bookmarks from a specific collection +SELECT * FROM raindrop_db.raindrops +WHERE collection_id = 0 AND tags = ""; + +-- Get untagged bookmarks sorted by creation date +SELECT title, link, created FROM raindrop_db.raindrops +WHERE tags = "" +ORDER BY created DESC; + +-- Get untagged bookmarks with additional filters +SELECT * FROM raindrop_db.raindrops +WHERE tags = "" AND important = true; + +-- Get recent untagged bookmarks +SELECT title, link, created FROM raindrop_db.raindrops +WHERE tags = "" AND created > '2024-01-01' +ORDER BY created DESC; + +-- Query broken links +SELECT * FROM raindrop_db.raindrops +WHERE broken = true; + +-- Count broken links (manual counting approach) +SELECT _id FROM raindrop_db.raindrops +WHERE broken = true; +-- Note: Use application-side counting for total count + +-- Get broken links with details +SELECT _id, title, link, domain, lastUpdate FROM raindrop_db.raindrops +WHERE broken = true; + +-- Get broken links from a specific collection +SELECT * FROM raindrop_db.raindrops +WHERE collection_id = 0 AND broken = true; + +-- Get broken links sorted by last update +SELECT title, link, lastUpdate FROM raindrop_db.raindrops +WHERE broken = true +ORDER BY lastUpdate DESC; + +-- Get broken links that are also important +SELECT * FROM raindrop_db.raindrops +WHERE broken = true AND important = true; +``` + +### Creating Bookmarks + +```sql +-- Create a single bookmark +INSERT INTO raindrop_db.raindrops (link, title, note, tags, collection_id) +VALUES ( + 'https://example.com', + 'Example Website', + 'This is a great example', + 'example,website,test', + 12345 +); + +-- Create multiple bookmarks +INSERT INTO raindrop_db.raindrops (link, title, collection_id) +VALUES + ('https://github.com', 'GitHub', 0), + ('https://stackoverflow.com', 'Stack Overflow', 0); +``` + +### Updating Bookmarks + +```sql +-- Update a specific bookmark +UPDATE raindrop_db.raindrops +SET title = 'New Title', note = 'Updated note', important = true +WHERE _id = 123456; + +-- Update multiple bookmarks +UPDATE raindrop_db.raindrops +SET collection_id = 54321 +WHERE tags LIKE '%oldtag%'; + +-- Mark bookmarks as important +UPDATE raindrop_db.raindrops +SET important = true +WHERE title LIKE '%important%'; +``` + +### Deleting Bookmarks + +```sql +-- Delete a specific bookmark +DELETE FROM raindrop_db.raindrops +WHERE _id = 123456; + +-- Delete bookmarks by search criteria +DELETE FROM raindrop_db.raindrops +WHERE tags LIKE '%obsolete%'; + +-- Delete old bookmarks +DELETE FROM raindrop_db.raindrops +WHERE created < '2023-01-01'; +``` + +### Working with Collections + +```sql +-- Get all collections +SELECT * FROM raindrop_db.collections; + +-- Create a new collection +INSERT INTO raindrop_db.collections (title, description, color, view) +VALUES ('Programming', 'Programming related bookmarks', '#FF0000', 'list'); + +-- Update a collection +UPDATE raindrop_db.collections +SET title = 'Web Development', color = '#00FF00' +WHERE _id = 12345; + +-- Delete a collection +DELETE FROM raindrop_db.collections +WHERE _id = 12345; + +-- Get collections with bookmark counts +SELECT title, count, lastUpdate FROM raindrop_db.collections +ORDER BY count DESC; +``` + +### Advanced Queries + +```sql +-- Get bookmarks with collection information +SELECT r.title, r.link, r.tags, c.title as collection_name +FROM raindrop_db.raindrops r +JOIN raindrop_db.collections c ON r.collection_id = c._id +WHERE r.important = true; + +-- Get recent bookmarks from multiple collections +SELECT title, link, created, collection_id +FROM raindrop_db.raindrops +WHERE collection_id IN (123, 456, 789) +AND created > '2024-01-01' +ORDER BY created DESC +LIMIT 20; + +-- Search across title and notes +SELECT title, link, note +FROM raindrop_db.raindrops +WHERE title LIKE '%python%' OR note LIKE '%python%' +ORDER BY lastUpdate DESC; +``` + +### Bulk Operations + +```sql +-- Move all bookmarks from one collection to another +UPDATE raindrop_db.bulk_operations +SET collection_id = 456 +WHERE source_collection_id = 123; + +-- Move specific bookmarks to a collection +UPDATE raindrop_db.bulk_operations +SET collection_id = 789 +WHERE _id IN (1, 2, 3, 4, 5); + +-- Move bookmarks matching search criteria +UPDATE raindrop_db.bulk_operations +SET collection_id = 999 +WHERE search = 'python tutorial'; + +-- Combine with other operations - move then update +UPDATE raindrop_db.bulk_operations +SET collection_id = 456 +WHERE source_collection_id = 123; + +UPDATE raindrop_db.raindrops +SET important = true +WHERE collection_id = 456 AND created > '2024-01-01'; +``` + +## API Rate Limits + +The Raindrop.io API has the following rate limits: +- 120 requests per minute for authenticated users +- The handler automatically handles pagination (API returns max 50 items per request) +- Bulk operations are used when possible to minimize API calls + +### Rate Limiting Features + +The handler includes intelligent rate limiting to prevent hitting API quotas: + +- **Automatic Throttling**: Requests are automatically delayed to stay within 120 requests/minute limit +- **Smart Pagination**: Page sizes are optimized based on LIMIT clauses (e.g., LIMIT 5 uses smaller pages) +- **Request Tracking**: The handler tracks request times and adds delays when approaching limits +- **Graceful Degradation**: Continues working even with rate limit errors + +**Example**: A `SELECT * FROM raindrop_db.raindrops LIMIT 5` query will use smaller page sizes and make fewer requests compared to larger queries. + +## Error Handling + +The handler includes comprehensive error handling: +- Connection validation on initialization +- Graceful fallback from bulk operations to individual operations when needed +- Proper error logging for debugging +- Handles API rate limiting and network errors + +## Recent Improvements + +### Version 0.0.1 Improvements +- **Robust Data Normalization**: Enhanced column normalization to handle missing nested fields gracefully +- **Defensive Column Checks**: Added checks to ensure all expected columns exist before query execution +- **Empty Data Handling**: Improved handling of empty API responses with proper column structure +- **Error Resilience**: Added try-catch blocks around data processing operations to prevent crashes +- **Logging Integration**: Replaced print statements with proper logging for better integration with MindsDB +- **Rate Limiting**: Implemented intelligent rate limiting to prevent API quota exhaustion (120 requests/minute) +- **Optimized Pagination**: Smart page sizing based on LIMIT clauses to minimize API calls +- **Request Throttling**: Automatic delays between requests to stay within API limits + +### Version 0.0.2 Improvements +- **Advanced WHERE Clause Operators**: Added support for >, <, >=, <=, BETWEEN, IN, and LIKE operators +- **Date Range Filtering**: Automatic datetime conversion and comparison for date-based filtering +- **Complex Condition Combinations**: Support for multiple AND/OR conditions in WHERE clauses +- **Advanced Filtering API**: New `/filters` endpoint integration for complex queries +- **Intelligent Query Routing**: Automatic selection between standard and advanced filtering endpoints +- **Fallback Mechanisms**: Graceful degradation when advanced endpoints are unavailable +- **Enhanced Error Handling**: Comprehensive error handling for API failures and edge cases +- **Local Filtering Engine**: Intelligent routing between API-supported and locally-processed filters +- **Enhanced Query Performance**: Optimized data fetching based on filter types and complexity +- **Comprehensive Test Coverage**: 49 unit tests covering all new filtering capabilities + +### Version 0.0.3 Improvements +- **Tags Table**: New `tags` table for tag management and statistics +- **Tag Statistics**: Access to tag usage counts and metadata +- **Tag Filtering**: Support for filtering and sorting tags by usage and creation date +- **API Integration**: Full integration with Raindrop.io `/tags` endpoint +- **Read-Only Operations**: Proper handling of API limitations for tag CRUD operations +- **Enhanced Documentation**: Comprehensive examples for tag queries +- **Test Coverage**: Additional unit tests for tags table functionality + +### Version 0.0.4 Improvements +- **Parse Table**: New `parse` table for URL metadata extraction +- **URL Metadata Extraction**: Extract title, description, domain, and media from URLs +- **Batch URL Parsing**: Support for parsing multiple URLs with IN operator +- **Error Handling**: Graceful error handling for invalid or unreachable URLs +- **API Integration**: Full integration with Raindrop.io `/parse` endpoint +- **Read-Only Operations**: Parse table designed as read-only for metadata extraction +- **Enhanced Documentation**: Comprehensive examples for URL parsing queries +- **Test Coverage**: Complete unit test suite for parse table functionality + +### Version 0.0.5 Improvements +- **Bulk Operations Table**: New `bulk_operations` table for bulk collection moves +- **Bulk Collection Moves**: Move multiple bookmarks between collections efficiently +- **Flexible Move Criteria**: Support for moving by collection ID, bookmark IDs, or search terms +- **API Integration**: Full integration with Raindrop.io bulk update endpoints +- **Error Handling**: Comprehensive error handling for bulk operations +- **SQL Interface**: User-friendly SQL interface for bulk operations +- **Enhanced Documentation**: Comprehensive examples for bulk move operations +- **Test Coverage**: Complete unit test suite for bulk operations functionality + +### Version 0.0.6 Improvements +- **Enhanced Full-Text Search**: Improved search capabilities with automatic optimization +- **Multi-Field Search**: Support for searching across title, excerpt, note, and tags fields +- **Smart LIKE Optimization**: Automatic conversion of simple LIKE patterns to API search +- **Field-Specific Search**: Convert field-specific searches to optimized API queries +- **Search Query Combination**: Intelligent combination of multiple search conditions +- **Preserved User Intent**: Respect explicit search queries while optimizing others +- **Performance Optimization**: Reduced local filtering by leveraging API search capabilities +- **Backward Compatibility**: All existing search functionality remains unchanged +- **Comprehensive Test Coverage**: 10 additional tests for search optimization features + +### Version 0.0.7 Improvements +- **Full API Compatibility**: Comprehensive evaluation and fixes for official Raindrop API compatibility +- **Fixed Endpoint URLs**: Corrected `/collections/childrens` to `/collections/children` to match API spec +- **Corrected Sort Parameters**: Fixed sort parameter format from `{field},-{direction}` to `field`/`-field` +- **Parameter Name Compliance**: Ensured all parameter names match official API specification (`perpage`, etc.) +- **Enhanced Security**: Updated allowed endpoints list with correct API paths +- **Rate Limiting Validation**: Verified rate limiting implementation matches API limits (120 requests/minute) +- **Authentication Compliance**: Confirmed Bearer token authentication format matches API requirements +- **Response Format Compatibility**: Verified response structure expectations match API responses +- **Error Handling Compatibility**: Ensured error handling matches API error response formats +- **Comprehensive Compatibility Tests**: 9 additional tests covering all aspects of API compatibility + +### Version 0.0.8 Improvements +- **Fixed Collections Query Error**: Resolved 404 error when querying collections due to invalid `/collections/children` endpoint +- **Simplified Collections API**: Removed separate child collections call, using single `/collections` endpoint for all collections +- **Updated Collections Logic**: Modified `get_collections()` to return both root and nested collections from single API call +- **Enhanced Error Handling**: Improved error handling for collections queries +- **Updated Test Suite**: Fixed existing tests to work with new collections API approach +- **Collections Endpoint Tests**: Added comprehensive tests for collections endpoint fix and table integration +- **API Compatibility Enhancement**: Further improved compatibility by fixing collections endpoint issues + +### Dependency Management +- Removed duplicate `requests` dependency from handler-specific requirements.txt +- All dependencies are now properly managed through the main requirements.txt file + +## Notes + +- The `raindrops` table has an alias `bookmarks` for convenience +- All date fields are automatically converted to pandas datetime objects +- Tags are stored as comma-separated strings for easier querying +- The handler supports both single and bulk operations for better performance +- Collection ID 0 represents "All bookmarks" (unsorted) +- Collection ID -1 represents "Unsorted" bookmarks +- Collection ID -99 represents "Trash" +- The `requests` dependency is already declared in the main requirements.txt file, so it's not included in this handler's requirements.txt to avoid duplication diff --git a/mindsdb/integrations/handlers/raindrop_handler/__about__.py b/mindsdb/integrations/handlers/raindrop_handler/__about__.py new file mode 100644 index 00000000000..681a687cf8d --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/__about__.py @@ -0,0 +1,9 @@ +__title__ = "MindsDB Raindrop.io handler" +__package_name__ = "mindsdb_raindrop_handler" +__version__ = "0.0.8" +__description__ = "MindsDB handler for Raindrop.io" +__author__ = "Lukas Wolfsteiner " +__github__ = "https://github.com/mindsdb/mindsdb" +__pypi__ = "https://pypi.org/project/mindsdb/" +__license__ = "MIT" +__copyright__ = "Copyright 2025 - MindsDB" diff --git a/mindsdb/integrations/handlers/raindrop_handler/__init__.py b/mindsdb/integrations/handlers/raindrop_handler/__init__.py new file mode 100644 index 00000000000..ff1d4d2b171 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/__init__.py @@ -0,0 +1,30 @@ +from mindsdb.integrations.libs.const import HANDLER_TYPE + +from .__about__ import __version__ as version, __description__ as description +from .connection_args import connection_args, connection_args_example + +try: + from .raindrop_handler import RaindropHandler as Handler + + import_error = None +except Exception as e: + Handler = None + import_error = e + +title = "Raindrop.io" +name = "raindrop" +type = HANDLER_TYPE.DATA +icon_path = "icon.svg" + +__all__ = [ + "Handler", + "version", + "name", + "type", + "title", + "description", + "import_error", + "icon_path", + "connection_args_example", + "connection_args", +] diff --git a/mindsdb/integrations/handlers/raindrop_handler/connection_args.py b/mindsdb/integrations/handlers/raindrop_handler/connection_args.py new file mode 100644 index 00000000000..82e7fae2f64 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/connection_args.py @@ -0,0 +1,16 @@ +from collections import OrderedDict + +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE + + +connection_args = OrderedDict( + api_key={ + "type": ARG_TYPE.PWD, + "description": "Raindrop.io API access token. You can get this from https://app.raindrop.io/settings/integrations", + "required": True, + "label": "API Key", + "secret": True, + }, +) + +connection_args_example = OrderedDict(api_key="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeee") diff --git a/mindsdb/integrations/handlers/raindrop_handler/icon.svg b/mindsdb/integrations/handlers/raindrop_handler/icon.svg new file mode 100644 index 00000000000..67ed32e1bfa --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mindsdb/integrations/handlers/raindrop_handler/raindrop_handler.py b/mindsdb/integrations/handlers/raindrop_handler/raindrop_handler.py new file mode 100644 index 00000000000..e2850a9f6b0 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/raindrop_handler.py @@ -0,0 +1,379 @@ +import requests +import time +from typing import Dict, Any, List + +from mindsdb_sql_parser import parse_sql + +from mindsdb.integrations.handlers.raindrop_handler.raindrop_tables import ( + RaindropsTable, + CollectionsTable, + TagsTable, + ParseTable, + BulkOperationsTable, +) +from mindsdb.integrations.libs.api_handler import APIHandler +from mindsdb.integrations.libs.response import ( + HandlerStatusResponse as StatusResponse, +) +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +class RaindropHandler(APIHandler): + """The Raindrop.io handler implementation""" + + def __init__(self, name: str, **kwargs): + """Initialize the Raindrop.io handler. + + Parameters + ---------- + name : str + name of a handler instance + """ + super().__init__(name) + + connection_data = kwargs.get("connection_data", {}) + self.connection_data = connection_data + self.kwargs = kwargs + + self.connection = None + self.is_connected = False + + # Register tables + self._register_table("raindrops", RaindropsTable(self)) + self._register_table("bookmarks", RaindropsTable(self)) # Alias for raindrops + self._register_table("collections", CollectionsTable(self)) + self._register_table("tags", TagsTable(self)) + self._register_table("parse", ParseTable(self)) + self._register_table("bulk_operations", BulkOperationsTable(self)) + + def connect(self) -> StatusResponse: + """Set up the connection required by the handler. + + Returns + ------- + StatusResponse + connection object + """ + if self.is_connected is True: + return self.connection + + api_key = self.connection_data.get("api_key") + if not api_key: + raise ValueError("API key is required for Raindrop.io connection") + + self.connection = RaindropAPIClient(api_key) + self.is_connected = True + + return self.connection + + def check_connection(self) -> StatusResponse: + """Check connection to the handler. + + Returns + ------- + StatusResponse + Status confirmation + """ + response = StatusResponse(False) + + try: + self.connect() + # Test the connection by getting user stats + test_response = self.connection.get_user_stats() + if test_response.get("result"): + logger.info("Successfully connected to Raindrop.io API") + response.success = True + else: + logger.error("Failed to connect to Raindrop.io API") + response.error_message = "Invalid API response" + except Exception as e: + logger.error(f"Error connecting to Raindrop.io API: {e}!") + response.error_message = str(e) + + self.is_connected = response.success + return response + + def native_query(self, query: str) -> StatusResponse: + """Receive and process a raw query. + + Parameters + ---------- + query : str + query in a native format + + Returns + ------- + StatusResponse + Request status + """ + ast = parse_sql(query) + return self.query(ast) + + +class RaindropAPIClient: + """A client for the Raindrop.io API""" + + def __init__(self, api_key: str): + self.api_key = api_key + self.base_url = "https://api.raindrop.io/rest/v1" + self.headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + # Rate limiting: 120 requests per minute = 2 requests per second + self.rate_limit_per_second = 2 + self.request_times = [] + + def _apply_rate_limit(self): + """Apply rate limiting to avoid hitting API limits""" + current_time = time.time() + + # Remove requests older than 1 second + self.request_times = [t for t in self.request_times if current_time - t < 1.0] + + # Check if we need to wait + if len(self.request_times) >= self.rate_limit_per_second: + # Calculate how long to wait + oldest_request = min(self.request_times) + wait_time = 1.0 - (current_time - oldest_request) + + if wait_time > 0: + logger.debug(f"Rate limit: waiting {wait_time:.2f} seconds") + time.sleep(wait_time) + # Update current_time after sleep + current_time = time.time() + # Clean up old requests again after sleep + self.request_times = [t for t in self.request_times if current_time - t < 1.0] + + # Record this request + self.request_times.append(current_time) + + def _make_request( + self, method: str, endpoint: str, params: Dict[str, Any] = None, data: Dict[str, Any] = None + ) -> Dict[str, Any]: + """Make a request to the Raindrop.io API with rate limiting""" + # Apply rate limiting + self._apply_rate_limit() + + # Validate endpoint to prevent path traversal/injection attacks + allowed_endpoints = [ + "/user/stats", + "/raindrops", + "/raindrop", + "/collections", + "/collection", + "/filters", + "/tags", + "/parse", + ] + + # Normalize endpoint by ensuring it starts with / + normalized_endpoint = f"/{endpoint.lstrip('/')}" + + # Check if endpoint matches any allowed prefix + if not any(normalized_endpoint.startswith(prefix) for prefix in allowed_endpoints): + raise ValueError(f"Invalid endpoint: {endpoint}. Only Raindrop.io API endpoints are allowed.") + + url = f"{self.base_url}{normalized_endpoint}" + + response = requests.request(method=method, url=url, headers=self.headers, params=params, json=data) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + try: + error_data = response.json() + error_message = error_data.get("error", error_data.get("message", str(e))) + except (ValueError, KeyError): + error_message = str(e) + raise Exception(f"Raindrop API error: {error_message}") + return response.json() + + def get_user_stats(self) -> Dict[str, Any]: + """Get user statistics""" + return self._make_request("GET", "/user/stats") + + # Raindrops (Bookmarks) methods + def get_raindrops( + self, + collection_id: int = 0, + search: str = None, + sort: str = None, + page: int = 0, + per_page: int = 50, + max_results: int = None, + ) -> Dict[str, Any]: + """Get raindrops from a collection with optimized pagination""" + all_items = [] + current_page = page + + # Optimize page size based on max_results to minimize API calls + if max_results and max_results <= 10: + # For small limits, use smaller page sizes to avoid wasting requests + per_page_limit = max(5, min(per_page, max_results)) + elif max_results and max_results <= 25: + per_page_limit = max(10, min(per_page, max_results)) + else: + per_page_limit = min(per_page, 50) # API limit is 50 + + while True: + params = {"page": current_page, "perpage": per_page_limit} + + if search: + params["search"] = search + if sort: + params["sort"] = sort + + response = self._make_request("GET", f"/raindrops/{collection_id}", params=params) + + if not response.get("result", False): + break + + items = response.get("items", []) + if not items: + break + + all_items.extend(items) + + # Check if we've reached max_results limit + if max_results and len(all_items) >= max_results: + all_items = all_items[:max_results] + break + + # Check if we got fewer items than requested (last page) + if len(items) < per_page_limit: + break + + current_page += 1 + + # Safety check: don't fetch more than 100 pages to prevent infinite loops and excessive API calls + # This allows fetching up to 5,000 bookmarks (100 pages * 50 per page) + if current_page > 100: + logger.warning("Stopping pagination after 100 pages to prevent excessive API usage") + break + + # Return response in same format as original API + return {"result": True, "items": all_items, "count": len(all_items)} + + def get_raindrop(self, raindrop_id: int) -> Dict[str, Any]: + """Get a single raindrop""" + return self._make_request("GET", f"/raindrop/{raindrop_id}") + + def create_raindrop(self, raindrop_data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new raindrop""" + return self._make_request("POST", "/raindrop", data=raindrop_data) + + def update_raindrop(self, raindrop_id: int, raindrop_data: Dict[str, Any]) -> Dict[str, Any]: + """Update an existing raindrop""" + return self._make_request("PUT", f"/raindrop/{raindrop_id}", data=raindrop_data) + + def delete_raindrop(self, raindrop_id: int) -> Dict[str, Any]: + """Delete a raindrop""" + return self._make_request("DELETE", f"/raindrop/{raindrop_id}") + + def create_multiple_raindrops(self, raindrops_data: list) -> Dict[str, Any]: + """Create multiple raindrops""" + return self._make_request("POST", "/raindrops", data={"items": raindrops_data}) + + def update_multiple_raindrops( + self, collection_id: int, update_data: Dict[str, Any], search: str = None, ids: list = None + ) -> Dict[str, Any]: + """Update multiple raindrops""" + data = update_data.copy() + if search: + data["search"] = search + if ids: + data["ids"] = ids + return self._make_request("PUT", f"/raindrops/{collection_id}", data=data) + + def delete_multiple_raindrops(self, collection_id: int, search: str = None, ids: list = None) -> Dict[str, Any]: + """Delete multiple raindrops""" + data = {} + if search: + data["search"] = search + if ids: + data["ids"] = ids + return self._make_request("DELETE", f"/raindrops/{collection_id}", data=data) + + def move_raindrops_to_collection( + self, target_collection_id: int, source_collection_id: int = None, search: str = None, ids: list = None + ) -> Dict[str, Any]: + """Move raindrops to a different collection""" + update_data = {"collection": {"$id": target_collection_id}} + data = update_data.copy() + if search: + data["search"] = search + if ids: + data["ids"] = ids + + endpoint = f"/raindrops/{source_collection_id}" if source_collection_id else "/raindrops/0" + return self._make_request("PUT", endpoint, data=data) + + # Collections methods + def get_collections(self) -> Dict[str, Any]: + """Get all collections (root and nested)""" + return self._make_request("GET", "/collections") + + def get_collection(self, collection_id: int) -> Dict[str, Any]: + """Get a single collection""" + return self._make_request("GET", f"/collection/{collection_id}") + + def create_collection(self, collection_data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new collection""" + return self._make_request("POST", "/collection", data=collection_data) + + def update_collection(self, collection_id: int, collection_data: Dict[str, Any]) -> Dict[str, Any]: + """Update an existing collection""" + return self._make_request("PUT", f"/collection/{collection_id}", data=collection_data) + + def delete_collection(self, collection_id: int) -> Dict[str, Any]: + """Delete a collection""" + return self._make_request("DELETE", f"/collection/{collection_id}") + + def delete_multiple_collections(self, collection_ids: list) -> Dict[str, Any]: + """Delete multiple collections""" + return self._make_request("DELETE", "/collections", data={"ids": collection_ids}) + + # Advanced filtering methods + def get_raindrops_with_filters(self, collection_id: int = 0, filters: Dict[str, Any] = None) -> Dict[str, Any]: + """Get raindrops using advanced filters endpoint""" + endpoint = f"/filters/{collection_id}" + return self._make_request("POST", endpoint, data=filters or {}) + + def get_tags(self) -> Dict[str, Any]: + """Get all tags with usage statistics""" + return self._make_request("GET", "/tags") + + def parse_url(self, url: str) -> Dict[str, Any]: + """Parse URL to extract metadata""" + return self._make_request("POST", "/parse", data={"url": url}) + + def search_raindrops_advanced( + self, + collection_id: int = 0, + search: str = None, + tags: List[str] = None, + important: bool = None, + sort: str = None, + page: int = 0, + per_page: int = 50, + ) -> Dict[str, Any]: + """Advanced search with multiple filter criteria""" + filters = {} + + if search: + filters["search"] = search + if tags: + filters["tags"] = tags + if important is not None: + filters["important"] = important + if sort: + filters["sort"] = sort + + # Add pagination parameters to filters if provided + if page is not None: + filters["page"] = page + if per_page is not None: + filters["perpage"] = per_page + + response = self.get_raindrops_with_filters(collection_id, filters) + return response diff --git a/mindsdb/integrations/handlers/raindrop_handler/raindrop_tables.py b/mindsdb/integrations/handlers/raindrop_handler/raindrop_tables.py new file mode 100644 index 00000000000..c667dd5fa29 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/raindrop_tables.py @@ -0,0 +1,1784 @@ +import pandas as pd +from typing import List, Dict, Any + +from mindsdb_sql_parser import ast +from mindsdb.integrations.libs.api_handler import APITable + +from mindsdb.integrations.utilities.handlers.query_utilities.select_query_utilities import ( + SELECTQueryParser, + SELECTQueryExecutor, +) +from mindsdb.integrations.utilities.handlers.query_utilities.delete_query_utilities import ( + DELETEQueryParser, + DELETEQueryExecutor, +) +from mindsdb.integrations.utilities.handlers.query_utilities.update_query_utilities import ( + UPDATEQueryParser, + UPDATEQueryExecutor, +) +from mindsdb.integrations.utilities.handlers.query_utilities import INSERTQueryParser + +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +class RaindropsTable(APITable): + """The Raindrop.io Raindrops (Bookmarks) Table implementation""" + + def select(self, query: ast.Select) -> pd.DataFrame: + """ + Pulls Raindrop.io raindrops data. + + Parameters + ---------- + query : ast.Select + Given SQL SELECT query + + Returns + ------- + pd.DataFrame + Raindrop.io raindrops matching the query + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + select_statement_parser = SELECTQueryParser(query, "raindrops", self.get_columns()) + ( + selected_columns, + where_conditions, + order_by_conditions, + result_limit, + ) = select_statement_parser.parse_query() + + # Parse WHERE conditions for Raindrop.io specific filters + collection_id = 0 # Default to All bookmarks + search_query = None + sort_order = None + raindrop_ids = [] + api_supported_conditions = [] # Conditions that can be handled by Raindrop.io API + local_filter_conditions = [] # Conditions that need local filtering + + # Parse conditions and categorize them + parsed_conditions = self._parse_where_conditions(where_conditions) + collection_id = parsed_conditions.get("collection_id", 0) + search_query = parsed_conditions.get("search") + sort_order = parsed_conditions.get("sort") + raindrop_ids = parsed_conditions.get("raindrop_ids", []) + api_supported_conditions = parsed_conditions.get("api_supported", []) + local_filter_conditions = parsed_conditions.get("local_filters", []) + complex_filters = parsed_conditions.get("complex_filters", {}) + + # Handle sorting + if order_by_conditions: + for order_condition in order_by_conditions: + if order_condition.column in ["created", "lastUpdate", "sort", "title"]: + if order_condition.ascending: + sort_order = order_condition.column + else: + sort_order = f"-{order_condition.column}" + break + + # If specific IDs are requested, try to fetch efficiently + if raindrop_ids: + raindrops_data = [] + # Process IDs individually with rate limiting + # Raindrop.io doesn't have bulk get endpoints, so we need to be careful with rate limits + for raindrop_id in raindrop_ids: + try: + response = self.handler.connection.get_raindrop(raindrop_id) + if response.get("result") and response.get("item"): + raindrops_data.append(response["item"]) + except Exception as e: + logger.warning(f"Failed to fetch raindrop {raindrop_id}: {e}") + continue + else: + # Check if we can use advanced filtering endpoint + if complex_filters and self._can_use_advanced_filters(complex_filters): + # Use advanced filtering endpoint + try: + response = self.handler.connection.get_raindrops_with_filters( + collection_id=collection_id, filters=complex_filters + ) + raindrops_data = response.get("items", []) + + # If advanced filtering worked, we might still need to apply local filters + # for conditions not supported by the advanced endpoint + if local_filter_conditions: + # Convert to DataFrame for local filtering + if raindrops_data: + temp_df = pd.json_normalize(raindrops_data) + temp_df = self._normalize_raindrop_data(temp_df) + temp_df = self._apply_local_filters(temp_df, local_filter_conditions) + raindrops_data = temp_df.to_dict("records") + + except Exception as e: + logger.warning(f"Advanced filtering failed, falling back to standard endpoint: {e}") + # Fall back to standard endpoint + raindrops_data = self._fetch_with_standard_endpoint( + collection_id, search_query, sort_order, result_limit, local_filter_conditions + ) + else: + # Use standard endpoint + raindrops_data = self._fetch_with_standard_endpoint( + collection_id, search_query, sort_order, result_limit, local_filter_conditions + ) + + # Convert to DataFrame + if raindrops_data: + raindrops_df = pd.json_normalize(raindrops_data) + raindrops_df = self._normalize_raindrop_data(raindrops_df) + else: + # Create empty DataFrame with all expected columns + raindrops_df = pd.DataFrame(columns=self.get_columns()) + + # Ensure all expected columns exist (defensive check) + expected_columns = self.get_columns() + for col in expected_columns: + if col not in raindrops_df.columns: + logger.warning(f"Missing column after normalization: {col}, adding as None") + raindrops_df[col] = None + + # Apply local filtering for advanced conditions + if local_filter_conditions: + raindrops_df = self._apply_local_filters(raindrops_df, local_filter_conditions) + + # Apply additional filtering and ordering using the executor (for any remaining conditions) + remaining_conditions = [ + cond + for cond in where_conditions + if cond not in api_supported_conditions and cond not in local_filter_conditions + ] + if remaining_conditions: + select_statement_executor = SELECTQueryExecutor( + raindrops_df, selected_columns, remaining_conditions, order_by_conditions + ) + raindrops_df = select_statement_executor.execute_query() + else: + # Apply ordering and column selection manually if no remaining conditions + if order_by_conditions: + raindrops_df = self._apply_ordering(raindrops_df, order_by_conditions) + if selected_columns and selected_columns != self.get_columns(): + available_columns = [col for col in selected_columns if col in raindrops_df.columns] + if available_columns: + raindrops_df = raindrops_df[available_columns] + + # Apply limit if needed + # Don't apply the default limit (20) when local filters were used, as this would + # artificially limit results when the user didn't specify a LIMIT + should_apply_limit = result_limit and ( + result_limit != 20 # Not the default limit + or not local_filter_conditions # No local filters were applied + ) + if should_apply_limit and len(raindrops_df) > result_limit: + raindrops_df = raindrops_df.head(result_limit) + + return raindrops_df + + def insert(self, query: ast.Insert) -> None: + """ + Inserts data into the Raindrop.io raindrops. + + Parameters + ---------- + query : ast.Insert + Given SQL INSERT query + + Returns + ------- + None + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + insert_statement_parser = INSERTQueryParser(query) + values_to_insert = insert_statement_parser.parse_query() + + # Process multiple or single inserts + if isinstance(values_to_insert, list): + # Multiple inserts + raindrops_data = [] + for row in values_to_insert: + raindrop_data = self._prepare_raindrop_data(row) + raindrops_data.append(raindrop_data) + + # Use batch insert if more than one item + if len(raindrops_data) > 1: + self.handler.connection.create_multiple_raindrops(raindrops_data) + else: + self.handler.connection.create_raindrop(raindrops_data[0]) + else: + # Single insert + raindrop_data = self._prepare_raindrop_data(values_to_insert) + self.handler.connection.create_raindrop(raindrop_data) + + def update(self, query: ast.Update) -> None: + """ + Updates data in the Raindrop.io raindrops. + + Parameters + ---------- + query : ast.Update + Given SQL UPDATE query + + Returns + ------- + None + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + update_statement_parser = UPDATEQueryParser(query) + values_to_update, where_conditions = update_statement_parser.parse_query() + + # Extract specific IDs and collection filters from WHERE conditions to avoid loading all data + raindrop_ids = [] + collection_id = None + search_query = None + + for condition in where_conditions: + if condition.column in ["_id", "id"]: + if isinstance(condition.value, list): + raindrop_ids.extend(condition.value) + else: + raindrop_ids.append(condition.value) + elif condition.column == "collection_id": + collection_id = condition.value + elif condition.column in ["search", "title"]: + search_query = condition.value + + # If we have specific IDs, update them directly without loading all data + if raindrop_ids: + for raindrop_id in raindrop_ids: + try: + update_data = self._prepare_raindrop_data(values_to_update) + self.handler.connection.update_raindrop(raindrop_id, update_data) + except Exception as e: + logger.error(f"Failed to update raindrop {raindrop_id}: {e}") + return + + # For complex filters, fetch only relevant data based on conditions + fetch_params = {} + if collection_id is not None: + fetch_params["collection_id"] = collection_id + if search_query: + fetch_params["search"] = search_query + + # Fetch only the relevant subset of data + raindrops_data = self.get_raindrops(**fetch_params) + + if not raindrops_data: + logger.warning("No raindrops found matching the WHERE conditions") + return + + raindrops_df = pd.json_normalize(raindrops_data) + raindrops_df = self._normalize_raindrop_data(raindrops_df) + + # Apply remaining filters + update_query_executor = UPDATEQueryExecutor(raindrops_df, where_conditions) + raindrops_df = update_query_executor.execute_query() + + if raindrops_df.empty: + logger.warning("No raindrops found matching the WHERE conditions") + return + + raindrop_ids = raindrops_df["_id"].tolist() + + # Check if we should do bulk update or individual updates + if len(raindrop_ids) > 1: + # Try bulk update first + collection_id = raindrops_df["collection.$id"].iloc[0] if "collection.$id" in raindrops_df.columns else 0 + + try: + update_data = self._prepare_raindrop_data(values_to_update) + self.handler.connection.update_multiple_raindrops( + collection_id=collection_id, update_data=update_data, ids=raindrop_ids + ) + except Exception as e: + logger.warning(f"Bulk update failed, falling back to individual updates: {e}") + # Fall back to individual updates + for raindrop_id in raindrop_ids: + try: + update_data = self._prepare_raindrop_data(values_to_update) + self.handler.connection.update_raindrop(raindrop_id, update_data) + except Exception as e: + logger.error(f"Failed to update raindrop {raindrop_id}: {e}") + else: + # Single update + raindrop_id = raindrop_ids[0] + update_data = self._prepare_raindrop_data(values_to_update) + self.handler.connection.update_raindrop(raindrop_id, update_data) + + def delete(self, query: ast.Delete) -> None: + """ + Deletes data from the Raindrop.io raindrops. + + Parameters + ---------- + query : ast.Delete + Given SQL DELETE query + + Returns + ------- + None + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + delete_statement_parser = DELETEQueryParser(query) + where_conditions = delete_statement_parser.parse_query() + + # Extract specific IDs and collection filters from WHERE conditions to avoid loading all data + raindrop_ids = [] + collection_id = None + search_query = None + + for condition in where_conditions: + if condition.column in ["_id", "id"]: + if isinstance(condition.value, list): + raindrop_ids.extend(condition.value) + else: + raindrop_ids.append(condition.value) + elif condition.column == "collection_id": + collection_id = condition.value + elif condition.column in ["search", "title"]: + search_query = condition.value + + # If we have specific IDs, delete them directly without loading all data + if raindrop_ids: + if len(raindrop_ids) > 1 and collection_id is not None: + # Try bulk delete if we know the collection + try: + self.handler.connection.delete_multiple_raindrops(collection_id=collection_id, ids=raindrop_ids) + return + except Exception as e: + logger.warning(f"Bulk delete failed, falling back to individual deletes: {e}") + + # Individual deletes + for raindrop_id in raindrop_ids: + try: + self.handler.connection.delete_raindrop(raindrop_id) + except Exception as e: + logger.error(f"Failed to delete raindrop {raindrop_id}: {e}") + return + + # For complex filters, fetch only relevant data based on conditions + fetch_params = {} + if collection_id is not None: + fetch_params["collection_id"] = collection_id + if search_query: + fetch_params["search"] = search_query + + # Fetch only the relevant subset of data + raindrops_data = self.get_raindrops(**fetch_params) + + if not raindrops_data: + logger.warning("No raindrops found matching the WHERE conditions") + return + + raindrops_df = pd.json_normalize(raindrops_data) + raindrops_df = self._normalize_raindrop_data(raindrops_df) + + # Apply remaining filters + delete_query_executor = DELETEQueryExecutor(raindrops_df, where_conditions) + raindrops_df = delete_query_executor.execute_query() + + if raindrops_df.empty: + logger.warning("No raindrops found matching the WHERE conditions") + return + + raindrop_ids = raindrops_df["_id"].tolist() + + # Check if we should do bulk delete or individual deletes + if len(raindrop_ids) > 1: + # Try bulk delete first + collection_id = raindrops_df["collection.$id"].iloc[0] if "collection.$id" in raindrops_df.columns else 0 + + try: + self.handler.connection.delete_multiple_raindrops(collection_id=collection_id, ids=raindrop_ids) + except Exception as e: + logger.warning(f"Bulk delete failed, falling back to individual deletes: {e}") + # Fall back to individual deletes + for raindrop_id in raindrop_ids: + try: + self.handler.connection.delete_raindrop(raindrop_id) + except Exception as e: + logger.error(f"Failed to delete raindrop {raindrop_id}: {e}") + else: + # Single delete + raindrop_id = raindrop_ids[0] + self.handler.connection.delete_raindrop(raindrop_id) + + def get_columns(self) -> List[str]: + """Get the column names for the raindrops table""" + return [ + "_id", + "link", + "title", + "excerpt", + "note", + "type", + "cover", + "tags", + "important", + "reminder", + "removed", + "created", + "lastUpdate", + "domain", + "collection.id", + "collection.title", + "user.id", + "broken", + "cache", + "file.name", + "file.size", + "file.type", + ] + + def get_raindrops(self, **kwargs) -> List[Dict]: + """Get raindrops data""" + if not self.handler.connection: + self.handler.connect() + + # Get from all collections by default + response = self.handler.connection.get_raindrops(**kwargs) + return response.get("items", []) + + def _normalize_raindrop_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Normalize raindrop data for consistent column structure""" + if df.empty: + return df + + # Process nested data first to extract flattened columns + try: + # Handle nested collection data + if "collection" in df.columns: + df["collection.id"] = df["collection"].apply(lambda x: x.get("$id") if isinstance(x, dict) else None) + df["collection.$id"] = df["collection"].apply(lambda x: x.get("$id") if isinstance(x, dict) else None) + df["collection.title"] = df["collection"].apply( + lambda x: x.get("title") if isinstance(x, dict) else None + ) + except Exception as e: + logger.warning(f"Error processing collection data: {e}") + + try: + # Handle nested user data + if "user" in df.columns: + df["user.id"] = df["user"].apply(lambda x: x.get("$id") if isinstance(x, dict) else None) + except Exception as e: + logger.warning(f"Error processing user data: {e}") + + try: + # Handle nested file data + if "file" in df.columns: + df["file.name"] = df["file"].apply(lambda x: x.get("name") if isinstance(x, dict) else None) + df["file.size"] = df["file"].apply(lambda x: x.get("size") if isinstance(x, dict) else None) + df["file.type"] = df["file"].apply(lambda x: x.get("type") if isinstance(x, dict) else None) + except Exception as e: + logger.warning(f"Error processing file data: {e}") + + # Convert tags list to string + try: + if "tags" in df.columns: + df["tags"] = df["tags"].apply(lambda x: ",".join(x) if isinstance(x, list) else x) + except Exception as e: + logger.warning(f"Error processing tags data: {e}") + + # Convert dates + for date_col in ["created", "lastUpdate"]: + try: + if date_col in df.columns: + df[date_col] = pd.to_datetime(df[date_col], errors="coerce") + except Exception as e: + logger.warning(f"Error processing date column {date_col}: {e}") + + # Ensure ALL expected columns exist, even if empty + # This must happen LAST to ensure any newly created columns are preserved + expected_columns = self.get_columns() + for col in expected_columns: + if col not in df.columns: + df[col] = None + + return df + + def _apply_local_filters(self, df: pd.DataFrame, conditions: List) -> pd.DataFrame: + """Apply local filtering for conditions not supported by Raindrop.io API""" + if df.empty or not conditions: + return df + + for condition in conditions: + # Handle different condition formats + if isinstance(condition, list) and len(condition) >= 3: + op, column, value = condition[0], condition[1], condition[2] + elif hasattr(condition, "op") and hasattr(condition, "column"): + op = getattr(condition, "op", "=") + column = condition.column + value = getattr(condition, "value", None) + else: + # Skip malformed conditions + logger.warning(f"Skipping malformed condition in local filter: {condition}") + continue + + if column not in df.columns: + logger.warning(f"Column '{column}' not found in DataFrame, skipping filter") + continue + + try: + if op == "=": + if isinstance(value, bool): + df = df[df[column] == value] + else: + df = df[df[column].astype(str).str.lower() == str(value).lower()] + elif op == "!=": + df = df[df[column] != value] + elif op == ">": + if column in ["created", "lastUpdate"]: + # Convert string dates to datetime for comparison + df[column] = pd.to_datetime(df[column], errors="coerce") + value = pd.to_datetime(value) + df = df[df[column] > value] + elif op == "<": + if column in ["created", "lastUpdate"]: + df[column] = pd.to_datetime(df[column], errors="coerce") + value = pd.to_datetime(value) + df = df[df[column] < value] + elif op == ">=": + if column in ["created", "lastUpdate"]: + df[column] = pd.to_datetime(df[column], errors="coerce") + value = pd.to_datetime(value) + df = df[df[column] >= value] + elif op == "<=": + if column in ["created", "lastUpdate"]: + df[column] = pd.to_datetime(df[column], errors="coerce") + value = pd.to_datetime(value) + df = df[df[column] <= value] + elif op == "between": + if column in ["created", "lastUpdate"]: + df[column] = pd.to_datetime(df[column], errors="coerce") + start_val, end_val = pd.to_datetime(value[0]), pd.to_datetime(value[1]) + else: + start_val, end_val = value + df = df[(df[column] >= start_val) & (df[column] <= end_val)] + elif op == "like": + # Simple LIKE implementation + pattern = str(value).replace("%", ".*").replace("_", ".") + df = df[df[column].astype(str).str.contains(pattern, case=False, regex=True, na=False)] + elif op == "in": + if isinstance(value, list): + df = df[df[column].isin(value)] + else: + df = df[df[column] == value] + else: + logger.warning(f"Unsupported operator '{op}' for column '{column}', skipping filter") + + except Exception as e: + logger.warning(f"Error applying filter {op} on column '{column}': {e}") + continue + + return df + + def _parse_where_conditions(self, conditions: List) -> Dict[str, Any]: + """Parse WHERE conditions and categorize them for different handling strategies""" + parsed = { + "collection_id": 0, + "search": None, + "sort": None, + "raindrop_ids": [], + "api_supported": [], + "local_filters": [], + "complex_filters": {}, + } + + # Collect all search-related conditions for potential optimization + search_conditions = [] + + for condition in conditions: + # Handle different condition formats + if isinstance(condition, list) and len(condition) >= 3: + op, column, value = condition[0], condition[1], condition[2] + elif hasattr(condition, "op") and hasattr(condition, "column"): + op = getattr(condition, "op", "=") + column = condition.column + value = getattr(condition, "value", None) + else: + # Skip malformed conditions + logger.warning(f"Skipping malformed condition: {condition}") + continue + + # Collect search-related conditions for optimization + if self._is_search_condition(column, op): + search_conditions.append((column, op, value, condition)) + + # Categorize conditions based on API support and complexity + # Defer search-related conditions until after optimization + if column == "collection_id" and op == "=": + parsed["collection_id"] = value + parsed["api_supported"].append(condition) + elif column == "search" and op == "=": + # Only handle direct search conditions, defer field-specific searches + parsed["search"] = value + parsed["api_supported"].append(condition) + elif (column in ["_id", "id"]) and op in ["=", "in"]: + if isinstance(value, list): + parsed["raindrop_ids"].extend(value) + else: + parsed["raindrop_ids"].append(value) + parsed["api_supported"].append(condition) + # Handle advanced conditions that need local filtering + elif column in ["created", "lastUpdate", "sort"] and op in [">", "<", ">=", "<=", "between"]: + parsed["local_filters"].append(condition) + elif column == "important" and op == "=": + parsed["local_filters"].append(condition) + elif column in ["domain"] and op in ["=", "like", "in"]: + # Only handle domain, defer other text fields until optimization + parsed["local_filters"].append(condition) + elif not self._is_search_condition(column, op): + # For non-search conditions, add to local filtering immediately + parsed["local_filters"].append(condition) + # Search-related conditions (title, excerpt, note, tags with = or like) are deferred + + # Optimize search conditions before final categorization + self._optimize_search_conditions(search_conditions, parsed) + + # Now categorize any remaining search conditions that weren't optimized + for column, op, value, original_condition in search_conditions: + if original_condition not in parsed["api_supported"] and original_condition not in parsed["local_filters"]: + # This condition wasn't optimized, add it to local filters + parsed["local_filters"].append(original_condition) + + # Build complex filters for advanced API endpoint if we have multiple criteria + if parsed["search"] or parsed["local_filters"]: + complex_filters = {} + if parsed["search"]: + complex_filters["search"] = parsed["search"] + + # Extract important flag if present in local filters + for condition in parsed["local_filters"]: + if isinstance(condition, list) and len(condition) >= 3: + op, column, value = condition[0], condition[1], condition[2] + elif hasattr(condition, "op") and hasattr(condition, "column"): + op = getattr(condition, "op", "=") + column = condition.column + value = getattr(condition, "value", None) + else: + continue + + if column == "important" and op == "=": + complex_filters["important"] = value + elif column == "tags" and op in ["=", "in"]: + if isinstance(value, list): + complex_filters["tags"] = value + else: + complex_filters["tags"] = [value] + + if complex_filters: + parsed["complex_filters"] = complex_filters + + return parsed + + def _is_search_condition(self, column: str, op: str) -> bool: + """Check if a condition is search-related""" + return ((column in ["search", "title", "excerpt", "note", "tags"]) and op in ["=", "like"]) or ( + column == "search" and op == "=" + ) + + def _optimize_search_conditions(self, search_conditions: List, parsed: Dict[str, Any]) -> None: + """Optimize multiple search conditions into a single API search query""" + if not search_conditions: + return + + # Check if we have a direct search condition (user explicitly specified search) + has_direct_search = any(column == "search" and op == "=" for column, op, _, _ in search_conditions) + + # If user specified a direct search, still process other conditions but don't combine them + # into the search query - just mark them as API supported if they can be optimized + + # Collect all text-based search terms + search_terms = [] + like_conditions = [] + + for column, op, value, original_condition in search_conditions: + if op == "=" and column in ["title", "excerpt", "note"]: + # If we have a direct search, don't combine field-specific searches + # but still mark them as API supported if they can be optimized + if not has_direct_search: + # Convert field-specific searches to general search terms + if column == "title": + search_terms.append(f"title:{value}") + elif column == "excerpt": + search_terms.append(f"excerpt:{value}") + elif column == "note": + search_terms.append(f"note:{value}") + + # Remove from local filters and mark as API supported + if original_condition in parsed["local_filters"]: + parsed["local_filters"].remove(original_condition) + if original_condition not in parsed["api_supported"]: + parsed["api_supported"].append(original_condition) + + elif op == "like" and column in ["title", "excerpt", "note", "tags"]: + like_conditions.append((column, op, value, original_condition)) + + # If we have multiple field-specific searches and no direct search, combine them + if search_terms and not has_direct_search: + combined_search = " ".join(search_terms) + if len(search_terms) > 1: + # For multiple terms, use AND logic + combined_search = f"({' AND '.join(search_terms)})" + + parsed["search"] = combined_search + + # Optimize simple LIKE patterns that can use API search + for column, op, value, original_condition in like_conditions: + if self._can_use_api_search_for_like(column, value): + # Convert simple LIKE patterns to API search + api_search_term = self._convert_like_to_api_search(column, value) + if api_search_term: + if not has_direct_search: + if parsed["search"]: + parsed["search"] += f" {api_search_term}" + else: + parsed["search"] = api_search_term + + # Remove from local filters and mark as API supported + if original_condition in parsed["local_filters"]: + parsed["local_filters"].remove(original_condition) + if original_condition not in parsed["api_supported"]: + parsed["api_supported"].append(original_condition) + + def _can_use_api_search_for_like(self, column: str, value: str) -> bool: + """Check if a LIKE pattern can be efficiently handled by API search""" + if not isinstance(value, str): + return False + + # Only optimize simple patterns that start and end with % + if not (value.startswith("%") and value.endswith("%")): + return False + + # Remove % and check if it's a simple word/pattern + pattern = value.strip("%") + + # Don't optimize if pattern contains regex special chars, % in middle, or is too short + if len(pattern) < 3 or any(char in pattern for char in ".*+?^$()[]{}|\\") or "%" in pattern: + return False + + return column in ["title", "excerpt", "note", "tags"] + + def _convert_like_to_api_search(self, column: str, value: str) -> str: + """Convert LIKE pattern to API search format""" + if not isinstance(value, str): + return None + + pattern = value.strip("%") + + # Create field-specific search term + if column == "title": + return f"title:{pattern}" + elif column == "excerpt": + return f"excerpt:{pattern}" + elif column == "note": + return f"note:{pattern}" + elif column == "tags": + return f"tag:{pattern}" + + return pattern + + def _can_use_advanced_filters(self, complex_filters: Dict[str, Any]) -> bool: + """Check if we can use the advanced filtering endpoint""" + # Use advanced filters if we have search, important, or tags criteria + return any(key in complex_filters for key in ["search", "important", "tags"]) + + def _fetch_with_standard_endpoint( + self, collection_id: int, search_query: str, sort_order: str, result_limit: int, local_filter_conditions: List + ) -> List[Dict]: + """Fetch data using the standard Raindrop.io endpoint""" + # If we have local filters, we may need to fetch more data than requested limit + # to ensure we have enough data to filter locally + fetch_limit = None + if local_filter_conditions: + # If we have local filters, fetch more data to account for filtering + # We'll apply the original limit after local filtering + if result_limit and result_limit != 20: # 20 is the default limit when no LIMIT is specified + fetch_limit = result_limit * 5 # Fetch more data for local filtering + else: + fetch_limit = None # Fetch all data when no limit specified or default limit + else: + fetch_limit = result_limit + + response = self.handler.connection.get_raindrops( + collection_id=collection_id, + search=search_query, + sort=sort_order, + page=0, + per_page=50, + max_results=fetch_limit, + ) + return response.get("items", []) + + def _apply_ordering(self, df: pd.DataFrame, order_by_conditions) -> pd.DataFrame: + """Apply ordering to DataFrame""" + if not order_by_conditions or df.empty: + return df + + sort_cols = [] + ascending = [] + + for order_condition in order_by_conditions: + column = getattr(order_condition, "column", getattr(order_condition, "field", None)) + if column and column in df.columns: + sort_cols.append(column) + ascending.append(getattr(order_condition, "ascending", True)) + + if sort_cols: + df = df.sort_values(by=sort_cols, ascending=ascending) + + return df + + def _prepare_raindrop_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Prepare raindrop data for API submission""" + raindrop_data = {} + + # Map common fields + field_mappings = { + "link": "link", + "title": "title", + "excerpt": "excerpt", + "note": "note", + "type": "type", + "cover": "cover", + "important": "important", + "collection_id": "collection", + "collection.id": "collection", + } + + for key, value in data.items(): + if key in field_mappings: + api_key = field_mappings[key] + if api_key == "collection" and value: + raindrop_data[api_key] = {"$id": int(value)} + elif key == "important" and value is not None: + raindrop_data[api_key] = bool(value) + elif value is not None: + raindrop_data[api_key] = value + + # Handle tags (convert string to list) + if "tags" in data and data["tags"]: + if isinstance(data["tags"], str): + raindrop_data["tags"] = [tag.strip() for tag in data["tags"].split(",")] + elif isinstance(data["tags"], list): + raindrop_data["tags"] = data["tags"] + + return raindrop_data + + +class CollectionsTable(APITable): + """The Raindrop.io Collections Table implementation""" + + def select(self, query: ast.Select) -> pd.DataFrame: + """ + Pulls Raindrop.io collections data. + + Parameters + ---------- + query : ast.Select + Given SQL SELECT query + + Returns + ------- + pd.DataFrame + Raindrop.io collections matching the query + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + select_statement_parser = SELECTQueryParser(query, "collections", self.get_columns()) + ( + selected_columns, + where_conditions, + order_by_conditions, + result_limit, + ) = select_statement_parser.parse_query() + + # Get collections data + collections_data = self.get_collections() + + # Convert to DataFrame + if collections_data: + collections_df = pd.json_normalize(collections_data) + collections_df = self._normalize_collection_data(collections_df) + else: + # Create empty DataFrame with all expected columns + collections_df = pd.DataFrame(columns=self.get_columns()) + + # Ensure all expected columns exist (defensive check) + expected_columns = self.get_columns() + for col in expected_columns: + if col not in collections_df.columns: + logger.warning(f"Missing column after normalization: {col}, adding as None") + collections_df[col] = None + + # Apply filtering and ordering + select_statement_executor = SELECTQueryExecutor( + collections_df, selected_columns, where_conditions, order_by_conditions + ) + collections_df = select_statement_executor.execute_query() + + # Apply limit if needed + if result_limit and len(collections_df) > result_limit: + collections_df = collections_df.head(result_limit) + + return collections_df + + def insert(self, query: ast.Insert) -> None: + """ + Inserts data into the Raindrop.io collections. + + Parameters + ---------- + query : ast.Insert + Given SQL INSERT query + + Returns + ------- + None + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + insert_statement_parser = INSERTQueryParser(query) + values_to_insert = insert_statement_parser.parse_query() + + if isinstance(values_to_insert, list): + # Multiple inserts + for row in values_to_insert: + collection_data = self._prepare_collection_data(row) + self.handler.connection.create_collection(collection_data) + else: + # Single insert + collection_data = self._prepare_collection_data(values_to_insert) + self.handler.connection.create_collection(collection_data) + + def update(self, query: ast.Update) -> None: + """ + Updates data in the Raindrop.io collections. + + Parameters + ---------- + query : ast.Update + Given SQL UPDATE query + + Returns + ------- + None + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + update_statement_parser = UPDATEQueryParser(query) + values_to_update, where_conditions = update_statement_parser.parse_query() + + # Extract specific IDs from WHERE conditions to avoid loading all data + collection_ids = [] + + for condition in where_conditions: + if condition.column in ["_id", "id"]: + if isinstance(condition.value, list): + collection_ids.extend(condition.value) + else: + collection_ids.append(condition.value) + + # If we have specific IDs, update them directly without loading all data + if collection_ids: + for collection_id in collection_ids: + try: + update_data = self._prepare_collection_data(values_to_update) + self.handler.connection.update_collection(collection_id, update_data) + except Exception as e: + logger.error(f"Failed to update collection {collection_id}: {e}") + return + + # For complex filters, we need to fetch and filter collections + # Since collections are typically fewer in number than raindrops, this is more acceptable + collections_data = self.get_collections() + + if not collections_data: + logger.warning("No collections found") + return + + collections_df = pd.json_normalize(collections_data) + collections_df = self._normalize_collection_data(collections_df) + + # Apply filters + update_query_executor = UPDATEQueryExecutor(collections_df, where_conditions) + collections_df = update_query_executor.execute_query() + + if collections_df.empty: + logger.warning("No collections found matching the WHERE conditions") + return + + collection_ids = collections_df["_id"].tolist() + + # Update each collection individually + for collection_id in collection_ids: + try: + update_data = self._prepare_collection_data(values_to_update) + self.handler.connection.update_collection(collection_id, update_data) + except Exception as e: + logger.error(f"Failed to update collection {collection_id}: {e}") + + def delete(self, query: ast.Delete) -> None: + """ + Deletes data from the Raindrop.io collections. + + Parameters + ---------- + query : ast.Delete + Given SQL DELETE query + + Returns + ------- + None + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + delete_statement_parser = DELETEQueryParser(query) + where_conditions = delete_statement_parser.parse_query() + + # Extract specific IDs from WHERE conditions to avoid loading all data + collection_ids = [] + + for condition in where_conditions: + if condition.column in ["_id", "id"]: + if isinstance(condition.value, list): + collection_ids.extend(condition.value) + else: + collection_ids.append(condition.value) + + # If we have specific IDs, delete them directly without loading all data + if collection_ids: + if len(collection_ids) > 1: + try: + self.handler.connection.delete_multiple_collections(collection_ids) + return + except Exception as e: + logger.warning(f"Bulk delete failed, falling back to individual deletes: {e}") + + # Individual deletes + for collection_id in collection_ids: + try: + self.handler.connection.delete_collection(collection_id) + except Exception as e: + logger.error(f"Failed to delete collection {collection_id}: {e}") + return + + # For complex filters, we need to fetch and filter collections + # Since collections are typically fewer in number than raindrops, this is more acceptable + collections_data = self.get_collections() + + if not collections_data: + logger.warning("No collections found") + return + + collections_df = pd.json_normalize(collections_data) + collections_df = self._normalize_collection_data(collections_df) + + # Apply filters + delete_query_executor = DELETEQueryExecutor(collections_df, where_conditions) + collections_df = delete_query_executor.execute_query() + + if collections_df.empty: + logger.warning("No collections found matching the WHERE conditions") + return + + collection_ids = collections_df["_id"].tolist() + + # Check if we should do bulk delete or individual deletes + if len(collection_ids) > 1: + try: + self.handler.connection.delete_multiple_collections(collection_ids) + except Exception as e: + logger.warning(f"Bulk delete failed, falling back to individual deletes: {e}") + # Fall back to individual deletes + for collection_id in collection_ids: + try: + self.handler.connection.delete_collection(collection_id) + except Exception as e: + logger.error(f"Failed to delete collection {collection_id}: {e}") + else: + # Single delete + collection_id = collection_ids[0] + self.handler.connection.delete_collection(collection_id) + + def get_columns(self) -> List[str]: + """Get the column names for the collections table""" + return [ + "_id", + "title", + "description", + "color", + "view", + "public", + "sort", + "count", + "created", + "lastUpdate", + "expanded", + "parent.id", + "user.id", + "cover", + "access.level", + "access.draggable", + ] + + def get_collections(self, **kwargs) -> List[Dict]: + """Get collections data""" + if not self.handler.connection: + self.handler.connect() + + # Get all collections (root and nested) from the main collections endpoint + response = self.handler.connection.get_collections() + return response.get("items", []) + + def _normalize_collection_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Normalize collection data for consistent column structure""" + if df.empty: + return df + + # Process nested data first to extract flattened columns + try: + # Handle nested parent data + if "parent" in df.columns: + df["parent.id"] = df["parent"].apply(lambda x: x.get("$id") if isinstance(x, dict) else None) + except Exception as e: + logger.warning(f"Error processing parent data: {e}") + + try: + # Handle nested user data + if "user" in df.columns: + df["user.id"] = df["user"].apply(lambda x: x.get("$id") if isinstance(x, dict) else None) + except Exception as e: + logger.warning(f"Error processing user data: {e}") + + try: + # Handle nested access data + if "access" in df.columns: + df["access.level"] = df["access"].apply(lambda x: x.get("level") if isinstance(x, dict) else None) + df["access.draggable"] = df["access"].apply( + lambda x: x.get("draggable") if isinstance(x, dict) else None + ) + except Exception as e: + logger.warning(f"Error processing access data: {e}") + + # Convert cover list to string + try: + if "cover" in df.columns: + df["cover"] = df["cover"].apply(lambda x: x[0] if isinstance(x, list) and x else x) + except Exception as e: + logger.warning(f"Error processing cover data: {e}") + + # Convert dates + for date_col in ["created", "lastUpdate"]: + try: + if date_col in df.columns: + df[date_col] = pd.to_datetime(df[date_col], errors="coerce") + except Exception as e: + logger.warning(f"Error processing date column {date_col}: {e}") + + # Ensure ALL expected columns exist, even if empty + # This must happen LAST to ensure any newly created columns are preserved + expected_columns = self.get_columns() + for col in expected_columns: + if col not in df.columns: + df[col] = None + + return df + + def _prepare_collection_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Prepare collection data for API submission""" + collection_data = {} + + # Map common fields + field_mappings = { + "title": "title", + "description": "description", + "color": "color", + "view": "view", + "public": "public", + "sort": "sort", + "parent_id": "parent", + "parent.id": "parent", + } + + for key, value in data.items(): + if key in field_mappings: + api_key = field_mappings[key] + if api_key == "parent" and value: + collection_data[api_key] = {"$id": int(value)} + elif key in ["public"] and value is not None: + collection_data[api_key] = bool(value) + elif value is not None: + collection_data[api_key] = value + + return collection_data + + +class TagsTable(APITable): + """The Raindrop.io Tags Table implementation""" + + def select(self, query: ast.Select) -> pd.DataFrame: + """ + Pulls Raindrop.io tags data. + + Parameters + ---------- + query : ast.Select + Given SQL SELECT query + + Returns + ------- + pd.DataFrame + Raindrop.io tags with usage statistics + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + select_statement_parser = SELECTQueryParser(query, "tags", self.get_columns()) + ( + selected_columns, + where_conditions, + order_by_conditions, + result_limit, + ) = select_statement_parser.parse_query() + + # Get tags data from API + tags_data = self.get_tags() + + # Convert to DataFrame + if tags_data: + tags_df = pd.json_normalize(tags_data) + tags_df = self._normalize_tags_data(tags_df) + else: + # Create empty DataFrame with all expected columns + tags_df = pd.DataFrame(columns=self.get_columns()) + + # Ensure all expected columns exist (defensive check) + expected_columns = self.get_columns() + for col in expected_columns: + if col not in tags_df.columns: + logger.warning(f"Missing column after normalization: {col}, adding as None") + tags_df[col] = None + + # Apply filtering and ordering using the executor + select_statement_executor = SELECTQueryExecutor( + tags_df, selected_columns, where_conditions, order_by_conditions + ) + tags_df = select_statement_executor.execute_query() + + # Apply limit if needed + if result_limit and len(tags_df) > result_limit: + tags_df = tags_df.head(result_limit) + + return tags_df + + def insert(self, query: ast.Insert) -> None: + """ + Tags are typically created automatically when bookmarks are tagged. + Direct tag creation is not supported by the Raindrop.io API. + + Parameters + ---------- + query : ast.Insert + Given SQL INSERT query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + Direct tag creation is not supported + """ + raise NotImplementedError( + "Direct tag creation is not supported by Raindrop.io API. " + "Tags are created automatically when bookmarks are tagged." + ) + + def update(self, query: ast.Update) -> None: + """ + Tag updates are typically handled through bookmark updates. + Direct tag updates are not supported by the Raindrop.io API. + + Parameters + ---------- + query : ast.Update + Given SQL UPDATE query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + Direct tag updates are not supported + """ + raise NotImplementedError( + "Direct tag updates are not supported by Raindrop.io API. Tag updates are handled through bookmark updates." + ) + + def delete(self, query: ast.Delete) -> None: + """ + Tag deletion removes the tag from all bookmarks. + This operation is not supported by the Raindrop.io API. + + Parameters + ---------- + query : ast.Delete + Given SQL DELETE query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + Tag deletion is not supported + """ + raise NotImplementedError( + "Tag deletion is not supported by Raindrop.io API. " + "Tags are removed automatically when no bookmarks use them." + ) + + def get_columns(self) -> List[str]: + """Get the column names for the tags table""" + return [ + "_id", + "label", + "count", + "created", + "lastUpdate", + ] + + def get_tags(self) -> List[Dict]: + """Get tags data""" + if not self.handler.connection: + self.handler.connect() + + response = self.handler.connection.get_tags() + return response.get("items", []) + + def _normalize_tags_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Normalize tags data for consistent column structure""" + if df.empty: + return df + + # Convert dates + for date_col in ["created", "lastUpdate"]: + try: + if date_col in df.columns: + df[date_col] = pd.to_datetime(df[date_col], errors="coerce") + except Exception as e: + logger.warning(f"Error processing date column {date_col}: {e}") + + # Ensure ALL expected columns exist, even if empty + expected_columns = self.get_columns() + for col in expected_columns: + if col not in df.columns: + df[col] = None + + return df + + +class ParseTable(APITable): + """The Raindrop.io Parse Table implementation for URL metadata extraction""" + + def select(self, query: ast.Select) -> pd.DataFrame: + """ + Parse URLs to extract metadata. + + Parameters + ---------- + query : ast.Select + Given SQL SELECT query + + Returns + ------- + pd.DataFrame + URL metadata from parsed URLs + + Raises + ------ + ValueError + If the query contains an unsupported condition + """ + select_statement_parser = SELECTQueryParser(query, "parse", self.get_columns()) + ( + selected_columns, + where_conditions, + order_by_conditions, + result_limit, + ) = select_statement_parser.parse_query() + + # Extract URLs to parse from WHERE conditions + urls_to_parse = [] + + for condition in where_conditions: + # Handle different condition formats + if isinstance(condition, list) and len(condition) >= 3: + op, column, value = condition[0], condition[1], condition[2] + elif hasattr(condition, "op") and hasattr(condition, "column"): + op = getattr(condition, "op", "=") + column = condition.column + value = getattr(condition, "value", None) + else: + # Skip malformed conditions + logger.warning(f"Skipping malformed condition: {condition}") + continue + + if column == "url" and op == "=" and isinstance(value, str): + urls_to_parse.append(value) + elif column == "url" and op == "in" and isinstance(value, list): + urls_to_parse.extend(value) + + if not urls_to_parse: + raise ValueError( + "Please specify URL(s) to parse using WHERE url = 'https://...' or WHERE url IN ('url1', 'url2')" + ) + + # Parse URLs and collect results + parsed_results = [] + + for url in urls_to_parse: + try: + response = self.handler.connection.parse_url(url) + if response.get("result") and response.get("item"): + parsed_item = response["item"] + parsed_item["parsed_url"] = url # Add original URL for reference + parsed_results.append(parsed_item) + else: + logger.warning(f"Failed to parse URL: {url}") + # Add empty result for failed parsing + parsed_results.append( + { + "parsed_url": url, + "title": None, + "excerpt": None, + "domain": None, + "type": None, + "cover": None, + "error": "Failed to parse URL", + } + ) + except Exception as e: + logger.error(f"Error parsing URL {url}: {e}") + # Add error result + parsed_results.append( + { + "parsed_url": url, + "title": None, + "excerpt": None, + "domain": None, + "type": None, + "cover": None, + "error": str(e), + } + ) + + # Convert to DataFrame + if parsed_results: + parse_df = pd.json_normalize(parsed_results) + parse_df = self._normalize_parse_data(parse_df) + else: + # Create empty DataFrame with all expected columns + parse_df = pd.DataFrame(columns=self.get_columns()) + + # Ensure all expected columns exist (defensive check) + expected_columns = self.get_columns() + for col in expected_columns: + if col not in parse_df.columns: + logger.warning(f"Missing column after normalization: {col}, adding as None") + parse_df[col] = None + + # Apply filtering and ordering using the executor + select_statement_executor = SELECTQueryExecutor( + parse_df, + selected_columns, + [], + order_by_conditions, # No additional filtering needed + ) + parse_df = select_statement_executor.execute_query() + + # Apply limit if needed + if result_limit and len(parse_df) > result_limit: + parse_df = parse_df.head(result_limit) + + return parse_df + + def insert(self, query: ast.Insert) -> None: + """ + URL parsing is a read-only operation. + Use INSERT on the raindrops table to create bookmarks from parsed URLs. + + Parameters + ---------- + query : ast.Insert + Given SQL INSERT query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + URL parsing is read-only + """ + raise NotImplementedError( + "URL parsing is a read-only operation. " + "Use INSERT on the raindrops table to create bookmarks from parsed URLs." + ) + + def update(self, query: ast.Update) -> None: + """ + URL parsing is a read-only operation. + + Parameters + ---------- + query : ast.Update + Given SQL UPDATE query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + URL parsing is read-only + """ + raise NotImplementedError("URL parsing is a read-only operation. Cannot update parsed URL metadata.") + + def delete(self, query: ast.Delete) -> None: + """ + URL parsing is a read-only operation. + + Parameters + ---------- + query : ast.Delete + Given SQL DELETE query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + URL parsing is read-only + """ + raise NotImplementedError("URL parsing is a read-only operation. Cannot delete parsed URL metadata.") + + def get_columns(self) -> List[str]: + """Get the column names for the parse table""" + return [ + "parsed_url", + "title", + "excerpt", + "domain", + "type", + "cover", + "media", + "lastUpdate", + "error", + ] + + def _normalize_parse_data(self, df: pd.DataFrame) -> pd.DataFrame: + """Normalize parsed URL data for consistent column structure""" + if df.empty: + return df + + # Convert dates + for date_col in ["lastUpdate"]: + try: + if date_col in df.columns: + df[date_col] = pd.to_datetime(df[date_col], errors="coerce") + except Exception as e: + logger.warning(f"Error processing date column {date_col}: {e}") + + # Ensure ALL expected columns exist, even if empty + expected_columns = self.get_columns() + for col in expected_columns: + if col not in df.columns: + df[col] = None + + return df + + +class BulkOperationsTable(APITable): + """The Raindrop.io Bulk Operations Table implementation for bulk move, update, and delete operations""" + + def select(self, query: ast.Select) -> pd.DataFrame: + """ + Bulk operations are not queryable. Use this table for bulk operations only. + + Parameters + ---------- + query : ast.Select + Given SQL SELECT query + + Returns + ------- + pd.DataFrame + Empty DataFrame with operation status information + + Raises + ------ + NotImplementedError + Bulk operations are not queryable + """ + raise NotImplementedError( + "Bulk operations table is not queryable. Use INSERT, UPDATE, or DELETE operations on this table for bulk operations." + ) + + def insert(self, query: ast.Insert) -> None: + """ + Bulk operations are initiated through UPDATE or DELETE operations, not INSERT. + + Parameters + ---------- + query : ast.Insert + Given SQL INSERT query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + Bulk operations use UPDATE/DELETE + """ + raise NotImplementedError( + "Use UPDATE operations on the raindrops table for bulk updates, or DELETE operations for bulk deletions." + ) + + def update(self, query: ast.Update) -> None: + """ + Perform bulk move operations between collections. + + Parameters + ---------- + query : ast.Update + Given SQL UPDATE query + + Returns + ------- + None + + Raises + ------ + ValueError + If the query contains invalid conditions + """ + update_statement_parser = UPDATEQueryParser(query) + values_to_update, where_conditions = update_statement_parser.parse_query() + + # Check if this is a move operation (has collection_id in update values) + if "collection_id" not in values_to_update: + raise ValueError("Bulk operations table only supports collection moves. Use 'collection_id' in SET clause.") + + target_collection_id = values_to_update["collection_id"] + + # Extract conditions for the move operation + source_collection_id = None + raindrop_ids = [] + search_query = None + + for condition in where_conditions: + if condition.column == "source_collection_id": + source_collection_id = condition.value + elif condition.column in ["_id", "id"]: + if isinstance(condition.value, list): + raindrop_ids.extend(condition.value) + else: + raindrop_ids.append(condition.value) + elif condition.column in ["search", "title"]: + search_query = condition.value + + # Validate that we have at least one condition + if not source_collection_id and not raindrop_ids and not search_query: + raise ValueError( + "Please specify source conditions using one of: source_collection_id = X, _id = Y, search = 'text'" + ) + + # Perform the bulk move operation + try: + result = self.handler.connection.move_raindrops_to_collection( + target_collection_id=target_collection_id, + source_collection_id=source_collection_id, + search=search_query, + ids=raindrop_ids if raindrop_ids else None, + ) + + if result.get("result"): + logger.info(f"Successfully moved raindrops to collection {target_collection_id}") + else: + logger.warning(f"Bulk move operation may have failed: {result}") + + except Exception as e: + logger.error(f"Failed to perform bulk move operation: {e}") + raise + + def delete(self, query: ast.Delete) -> None: + """ + Bulk delete operations are handled by the raindrops table. + Use DELETE on the raindrops table for bulk deletions. + + Parameters + ---------- + query : ast.Delete + Given SQL DELETE query + + Returns + ------- + None + + Raises + ------ + NotImplementedError + Bulk operations use raindrops table + """ + raise NotImplementedError("Use DELETE operations on the raindrops table for bulk deletions.") + + def get_columns(self) -> List[str]: + """Get the column names for the bulk operations table""" + return [ + "operation", + "status", + "affected_count", + "target_collection_id", + "source_collection_id", + "error", + ] diff --git a/mindsdb/integrations/handlers/raindrop_handler/requirements.txt b/mindsdb/integrations/handlers/raindrop_handler/requirements.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindsdb/integrations/handlers/raindrop_handler/tests/__init__.py b/mindsdb/integrations/handlers/raindrop_handler/tests/__init__.py new file mode 100644 index 00000000000..dd46c4f54c3 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/tests/__init__.py @@ -0,0 +1 @@ +# Raindrop.io handler tests diff --git a/mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_handler.py b/mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_handler.py new file mode 100644 index 00000000000..131cbf7d0c5 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_handler.py @@ -0,0 +1,2068 @@ +import unittest +from unittest.mock import Mock, patch +import pandas as pd + +from mindsdb.integrations.handlers.raindrop_handler.raindrop_handler import RaindropHandler, RaindropAPIClient +from mindsdb.integrations.handlers.raindrop_handler.raindrop_tables import ( + RaindropsTable, + CollectionsTable, + TagsTable, + ParseTable, + BulkOperationsTable, +) + + +class TestRaindropHandler(unittest.TestCase): + """Test cases for RaindropHandler""" + + def setUp(self): + self.handler = RaindropHandler("test_raindrop_handler") + self.handler.connection_data = {"api_key": "test_api_key"} + + def test_init(self): + """Test handler initialization""" + self.assertEqual(self.handler.name, "test_raindrop_handler") + self.assertFalse(self.handler.is_connected) + self.assertIn("raindrops", self.handler._tables) + self.assertIn("bookmarks", self.handler._tables) + self.assertIn("collections", self.handler._tables) + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.RaindropAPIClient") + def test_connect(self, mock_client): + """Test connection establishment""" + mock_instance = Mock() + mock_client.return_value = mock_instance + + result = self.handler.connect() + + mock_client.assert_called_once_with("test_api_key") + self.assertEqual(result, mock_instance) + self.assertTrue(self.handler.is_connected) + + def test_connect_missing_api_key(self): + """Test connection with missing API key""" + self.handler.connection_data = {} + + with self.assertRaises(ValueError) as context: + self.handler.connect() + + self.assertIn("API key is required", str(context.exception)) + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.RaindropAPIClient") + def test_check_connection_success(self, mock_client): + """Test successful connection check""" + mock_instance = Mock() + mock_instance.get_user_stats.return_value = {"result": True} + mock_client.return_value = mock_instance + + result = self.handler.check_connection() + + self.assertTrue(result.success) + self.assertTrue(self.handler.is_connected) + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.RaindropAPIClient") + def test_check_connection_failure(self, mock_client): + """Test failed connection check""" + mock_instance = Mock() + mock_instance.get_user_stats.return_value = {"result": False} + mock_client.return_value = mock_instance + + result = self.handler.check_connection() + + self.assertFalse(result.success) + self.assertFalse(self.handler.is_connected) + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.RaindropAPIClient") + def test_check_connection_exception(self, mock_client): + """Test connection check with exception""" + mock_instance = Mock() + mock_instance.get_user_stats.side_effect = Exception("Connection error") + mock_client.return_value = mock_instance + + result = self.handler.check_connection() + + self.assertFalse(result.success) + self.assertIn("Connection error", result.error_message) + + +class TestRaindropAPIClient(unittest.TestCase): + """Test cases for RaindropAPIClient""" + + def setUp(self): + self.client = RaindropAPIClient("test_api_key") + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.requests") + def test_make_request_get(self, mock_requests): + """Test GET request""" + mock_response = Mock() + mock_response.json.return_value = {"result": True, "items": []} + mock_response.raise_for_status.return_value = None + mock_requests.request.return_value = mock_response + + result = self.client._make_request("GET", "/user/stats") + + mock_requests.request.assert_called_once_with( + method="GET", + url="https://api.raindrop.io/rest/v1/user/stats", + headers={"Authorization": "Bearer test_api_key", "Content-Type": "application/json"}, + params=None, + json=None, + ) + self.assertEqual(result, {"result": True, "items": []}) + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.requests.request") + def test_rate_limiting(self, mock_request): + """Test that rate limiting works correctly""" + import time + + # Mock response + mock_response = Mock() + mock_response.json.return_value = {"result": True, "items": []} + mock_response.raise_for_status.return_value = None + mock_request.return_value = mock_response + + # Reset request times to ensure clean state + self.client.request_times = [] + + # Make multiple rapid requests + start_time = time.time() + for i in range(3): + self.client._make_request("GET", "/user/stats") + end_time = time.time() + + # Should take at least 1 second due to rate limiting (2 requests/second limit) + total_time = end_time - start_time + self.assertGreaterEqual(total_time, 1.0, "Rate limiting should add delays between requests") + + # Should have tracked the requests (rate limiter may clean up old entries) + self.assertGreaterEqual(len(self.client.request_times), 1, "Should track at least the most recent request") + + @patch.object(RaindropAPIClient, "_make_request") + def test_get_raindrops_optimized_pagination(self, mock_request): + """Test that get_raindrops optimizes page sizes based on LIMIT""" + # Mock response with items + mock_response = {"result": True, "items": [{"_id": 1, "title": "Test"}] * 5, "count": 5} + mock_request.return_value = mock_response + + # Test small LIMIT - should use smaller page size + result = self.client.get_raindrops(max_results=5) + self.assertEqual(len(result["items"]), 5) + + # Verify the request was made with optimized page size + args, kwargs = mock_request.call_args + self.assertEqual(kwargs["params"]["perpage"], 5, "Should use small page size for small LIMIT") + + # Reset mock + mock_request.reset_mock() + + # Test larger LIMIT - should use larger page size + result = self.client.get_raindrops(max_results=100) + args, kwargs = mock_request.call_args + self.assertEqual(kwargs["params"]["perpage"], 50, "Should use larger page size for bigger LIMIT") + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.requests") + def test_make_request_post(self, mock_requests): + """Test POST request with data""" + mock_response = Mock() + mock_response.json.return_value = {"result": True, "item": {}} + mock_response.raise_for_status.return_value = None + mock_requests.request.return_value = mock_response + + test_data = {"title": "Test"} + self.client._make_request("POST", "/raindrop", data=test_data) + + mock_requests.request.assert_called_once_with( + method="POST", + url="https://api.raindrop.io/rest/v1/raindrop", + headers={"Authorization": "Bearer test_api_key", "Content-Type": "application/json"}, + params=None, + json=test_data, + ) + + @patch.object(RaindropAPIClient, "_make_request") + def test_get_raindrops(self, mock_request): + """Test get_raindrops method""" + mock_request.return_value = {"result": True, "items": []} + + self.client.get_raindrops(collection_id=123, search="test", page=1) + + mock_request.assert_called_once_with( + "GET", "/raindrops/123", params={"page": 1, "perpage": 50, "search": "test"} + ) + + @patch.object(RaindropAPIClient, "_make_request") + def test_create_raindrop(self, mock_request): + """Test create_raindrop method""" + mock_request.return_value = {"result": True, "item": {}} + + raindrop_data = {"link": "https://example.com", "title": "Test"} + self.client.create_raindrop(raindrop_data) + + mock_request.assert_called_once_with("POST", "/raindrop", data=raindrop_data) + + @patch.object(RaindropAPIClient, "_make_request") + def test_update_raindrop(self, mock_request): + """Test update_raindrop method""" + mock_request.return_value = {"result": True, "item": {}} + + raindrop_data = {"title": "Updated Title"} + self.client.update_raindrop(123, raindrop_data) + + mock_request.assert_called_once_with("PUT", "/raindrop/123", data=raindrop_data) + + @patch.object(RaindropAPIClient, "_make_request") + def test_delete_raindrop(self, mock_request): + """Test delete_raindrop method""" + mock_request.return_value = {"result": True} + + self.client.delete_raindrop(123) + + mock_request.assert_called_once_with("DELETE", "/raindrop/123") + + def test_make_request_invalid_endpoint(self): + """Test that invalid endpoints are rejected""" + with self.assertRaises(ValueError) as context: + self.client._make_request("GET", "/invalid/endpoint") + + self.assertIn("Invalid endpoint", str(context.exception)) + self.assertIn("Only Raindrop.io API endpoints are allowed", str(context.exception)) + + def test_make_request_path_traversal_attempt(self): + """Test that path traversal attempts are rejected""" + with self.assertRaises(ValueError) as context: + self.client._make_request("GET", "../../../etc/passwd") + + self.assertIn("Invalid endpoint", str(context.exception)) + + +class TestRaindropsTable(unittest.TestCase): + """Test cases for RaindropsTable""" + + def setUp(self): + self.handler = Mock() + self.handler.connection = Mock() + self.table = RaindropsTable(self.handler) + + def test_apply_local_filters_greater_than(self): + """Test _apply_local_filters with greater than operator""" + test_data = pd.DataFrame( + [ + {"_id": 1, "created": "2024-01-01T00:00:00Z", "sort": 10}, + {"_id": 2, "created": "2024-01-15T00:00:00Z", "sort": 20}, + {"_id": 3, "created": "2024-01-30T00:00:00Z", "sort": 30}, + ] + ) + + # Test date comparison + conditions = [[">", "created", "2024-01-15T00:00:00Z"]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 1) + self.assertEqual(result["_id"].iloc[0], 3) + + # Test numeric comparison + conditions = [[">", "sort", 15]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 2) + self.assertListEqual(result["_id"].tolist(), [2, 3]) + + def test_apply_local_filters_less_than_equal(self): + """Test _apply_local_filters with less than or equal operator""" + test_data = pd.DataFrame( + [ + {"_id": 1, "sort": 10}, + {"_id": 2, "sort": 20}, + {"_id": 3, "sort": 30}, + ] + ) + + conditions = [["<=", "sort", 20]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 2) + self.assertListEqual(result["_id"].tolist(), [1, 2]) + + def test_apply_local_filters_between(self): + """Test _apply_local_filters with BETWEEN operator""" + test_data = pd.DataFrame( + [ + {"_id": 1, "created": "2024-01-01T00:00:00Z"}, + {"_id": 2, "created": "2024-01-15T00:00:00Z"}, + {"_id": 3, "created": "2024-01-30T00:00:00Z"}, + ] + ) + + conditions = [["between", "created", ("2024-01-05T00:00:00Z", "2024-01-25T00:00:00Z")]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 1) + self.assertEqual(result["_id"].iloc[0], 2) + + def test_apply_local_filters_like(self): + """Test _apply_local_filters with LIKE operator""" + test_data = pd.DataFrame( + [ + {"_id": 1, "title": "Python Tutorial"}, + {"_id": 2, "title": "JavaScript Guide"}, + {"_id": 3, "title": "Python Best Practices"}, + ] + ) + + conditions = [["like", "title", "%Python%"]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 2) + self.assertListEqual(result["_id"].tolist(), [1, 3]) + + def test_apply_local_filters_in(self): + """Test _apply_local_filters with IN operator""" + test_data = pd.DataFrame( + [ + {"_id": 1, "tags": "python,javascript"}, + {"_id": 2, "tags": "java,ruby"}, + {"_id": 3, "tags": "python,django"}, + ] + ) + + conditions = [["in", "_id", [1, 3]]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 2) + self.assertListEqual(result["_id"].tolist(), [1, 3]) + + def test_apply_local_filters_important_flag(self): + """Test _apply_local_filters with important flag""" + test_data = pd.DataFrame( + [ + {"_id": 1, "important": True}, + {"_id": 2, "important": False}, + {"_id": 3, "important": True}, + ] + ) + + conditions = [["=", "important", True]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 2) + self.assertListEqual(result["_id"].tolist(), [1, 3]) + + def test_apply_local_filters_multiple_conditions(self): + """Test _apply_local_filters with multiple conditions""" + test_data = pd.DataFrame( + [ + {"_id": 1, "important": True, "sort": 10}, + {"_id": 2, "important": False, "sort": 20}, + {"_id": 3, "important": True, "sort": 30}, + ] + ) + + conditions = [["=", "important", True], [">", "sort", 15]] + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertEqual(len(result), 1) + self.assertEqual(result["_id"].iloc[0], 3) + + def test_apply_local_filters_unsupported_operator(self): + """Test _apply_local_filters with unsupported operator""" + test_data = pd.DataFrame( + [ + {"_id": 1, "title": "Test"}, + ] + ) + + conditions = [["regex", "title", ".*"]] + with self.assertLogs(level="WARNING") as log: + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertIn("Unsupported operator 'regex'", log.output[0]) + self.assertEqual(len(result), 1) # Original data should be returned + + def test_apply_local_filters_missing_column(self): + """Test _apply_local_filters with missing column""" + test_data = pd.DataFrame( + [ + {"_id": 1, "title": "Test"}, + ] + ) + + conditions = [["=", "missing_column", "value"]] + with self.assertLogs(level="WARNING") as log: + result = self.table._apply_local_filters(test_data.copy(), conditions) + self.assertIn("Column 'missing_column' not found", log.output[0]) + self.assertEqual(len(result), 1) # Original data should be returned + + def test_apply_ordering(self): + """Test _apply_ordering method""" + test_data = pd.DataFrame( + [ + {"_id": 1, "sort": 30, "title": "Z Title"}, + {"_id": 2, "sort": 10, "title": "A Title"}, + {"_id": 3, "sort": 20, "title": "B Title"}, + ] + ) + + # Mock order by conditions + order_conditions = [ + type("MockOrder", (), {"column": "sort", "ascending": True})(), + ] + + result = self.table._apply_ordering(test_data.copy(), order_conditions) + self.assertEqual(result["_id"].tolist(), [2, 3, 1]) # Sorted by sort ascending + + def test_apply_ordering_descending(self): + """Test _apply_ordering method with descending order""" + test_data = pd.DataFrame( + [ + {"_id": 1, "sort": 10}, + {"_id": 2, "sort": 30}, + {"_id": 3, "sort": 20}, + ] + ) + + # Mock order by conditions + order_conditions = [ + type("MockOrder", (), {"column": "sort", "ascending": False})(), + ] + + result = self.table._apply_ordering(test_data.copy(), order_conditions) + self.assertEqual(result["_id"].tolist(), [2, 3, 1]) # Sorted by sort descending + + def test_apply_ordering_multiple_columns(self): + """Test _apply_ordering method with multiple columns""" + test_data = pd.DataFrame( + [ + {"_id": 1, "sort": 10, "title": "B"}, + {"_id": 2, "sort": 20, "title": "A"}, + {"_id": 3, "sort": 10, "title": "A"}, + ] + ) + + # Mock order by conditions + order_conditions = [ + type("MockOrder", (), {"column": "sort", "ascending": True})(), + type("MockOrder", (), {"column": "title", "ascending": True})(), + ] + + result = self.table._apply_ordering(test_data.copy(), order_conditions) + self.assertEqual(result["_id"].tolist(), [3, 1, 2]) # Sort by sort then title + + def test_get_columns(self): + """Test get_columns method""" + columns = self.table.get_columns() + + expected_columns = [ + "_id", + "link", + "title", + "excerpt", + "note", + "type", + "cover", + "tags", + "important", + "reminder", + "removed", + "created", + "lastUpdate", + "domain", + "collection.id", + "collection.title", + "user.id", + "broken", + "cache", + "file.name", + "file.size", + "file.type", + ] + + self.assertEqual(columns, expected_columns) + + def test_normalize_raindrop_data(self): + """Test _normalize_raindrop_data method""" + test_data = pd.DataFrame( + [ + { + "_id": 123, + "title": "Test", + "collection": {"$id": 456, "title": "Test Collection"}, + "user": {"$id": 789}, + "file": {"name": "test.pdf", "size": 1024, "type": "pdf"}, + "tags": ["tag1", "tag2"], + "created": "2024-01-01T00:00:00Z", + "lastUpdate": "2024-01-02T00:00:00Z", + } + ] + ) + + result = self.table._normalize_raindrop_data(test_data) + + self.assertEqual(result["collection.id"].iloc[0], 456) + self.assertEqual(result["collection.title"].iloc[0], "Test Collection") + self.assertEqual(result["user.id"].iloc[0], 789) + self.assertEqual(result["file.name"].iloc[0], "test.pdf") + self.assertEqual(result["file.size"].iloc[0], 1024) + self.assertEqual(result["file.type"].iloc[0], "pdf") + self.assertEqual(result["tags"].iloc[0], "tag1,tag2") + + def test_prepare_raindrop_data(self): + """Test _prepare_raindrop_data method""" + input_data = { + "link": "https://example.com", + "title": "Test", + "collection_id": 123, + "tags": "tag1,tag2", + "important": True, + } + + result = self.table._prepare_raindrop_data(input_data) + + expected = { + "link": "https://example.com", + "title": "Test", + "collection": {"$id": 123}, + "tags": ["tag1", "tag2"], + "important": True, + } + + self.assertEqual(result, expected) + + def test_prepare_raindrop_data_with_list_tags(self): + """Test _prepare_raindrop_data method with list tags""" + input_data = {"link": "https://example.com", "tags": ["tag1", "tag2"]} + + result = self.table._prepare_raindrop_data(input_data) + + self.assertEqual(result["tags"], ["tag1", "tag2"]) + + def test_normalize_raindrop_data_missing_columns(self): + """Test _normalize_raindrop_data method with missing columns""" + # Test with minimal data that might come from API (missing some nested fields) + test_data = pd.DataFrame( + [ + { + "_id": 123, + "title": "Test Bookmark", + "link": "https://example.com", + "created": "2024-01-01T00:00:00Z", + # Note: missing collection, user, file, tags fields + } + ] + ) + + result = self.table._normalize_raindrop_data(test_data) + + # Check that all expected columns exist + expected_columns = self.table.get_columns() + for col in expected_columns: + self.assertIn(col, result.columns, f"Missing column: {col}") + + # Check that missing nested fields are handled gracefully + self.assertIsNone(result["collection.id"].iloc[0]) + self.assertIsNone(result["collection.title"].iloc[0]) + self.assertIsNone(result["user.id"].iloc[0]) + self.assertIsNone(result["file.name"].iloc[0]) + self.assertIsNone(result["tags"].iloc[0]) + + def test_normalize_raindrop_data_empty_dataframe(self): + """Test _normalize_raindrop_data method with empty DataFrame""" + empty_df = pd.DataFrame() + + result = self.table._normalize_raindrop_data(empty_df) + + # Should return the same empty DataFrame + self.assertTrue(result.empty) + + def test_select_with_empty_data(self): + """Test select method with empty data from API""" + # Mock empty response from API + self.handler.connection.get_raindrops.return_value = {"result": True, "items": []} + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ([], [], [], None) + mock_parser.return_value = mock_parser_instance + + # Mock executor to return DataFrame with columns (as it should after our fix) + mock_executor_instance = Mock() + empty_df_with_columns = pd.DataFrame(columns=self.table.get_columns()) + mock_executor_instance.execute_query.return_value = empty_df_with_columns + mock_executor.return_value = mock_executor_instance + + query = Mock() + result = self.table.select(query) + + # Should return DataFrame with all expected columns + expected_columns = self.table.get_columns() + for col in expected_columns: + self.assertIn(col, result.columns, f"Missing column in empty result: {col}") + + # Should be empty but have all columns + self.assertTrue(result.empty) + self.assertEqual(len(result.columns), len(expected_columns)) + + def test_select_optimized_for_limit(self): + """Test that SELECT with LIMIT uses optimized pagination""" + # Mock empty response from API (no items) + self.handler.connection.get_raindrops.return_value = {"result": True, "items": []} + + # Mock the SELECT query components with LIMIT 3 + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ([], [], [], 3) # LIMIT 3 + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + empty_df_with_columns = pd.DataFrame(columns=self.table.get_columns()) + mock_executor_instance.execute_query.return_value = empty_df_with_columns + mock_executor.return_value = mock_executor_instance + + query = Mock() + self.table.select(query) + + # Verify that get_raindrops was called with max_results=3 for optimization + self.handler.connection.get_raindrops.assert_called_once() + args, kwargs = self.handler.connection.get_raindrops.call_args + self.assertEqual(kwargs.get("max_results"), 3, "Should pass LIMIT to API for optimization") + + def test_normalize_raindrop_data_partial_nested_data(self): + """Test _normalize_raindrop_data with partial nested data""" + test_data = pd.DataFrame( + [ + { + "_id": 123, + "title": "Test", + "collection": {"$id": 456}, # Missing title + "user": None, # Explicitly None + "tags": [], # Empty list + "created": "2024-01-01T00:00:00Z", + } + ] + ) + + result = self.table._normalize_raindrop_data(test_data) + + # Check that partial data is handled correctly + self.assertEqual(result["collection.id"].iloc[0], 456) + self.assertIsNone(result["collection.title"].iloc[0]) # Missing field + self.assertIsNone(result["user.id"].iloc[0]) # None user + self.assertEqual(result["tags"].iloc[0], "") # Empty list becomes empty string + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor") + def test_select_basic(self, mock_executor, mock_parser): + """Test basic select operation""" + # Mock parser + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ([], [], [], None) + mock_parser.return_value = mock_parser_instance + + # Mock executor + mock_executor_instance = Mock() + mock_executor_instance.execute_query.return_value = pd.DataFrame() + mock_executor.return_value = mock_executor_instance + + # Mock API response + self.handler.connection.get_raindrops.return_value = {"result": True, "items": [{"_id": 123, "title": "Test"}]} + + query = Mock() + result = self.table.select(query) + + self.assertIsInstance(result, pd.DataFrame) + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.INSERTQueryParser") + def test_insert_single(self, mock_parser): + """Test single insert operation""" + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = {"link": "https://example.com", "title": "Test"} + mock_parser.return_value = mock_parser_instance + + self.handler.connection.create_raindrop.return_value = {"result": True} + + query = Mock() + self.table.insert(query) + + self.handler.connection.create_raindrop.assert_called_once() + + @patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.INSERTQueryParser") + def test_insert_multiple(self, mock_parser): + """Test multiple insert operation""" + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = [ + {"link": "https://example1.com", "title": "Test1"}, + {"link": "https://example2.com", "title": "Test2"}, + ] + mock_parser.return_value = mock_parser_instance + + self.handler.connection.create_multiple_raindrops.return_value = {"result": True} + + query = Mock() + self.table.insert(query) + + self.handler.connection.create_multiple_raindrops.assert_called_once() + + +class TestCollectionsTable(unittest.TestCase): + """Test cases for CollectionsTable""" + + def setUp(self): + self.handler = Mock() + self.handler.connection = Mock() + self.table = CollectionsTable(self.handler) + + def test_get_columns(self): + """Test get_columns method""" + columns = self.table.get_columns() + + expected_columns = [ + "_id", + "title", + "description", + "color", + "view", + "public", + "sort", + "count", + "created", + "lastUpdate", + "expanded", + "parent.id", + "user.id", + "cover", + "access.level", + "access.draggable", + ] + + self.assertEqual(columns, expected_columns) + + def test_normalize_collection_data(self): + """Test _normalize_collection_data method""" + test_data = pd.DataFrame( + [ + { + "_id": 123, + "title": "Test Collection", + "parent": {"$id": 456}, + "user": {"$id": 789}, + "access": {"level": 4, "draggable": True}, + "cover": ["https://example.com/cover.jpg"], + "created": "2024-01-01T00:00:00Z", + "lastUpdate": "2024-01-02T00:00:00Z", + } + ] + ) + + result = self.table._normalize_collection_data(test_data) + + self.assertEqual(result["parent.id"].iloc[0], 456) + self.assertEqual(result["user.id"].iloc[0], 789) + self.assertEqual(result["access.level"].iloc[0], 4) + self.assertEqual(result["access.draggable"].iloc[0], True) + self.assertEqual(result["cover"].iloc[0], "https://example.com/cover.jpg") + + def test_normalize_collection_data_missing_columns(self): + """Test _normalize_collection_data method with missing columns""" + # Test with minimal data that might come from API (missing some nested fields) + test_data = pd.DataFrame( + [ + { + "_id": 123, + "title": "Test Collection", + "created": "2024-01-01T00:00:00Z", + # Note: missing parent, user, access, cover fields + } + ] + ) + + result = self.table._normalize_collection_data(test_data) + + # Check that all expected columns exist + expected_columns = self.table.get_columns() + for col in expected_columns: + self.assertIn(col, result.columns, f"Missing column: {col}") + + # Check that missing nested fields are handled gracefully + self.assertIsNone(result["parent.id"].iloc[0]) + self.assertIsNone(result["user.id"].iloc[0]) + self.assertIsNone(result["access.level"].iloc[0]) + self.assertIsNone(result["access.draggable"].iloc[0]) + self.assertIsNone(result["cover"].iloc[0]) + + def test_normalize_collection_data_empty_dataframe(self): + """Test _normalize_collection_data method with empty DataFrame""" + empty_df = pd.DataFrame() + + result = self.table._normalize_collection_data(empty_df) + + # Should return the same empty DataFrame + self.assertTrue(result.empty) + + def test_normalize_collection_data_partial_nested_data(self): + """Test _normalize_collection_data with partial nested data""" + test_data = pd.DataFrame( + [ + { + "_id": 123, + "title": "Test Collection", + "parent": {"$id": 456}, # Missing other parent fields + "user": None, # Explicitly None + "access": {"level": 4}, # Missing draggable + "created": "2024-01-01T00:00:00Z", + } + ] + ) + + result = self.table._normalize_collection_data(test_data) + + # Check that partial data is handled correctly + self.assertEqual(result["parent.id"].iloc[0], 456) + self.assertIsNone(result["user.id"].iloc[0]) # None user + self.assertEqual(result["access.level"].iloc[0], 4) + self.assertIsNone(result["access.draggable"].iloc[0]) # Missing field + + def test_prepare_collection_data(self): + """Test _prepare_collection_data method""" + input_data = { + "title": "Test Collection", + "description": "Test Description", + "color": "#FF0000", + "public": True, + "parent_id": 123, + } + + result = self.table._prepare_collection_data(input_data) + + expected = { + "title": "Test Collection", + "description": "Test Description", + "color": "#FF0000", + "public": True, + "parent": {"$id": 123}, + } + + self.assertEqual(result, expected) + + def test_get_collections(self): + """Test get_collections method""" + # Mock get_collections to return both root and child collections + self.handler.connection.get_collections.return_value = { + "result": True, + "items": [{"_id": 123, "title": "Root Collection"}, {"_id": 456, "title": "Child Collection"}], + } + + result = self.table.get_collections() + + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["_id"], 123) + self.assertEqual(result[1]["_id"], 456) + + def test_select_with_simple_filters(self): + """Test select method with simple WHERE clause conditions for collections""" + # Mock response with sample collection data + sample_data = [ + {"_id": 123, "title": "Work Collection", "public": True}, + {"_id": 456, "title": "Personal Collection", "public": False}, + ] + + # Mock the API responses + self.handler.connection.get_collections.return_value = {"result": True, "items": sample_data} + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["_id", "title"], # selected_columns + [["=", "public", True]], # where_conditions - corrected format + [], # order_by_conditions + 10, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + filtered_df = pd.DataFrame([{"_id": 123, "title": "Work Collection", "public": True}]) + mock_executor_instance.execute_query.return_value = filtered_df + mock_executor.return_value = mock_executor_instance + + query = Mock() + result = self.table.select(query) + + # Should filter to only public collections + self.assertEqual(len(result), 1) + self.assertEqual(result["_id"].iloc[0], 123) + + def test_select_with_title_filter(self): + """Test select method with title filtering for collections""" + # Mock response with sample collection data + sample_data = [ + {"_id": 123, "title": "Work Collection"}, + {"_id": 456, "title": "Personal Collection"}, + ] + + # Mock the API responses + self.handler.connection.get_collections.return_value = {"result": True, "items": sample_data} + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["_id", "title"], # selected_columns + [["like", "title", "%Work%"]], # where_conditions - corrected format + [], # order_by_conditions + None, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + filtered_df = pd.DataFrame([{"_id": 123, "title": "Work Collection"}]) + mock_executor_instance.execute_query.return_value = filtered_df + mock_executor.return_value = mock_executor_instance + + query = Mock() + result = self.table.select(query) + + # Should filter to collections with "Work" in title + self.assertEqual(len(result), 1) + self.assertEqual(result["title"].iloc[0], "Work Collection") + + +class TestTagsTable(unittest.TestCase): + """Test cases for TagsTable""" + + def setUp(self): + self.handler = Mock() + self.handler.connection = Mock() + self.table = TagsTable(self.handler) + + def test_get_columns(self): + """Test get_columns method""" + columns = self.table.get_columns() + + expected_columns = [ + "_id", + "label", + "count", + "created", + "lastUpdate", + ] + + self.assertEqual(columns, expected_columns) + + def test_select_basic(self): + """Test basic select operation""" + # Mock API response + sample_data = [ + { + "_id": "tag1", + "label": "Python", + "count": 15, + "created": "2024-01-01T00:00:00Z", + "lastUpdate": "2024-01-01T00:00:00Z", + }, + { + "_id": "tag2", + "label": "JavaScript", + "count": 8, + "created": "2024-02-01T00:00:00Z", + "lastUpdate": "2024-02-01T00:00:00Z", + }, + { + "_id": "tag3", + "label": "Machine Learning", + "count": 3, + "created": "2024-03-01T00:00:00Z", + "lastUpdate": "2024-03-01T00:00:00Z", + }, + ] + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ([], [], [], None) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + # Create a DataFrame with the sample data for the executor to return + sample_df = pd.DataFrame(sample_data) + mock_executor_instance.execute_query.return_value = sample_df + mock_executor.return_value = mock_executor_instance + + # Mock the handler connection's get_tags method + self.handler.connection.get_tags.return_value = {"items": sample_data} + + query = Mock() + result = self.table.select(query) + + # Should return DataFrame with all expected columns + expected_columns = self.table.get_columns() + for col in expected_columns: + self.assertIn(col, result.columns, f"Missing column: {col}") + + # Should have the sample data + self.assertEqual(len(result), 3) + self.assertListEqual(result["label"].tolist(), ["Python", "JavaScript", "Machine Learning"]) + + def test_select_with_filters(self): + """Test select with filtering""" + sample_data = [ + { + "_id": "tag1", + "label": "Python", + "count": 15, + "created": "2024-01-01T00:00:00Z", + "lastUpdate": "2024-01-01T00:00:00Z", + }, + { + "_id": "tag2", + "label": "JavaScript", + "count": 8, + "created": "2024-02-01T00:00:00Z", + "lastUpdate": "2024-02-01T00:00:00Z", + }, + ] + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["label", "count"], # selected_columns + [["=", "count", 15]], # where_conditions + [], # order_by_conditions + None, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + filtered_df = pd.DataFrame([{"label": "Python", "count": 15}]) + mock_executor_instance.execute_query.return_value = filtered_df + mock_executor.return_value = mock_executor_instance + + # Mock the handler connection's get_tags method + self.handler.connection.get_tags.return_value = {"items": sample_data} + + query = Mock() + result = self.table.select(query) + + # Should filter to tags with count = 15 + self.assertEqual(len(result), 1) + self.assertEqual(result["label"].iloc[0], "Python") + + def test_select_with_limit(self): + """Test select with LIMIT clause""" + sample_data = [ + { + "_id": "tag1", + "label": "Python", + "count": 15, + "created": "2024-01-01T00:00:00Z", + "lastUpdate": "2024-01-01T00:00:00Z", + }, + { + "_id": "tag2", + "label": "JavaScript", + "count": 8, + "created": "2024-02-01T00:00:00Z", + "lastUpdate": "2024-02-01T00:00:00Z", + }, + { + "_id": "tag3", + "label": "Machine Learning", + "count": 3, + "created": "2024-03-01T00:00:00Z", + "lastUpdate": "2024-03-01T00:00:00Z", + }, + ] + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["label", "count"], # selected_columns + [], # where_conditions + [], # order_by_conditions + 2, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + limited_df = pd.DataFrame([{"label": "Python", "count": 15}, {"label": "JavaScript", "count": 8}]) + mock_executor_instance.execute_query.return_value = limited_df + mock_executor.return_value = mock_executor_instance + + # Mock the handler connection's get_tags method + self.handler.connection.get_tags.return_value = {"items": sample_data} + + query = Mock() + result = self.table.select(query) + + # Should limit to 2 results + self.assertEqual(len(result), 2) + self.assertListEqual(result["label"].tolist(), ["Python", "JavaScript"]) + + def test_insert_not_supported(self): + """Test that insert operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.insert(query) + + self.assertIn("Direct tag creation is not supported", str(context.exception)) + self.assertIn("Raindrop.io API", str(context.exception)) + + def test_update_not_supported(self): + """Test that update operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.update(query) + + self.assertIn("Direct tag updates are not supported", str(context.exception)) + self.assertIn("Raindrop.io API", str(context.exception)) + + def test_delete_not_supported(self): + """Test that delete operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.delete(query) + + self.assertIn("Tag deletion is not supported", str(context.exception)) + self.assertIn("Raindrop.io API", str(context.exception)) + + def test_normalize_tags_data(self): + """Test _normalize_tags_data method""" + test_data = pd.DataFrame( + [ + { + "_id": "tag1", + "label": "Python", + "count": 15, + "created": "2024-01-01T00:00:00Z", + "lastUpdate": "2024-01-02T00:00:00Z", + } + ] + ) + + result = self.table._normalize_tags_data(test_data) + + # Check that dates are converted to datetime + self.assertEqual(result["label"].iloc[0], "Python") + self.assertEqual(result["count"].iloc[0], 15) + # Note: Date conversion would require pandas datetime conversion, checking basic structure + self.assertIn("_id", result.columns) + self.assertIn("label", result.columns) + self.assertIn("count", result.columns) + self.assertIn("created", result.columns) + self.assertIn("lastUpdate", result.columns) + + def test_normalize_tags_data_empty(self): + """Test _normalize_tags_data with empty DataFrame""" + empty_df = pd.DataFrame() + + result = self.table._normalize_tags_data(empty_df) + + # Should return the same empty DataFrame + self.assertTrue(result.empty) + + def test_get_tags_calls_api(self): + """Test that get_tags calls the API correctly""" + expected_response = { + "items": [ + { + "_id": "tag1", + "label": "Python", + "count": 15, + "created": "2024-01-01T00:00:00Z", + "lastUpdate": "2024-01-01T00:00:00Z", + }, + { + "_id": "tag2", + "label": "JavaScript", + "count": 8, + "created": "2024-02-01T00:00:00Z", + "lastUpdate": "2024-02-01T00:00:00Z", + }, + ] + } + + self.handler.connection.get_tags.return_value = expected_response + + result = self.table.get_tags() + + # Should have called get_tags on the connection + self.handler.connection.get_tags.assert_called_once() + # Should return the items from the response + self.assertEqual(result, expected_response["items"]) + + +class TestParseTable(unittest.TestCase): + """Test cases for ParseTable""" + + def setUp(self): + self.handler = Mock() + self.handler.connection = Mock() + self.table = ParseTable(self.handler) + + def test_get_columns(self): + """Test get_columns method""" + columns = self.table.get_columns() + + expected_columns = [ + "parsed_url", + "title", + "excerpt", + "domain", + "type", + "cover", + "media", + "lastUpdate", + "error", + ] + + self.assertEqual(columns, expected_columns) + + def test_select_single_url(self): + """Test select with single URL to parse""" + # Mock API response + mock_parsed_data = { + "title": "Test Article", + "excerpt": "This is a test article excerpt", + "domain": "example.com", + "type": "article", + "cover": "https://example.com/cover.jpg", + "media": [{"link": "https://example.com/image.jpg"}], + "lastUpdate": "2024-01-01T00:00:00Z", + } + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["parsed_url", "title", "excerpt"], # selected_columns + [["=", "url", "https://example.com/test"]], # where_conditions + [], # order_by_conditions + None, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + # Create DataFrame with expected parsed data + expected_df = pd.DataFrame( + [ + { + "parsed_url": "https://example.com/test", + "title": "Test Article", + "excerpt": "This is a test article excerpt", + "domain": "example.com", + "type": "article", + "cover": "https://example.com/cover.jpg", + "media": [{"link": "https://example.com/image.jpg"}], + "lastUpdate": "2024-01-01T00:00:00Z", + "error": None, + } + ] + ) + mock_executor_instance.execute_query.return_value = expected_df + mock_executor.return_value = mock_executor_instance + + # Mock the API call + self.handler.connection.parse_url.return_value = {"result": True, "item": mock_parsed_data} + + query = Mock() + result = self.table.select(query) + + # Verify API was called with correct URL + self.handler.connection.parse_url.assert_called_once_with("https://example.com/test") + + # Should return DataFrame with parsed data + self.assertEqual(len(result), 1) + self.assertEqual(result["parsed_url"].iloc[0], "https://example.com/test") + self.assertEqual(result["title"].iloc[0], "Test Article") + + def test_select_multiple_urls(self): + """Test select with multiple URLs using IN operator""" + urls = ["https://example1.com", "https://example2.com"] + + # Mock API responses for each URL + mock_responses = [ + { + "result": True, + "item": {"title": "Article 1", "excerpt": "Excerpt 1", "domain": "example1.com", "type": "article"}, + }, + { + "result": True, + "item": {"title": "Article 2", "excerpt": "Excerpt 2", "domain": "example2.com", "type": "article"}, + }, + ] + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["parsed_url", "title"], # selected_columns + [["in", "url", urls]], # where_conditions + [], # order_by_conditions + None, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + expected_df = pd.DataFrame( + [ + {"parsed_url": "https://example1.com", "title": "Article 1", "error": None}, + {"parsed_url": "https://example2.com", "title": "Article 2", "error": None}, + ] + ) + mock_executor_instance.execute_query.return_value = expected_df + mock_executor.return_value = mock_executor_instance + + # Mock the API calls + self.handler.connection.parse_url.side_effect = mock_responses + + query = Mock() + result = self.table.select(query) + + # Verify API was called for each URL + self.assertEqual(self.handler.connection.parse_url.call_count, 2) + calls = self.handler.connection.parse_url.call_args_list + self.assertEqual(calls[0][0][0], "https://example1.com") + self.assertEqual(calls[1][0][0], "https://example2.com") + + # Should return DataFrame with both parsed URLs + self.assertEqual(len(result), 2) + + def test_select_no_url_specified(self): + """Test select without URL specification raises error""" + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["parsed_url", "title"], # selected_columns + [], # where_conditions - no URL specified + [], # order_by_conditions + None, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + query = Mock() + with self.assertRaises(ValueError) as context: + self.table.select(query) + + self.assertIn("Please specify URL(s) to parse", str(context.exception)) + self.assertIn("WHERE url =", str(context.exception)) + + def test_select_api_error_handling(self): + """Test select handles API errors gracefully""" + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["parsed_url", "title", "error"], # selected_columns + [["=", "url", "https://invalid-url.com"]], # where_conditions + [], # order_by_conditions + None, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + expected_df = pd.DataFrame([{"parsed_url": "https://invalid-url.com", "title": None, "error": "API Error"}]) + mock_executor_instance.execute_query.return_value = expected_df + mock_executor.return_value = mock_executor_instance + + # Mock API to raise exception + self.handler.connection.parse_url.side_effect = Exception("API Error") + + query = Mock() + result = self.table.select(query) + + # Should handle error gracefully and return error info + self.assertEqual(len(result), 1) + self.assertEqual(result["parsed_url"].iloc[0], "https://invalid-url.com") + self.assertEqual(result["error"].iloc[0], "API Error") + + def test_select_with_limit(self): + """Test select with LIMIT clause""" + urls = ["https://example1.com", "https://example2.com", "https://example3.com"] + + # Mock the SELECT query components + with ( + patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryParser") as mock_parser, + patch( + "mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.SELECTQueryExecutor" + ) as mock_executor, + ): + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + ["parsed_url", "title"], # selected_columns + [["in", "url", urls]], # where_conditions + [], # order_by_conditions + 2, # result_limit + ) + mock_parser.return_value = mock_parser_instance + + mock_executor_instance = Mock() + expected_df = pd.DataFrame( + [ + {"parsed_url": "https://example1.com", "title": "Article 1"}, + {"parsed_url": "https://example2.com", "title": "Article 2"}, + ] + ) + mock_executor_instance.execute_query.return_value = expected_df + mock_executor.return_value = mock_executor_instance + + # Mock API calls + mock_responses = [ + {"result": True, "item": {"title": "Article 1", "excerpt": "Excerpt 1"}}, + {"result": True, "item": {"title": "Article 2", "excerpt": "Excerpt 2"}}, + {"result": True, "item": {"title": "Article 3", "excerpt": "Excerpt 3"}}, + ] + self.handler.connection.parse_url.side_effect = mock_responses + + query = Mock() + result = self.table.select(query) + + # Should limit to 2 results + self.assertEqual(len(result), 2) + + def test_insert_not_supported(self): + """Test that insert operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.insert(query) + + self.assertIn("URL parsing is a read-only operation", str(context.exception)) + + def test_update_not_supported(self): + """Test that update operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.update(query) + + self.assertIn("URL parsing is a read-only operation", str(context.exception)) + + def test_delete_not_supported(self): + """Test that delete operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.delete(query) + + self.assertIn("URL parsing is a read-only operation", str(context.exception)) + + def test_normalize_parse_data(self): + """Test _normalize_parse_data method""" + test_data = pd.DataFrame( + [ + { + "parsed_url": "https://example.com", + "title": "Test Article", + "excerpt": "Test excerpt", + "domain": "example.com", + "lastUpdate": "2024-01-01T00:00:00Z", + } + ] + ) + + result = self.table._normalize_parse_data(test_data) + + # Check that all expected columns exist + expected_columns = self.table.get_columns() + for col in expected_columns: + self.assertIn(col, result.columns, f"Missing column: {col}") + + # Check specific values + self.assertEqual(result["parsed_url"].iloc[0], "https://example.com") + self.assertEqual(result["title"].iloc[0], "Test Article") + + def test_normalize_parse_data_empty(self): + """Test _normalize_parse_data with empty DataFrame""" + empty_df = pd.DataFrame() + + result = self.table._normalize_parse_data(empty_df) + + # Should return the same empty DataFrame + self.assertTrue(result.empty) + + +class TestBulkOperationsTable(unittest.TestCase): + """Test cases for BulkOperationsTable""" + + def setUp(self): + self.handler = Mock() + self.handler.connection = Mock() + self.table = BulkOperationsTable(self.handler) + + def test_get_columns(self): + """Test get_columns method""" + columns = self.table.get_columns() + + expected_columns = [ + "operation", + "status", + "affected_count", + "target_collection_id", + "source_collection_id", + "error", + ] + + self.assertEqual(columns, expected_columns) + + def test_select_not_supported(self): + """Test that select operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.select(query) + + self.assertIn("Bulk operations table is not queryable", str(context.exception)) + + def test_insert_not_supported(self): + """Test that insert operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.insert(query) + + self.assertIn("Use UPDATE operations on the raindrops table", str(context.exception)) + + def test_delete_not_supported(self): + """Test that delete operation raises NotImplementedError""" + with self.assertRaises(NotImplementedError) as context: + query = Mock() + self.table.delete(query) + + self.assertIn("Use DELETE operations on the raindrops table", str(context.exception)) + + def test_update_bulk_move_by_collection(self): + """Test bulk move operation by source collection""" + # Mock the UPDATE query components + with patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.UPDATEQueryParser") as mock_parser: + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + {"collection_id": 456}, # values_to_update + [Mock(column="source_collection_id", value=123)], # where_conditions + ) + mock_parser.return_value = mock_parser_instance + + # Mock API call + self.handler.connection.move_raindrops_to_collection.return_value = {"result": True} + + query = Mock() + self.table.update(query) + + # Verify API was called with correct parameters + self.handler.connection.move_raindrops_to_collection.assert_called_once_with( + target_collection_id=456, source_collection_id=123, search=None, ids=None + ) + + def test_update_bulk_move_by_ids(self): + """Test bulk move operation by specific raindrop IDs""" + # Mock the UPDATE query components + with patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.UPDATEQueryParser") as mock_parser: + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + {"collection_id": 789}, # values_to_update + [Mock(column="_id", value=[1, 2, 3])], # where_conditions + ) + mock_parser.return_value = mock_parser_instance + + # Mock API call + self.handler.connection.move_raindrops_to_collection.return_value = {"result": True} + + query = Mock() + self.table.update(query) + + # Verify API was called with correct parameters + self.handler.connection.move_raindrops_to_collection.assert_called_once_with( + target_collection_id=789, source_collection_id=None, search=None, ids=[1, 2, 3] + ) + + def test_update_bulk_move_by_search(self): + """Test bulk move operation by search query""" + # Mock the UPDATE query components + with patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.UPDATEQueryParser") as mock_parser: + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + {"collection_id": 999}, # values_to_update + [Mock(column="search", value="python tutorial")], # where_conditions + ) + mock_parser.return_value = mock_parser_instance + + # Mock API call + self.handler.connection.move_raindrops_to_collection.return_value = {"result": True} + + query = Mock() + self.table.update(query) + + # Verify API was called with correct parameters + self.handler.connection.move_raindrops_to_collection.assert_called_once_with( + target_collection_id=999, source_collection_id=None, search="python tutorial", ids=None + ) + + def test_update_no_collection_id_error(self): + """Test update operation without collection_id raises error""" + # Mock the UPDATE query components + with patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.UPDATEQueryParser") as mock_parser: + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + {"title": "New Title"}, # values_to_update - no collection_id + [Mock(column="source_collection_id", value=123)], # where_conditions + ) + mock_parser.return_value = mock_parser_instance + + query = Mock() + with self.assertRaises(ValueError) as context: + self.table.update(query) + + self.assertIn("Bulk operations table only supports collection moves", str(context.exception)) + + def test_update_no_conditions_error(self): + """Test update operation without any valid conditions raises error""" + # Mock the UPDATE query components + with patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.UPDATEQueryParser") as mock_parser: + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + {"collection_id": 456}, # values_to_update + [Mock(column="invalid_column", value="invalid")], # where_conditions - no valid conditions + ) + mock_parser.return_value = mock_parser_instance + + query = Mock() + with self.assertRaises(ValueError) as context: + self.table.update(query) + + self.assertIn("Please specify source conditions", str(context.exception)) + + def test_update_api_error_handling(self): + """Test update operation handles API errors gracefully""" + # Mock the UPDATE query components + with patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_tables.UPDATEQueryParser") as mock_parser: + mock_parser_instance = Mock() + mock_parser_instance.parse_query.return_value = ( + {"collection_id": 456}, # values_to_update + [Mock(column="source_collection_id", value=123)], # where_conditions + ) + mock_parser.return_value = mock_parser_instance + + # Mock API to raise exception + self.handler.connection.move_raindrops_to_collection.side_effect = Exception("API Error") + + query = Mock() + with self.assertRaises(Exception) as context: + self.table.update(query) + + self.assertEqual(str(context.exception), "API Error") + + +class TestAPICompatibility(unittest.TestCase): + """Test cases for Raindrop API compatibility""" + + def setUp(self): + self.client = RaindropAPIClient("test_api_key") + + def test_endpoint_format_compatibility(self): + """Test that all endpoints match official Raindrop API specification""" + # Test all endpoints used in the handler + test_endpoints = [ + ("/user/stats", "GET"), + ("/raindrops/0", "GET"), + ("/raindrops/123", "GET"), + ("/raindrop/456", "GET"), + ("/raindrop", "POST"), + ("/raindrop/456", "PUT"), + ("/raindrop/456", "DELETE"), + ("/raindrops", "POST"), + ("/raindrops/123", "PUT"), + ("/raindrops/123", "DELETE"), + ("/collections", "GET"), + ("/collection/789", "GET"), + ("/collection", "POST"), + ("/collection/789", "PUT"), + ("/collection/789", "DELETE"), + ("/collections", "DELETE"), + ("/filters/0", "POST"), + ("/tags", "GET"), + ("/parse", "POST"), + ] + + for endpoint, method in test_endpoints: + with self.subTest(endpoint=endpoint, method=method): + try: + # This should not raise a ValueError for invalid endpoints + self.client._make_request(method, endpoint) + except ValueError as e: + if "Invalid endpoint" in str(e): + self.fail(f"Endpoint {endpoint} not recognized as valid") + except Exception: + # Other exceptions (like 401 unauthorized) are expected without real API + pass + + def test_parameter_names_compatibility(self): + """Test that parameter names match official API specification""" + # Test get_raindrops parameters + with patch.object(self.client, "_make_request") as mock_request: + mock_request.return_value = {"result": True, "items": []} + + self.client.get_raindrops(collection_id=123, search="test query", sort="-created", page=1, per_page=25) + + # Verify the call was made with correct parameter names + args, kwargs = mock_request.call_args + params = kwargs.get("params", {}) + + # Official API uses 'perpage' (lowercase, no underscore) + self.assertIn("perpage", params) + self.assertNotIn("per_page", params) + self.assertEqual(params["perpage"], 25) + + # Other parameters should be lowercase + self.assertIn("page", params) + self.assertIn("search", params) + self.assertIn("sort", params) + + def test_sort_parameter_format(self): + """Test that sort parameter format matches API specification""" + # Test ascending sort (just field name) + with patch.object(self.client, "_make_request") as mock_request: + mock_request.return_value = {"result": True, "items": []} + + self.client.get_raindrops(sort="created") + args, kwargs = mock_request.call_args + params = kwargs.get("params", {}) + self.assertEqual(params["sort"], "created") + + # Test descending sort (field with minus prefix) + with patch.object(self.client, "_make_request") as mock_request: + mock_request.return_value = {"result": True, "items": []} + + self.client.get_raindrops(sort="-created") + args, kwargs = mock_request.call_args + params = kwargs.get("params", {}) + self.assertEqual(params["sort"], "-created") + + def test_filters_endpoint_compatibility(self): + """Test /filters endpoint parameter compatibility""" + with patch.object(self.client, "_make_request") as mock_request: + mock_request.return_value = {"result": True, "items": []} + + filters = {"search": "test query", "important": True, "tags": ["tag1", "tag2"], "page": 0, "perpage": 50} + + self.client.get_raindrops_with_filters(collection_id=123, filters=filters) + + # Verify the call was made correctly + args, kwargs = mock_request.call_args + self.assertEqual(args[0], "POST") # Should be POST request + self.assertEqual(args[1], "/filters/123") # Correct endpoint format + self.assertEqual(kwargs["data"], filters) # Data should match filters + + def test_bulk_operations_compatibility(self): + """Test bulk operations compatibility""" + with patch.object(self.client, "_make_request") as mock_request: + mock_request.return_value = {"result": True} + + # Test move operation with search + self.client.move_raindrops_to_collection( + target_collection_id=456, source_collection_id=123, search="test query", ids=[1, 2, 3] + ) + + args, kwargs = mock_request.call_args + self.assertEqual(args[0], "PUT") + self.assertEqual(args[1], "/raindrops/123") # Should use source collection + + data = kwargs["data"] + expected_data = {"collection": {"$id": 456}, "search": "test query", "ids": [1, 2, 3]} + self.assertEqual(data, expected_data) + + # Test move operation without source collection (uses collection 0) + with patch.object(self.client, "_make_request") as mock_request: + mock_request.return_value = {"result": True} + + self.client.move_raindrops_to_collection(target_collection_id=456, search="test query") + + args, kwargs = mock_request.call_args + self.assertEqual(args[1], "/raindrops/0") # Should default to collection 0 + + def test_response_format_expectations(self): + """Test that response format expectations match API""" + # Test successful response format + with patch.object(self.client, "_make_request") as mock_request: + mock_request.return_value = { + "result": True, + "items": [{"_id": 123, "title": "Test Bookmark"}, {"_id": 456, "title": "Another Bookmark"}], + "count": 2, + } + + response = self.client.get_collections() + + # Verify response structure matches what our code expects + self.assertIn("result", response) + self.assertIn("items", response) + self.assertEqual(response["result"], True) + self.assertEqual(len(response["items"]), 2) + + def test_error_handling_compatibility(self): + """Test error handling matches API error formats""" + with patch.object(self.client, "_make_request") as mock_request: + # Simulate API error response + mock_request.side_effect = Exception("Raindrop API error: Invalid collection ID") + + with self.assertRaises(Exception) as context: + self.client.get_collection(999) + + self.assertIn("Raindrop API error", str(context.exception)) + + def test_rate_limiting_compatibility(self): + """Test rate limiting implementation matches API limits""" + # Raindrop API allows 120 requests per minute + self.assertEqual(self.client.rate_limit_per_second, 2) # 120/60 = 2 per second + + # Test that rate limiting tracks requests properly + # Reset request times to ensure clean state + self.client.request_times = [] + + # Test the rate limiting method directly + self.client._apply_rate_limit() + self.assertEqual(len(self.client.request_times), 1) + + # Test rate limit configuration + self.assertEqual(self.client.rate_limit_per_second, 2) + self.assertIsInstance(self.client.request_times, list) + + def test_authentication_header_format(self): + """Test authentication header format matches API requirements""" + # Raindrop API uses Bearer token authentication + expected_auth = f"Bearer {self.client.api_key}" + self.assertEqual(self.client.headers["Authorization"], expected_auth) + self.assertEqual(self.client.headers["Content-Type"], "application/json") + + def test_collections_endpoint_fix(self): + """Test that collections endpoint works correctly without children endpoint""" + # Mock the get_collections to return all collections + self.client._make_request = Mock( + return_value={ + "result": True, + "items": [ + {"_id": 123, "title": "Root Collection"}, + {"_id": 456, "title": "Child Collection", "parent": {"$id": 123}}, + ], + } + ) + + # Test that get_collections works without calling children endpoint + response = self.client.get_collections() + + # Verify the call was made correctly + self.client._make_request.assert_called_once_with("GET", "/collections") + + # Verify response structure + self.assertIn("result", response) + self.assertIn("items", response) + self.assertEqual(len(response["items"]), 2) + + def test_collections_table_integration(self): + """Test that collections table works correctly with the fix""" + # Mock handler and connection + mock_handler = Mock() + mock_connection = Mock() + mock_handler.connection = mock_connection + + # Mock get_collections to return all collections + mock_connection.get_collections.return_value = { + "items": [ + {"_id": 123, "title": "Root Collection"}, + {"_id": 456, "title": "Child Collection", "parent": {"$id": 123}}, + ] + } + + # Create collections table and test get_collections method + collections_table = CollectionsTable(mock_handler) + result = collections_table.get_collections() + + # Verify that get_collections was called (not get_child_collections) + mock_connection.get_collections.assert_called_once() + + # Verify that get_child_collections was NOT called + mock_connection.get_child_collections.assert_not_called() + + # Verify result + self.assertEqual(len(result), 2) + self.assertEqual(result[0]["_id"], 123) + self.assertEqual(result[1]["_id"], 456) + + +class TestSearchOptimizations(unittest.TestCase): + """Test cases for enhanced search capabilities""" + + def setUp(self): + self.table = RaindropsTable(None) + + def test_enhanced_search_parsing_single_field(self): + """Test enhanced search parsing with single field search""" + conditions = [["=", "title", "Python Tutorial"]] + + parsed = self.table._parse_where_conditions(conditions) + + # Should convert to API search + self.assertEqual(parsed["search"], "title:Python Tutorial") + self.assertEqual(len(parsed["api_supported"]), 1) + self.assertEqual(len(parsed["local_filters"]), 0) + + def test_enhanced_search_parsing_multiple_fields(self): + """Test enhanced search parsing with multiple field searches""" + conditions = [["=", "title", "Python"], ["=", "excerpt", "Tutorial"], ["=", "note", "Advanced"]] + + parsed = self.table._parse_where_conditions(conditions) + + # Should combine into complex search query + expected_search = "(title:Python AND excerpt:Tutorial AND note:Advanced)" + self.assertEqual(parsed["search"], expected_search) + self.assertEqual(len(parsed["api_supported"]), 3) + self.assertEqual(len(parsed["local_filters"]), 0) + + def test_enhanced_search_like_optimization(self): + """Test LIKE pattern optimization to API search""" + conditions = [["like", "title", "%python%"], ["like", "excerpt", "%tutorial%"]] + + parsed = self.table._parse_where_conditions(conditions) + + # Should convert LIKE patterns to API search + self.assertIn("title:python", parsed["search"]) + self.assertIn("excerpt:tutorial", parsed["search"]) + self.assertEqual(len(parsed["api_supported"]), 2) + self.assertEqual(len(parsed["local_filters"]), 0) + + def test_enhanced_search_mixed_conditions(self): + """Test mixed search conditions (API and local)""" + conditions = [ + ["=", "title", "Python"], # Should be optimized to API + ["=", "important", True], # Should remain local + ["like", "tags", "%web%"], # Should be optimized to API + ] + + parsed = self.table._parse_where_conditions(conditions) + + # Should have both API search and local filters + self.assertIsNotNone(parsed["search"]) + self.assertIn("title:Python", parsed["search"]) + self.assertIn("tag:web", parsed["search"]) + self.assertEqual(len(parsed["local_filters"]), 1) # important flag + + def test_like_pattern_not_optimized_complex(self): + """Test that complex LIKE patterns are not optimized""" + conditions = [ + ["like", "title", "python%"], # Only starts with %, not optimized + ["like", "title", "%p%t%"], # Contains regex chars, not optimized + ["like", "title", "%ab%"], # Too short, not optimized + ] + + parsed = self.table._parse_where_conditions(conditions) + + # Should keep complex patterns as local filters + self.assertIsNone(parsed["search"]) + self.assertEqual(len(parsed["local_filters"]), 3) + + def test_like_pattern_optimized_simple(self): + """Test that simple LIKE patterns are optimized""" + conditions = [ + ["like", "title", "%python%"], # Should be optimized + ["like", "excerpt", "%tutorial%"], # Should be optimized + ["like", "note", "%advanced%"], # Should be optimized + ] + + parsed = self.table._parse_where_conditions(conditions) + + # Should convert to API search + self.assertIsNotNone(parsed["search"]) + self.assertIn("title:python", parsed["search"]) + self.assertIn("excerpt:tutorial", parsed["search"]) + self.assertIn("note:advanced", parsed["search"]) + self.assertEqual(len(parsed["api_supported"]), 3) + self.assertEqual(len(parsed["local_filters"]), 0) + + def test_existing_search_not_overridden(self): + """Test that existing search conditions are not overridden by optimizations""" + conditions = [ + ["=", "search", "original query"], + ["=", "title", "Python"], # This should not override the search + ] + + parsed = self.table._parse_where_conditions(conditions) + + # Should keep original search + self.assertEqual(parsed["search"], "original query") + self.assertEqual(len(parsed["api_supported"]), 2) + + def test_is_search_condition_detection(self): + """Test search condition detection logic""" + # Should detect search conditions + self.assertTrue(self.table._is_search_condition("title", "=")) + self.assertTrue(self.table._is_search_condition("search", "=")) + self.assertTrue(self.table._is_search_condition("excerpt", "like")) + self.assertTrue(self.table._is_search_condition("tags", "like")) + + # Should not detect non-search conditions + self.assertFalse(self.table._is_search_condition("created", ">")) + self.assertFalse(self.table._is_search_condition("_id", "=")) + self.assertFalse(self.table._is_search_condition("collection_id", "=")) + + def test_can_use_api_search_for_like(self): + """Test LIKE pattern optimization detection""" + # Should optimize simple patterns + self.assertTrue(self.table._can_use_api_search_for_like("title", "%python%")) + self.assertTrue(self.table._can_use_api_search_for_like("excerpt", "%tutorial%")) + + # Should not optimize complex patterns + self.assertFalse(self.table._can_use_api_search_for_like("title", "python%")) + self.assertFalse(self.table._can_use_api_search_for_like("title", "%p%t%")) + self.assertFalse(self.table._can_use_api_search_for_like("title", "%ab%")) + self.assertFalse(self.table._can_use_api_search_for_like("title", "%test*ing%")) + + def test_convert_like_to_api_search(self): + """Test LIKE to API search conversion""" + # Should convert properly + self.assertEqual(self.table._convert_like_to_api_search("title", "%python%"), "title:python") + self.assertEqual(self.table._convert_like_to_api_search("excerpt", "%tutorial%"), "excerpt:tutorial") + self.assertEqual(self.table._convert_like_to_api_search("note", "%advanced%"), "note:advanced") + self.assertEqual(self.table._convert_like_to_api_search("tags", "%web%"), "tag:web") + + # Should handle non-string values + self.assertIsNone(self.table._convert_like_to_api_search("title", 123)) + + +if __name__ == "__main__": + unittest.main() diff --git a/mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_integration.py b/mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_integration.py new file mode 100644 index 00000000000..ee8fd863c21 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/tests/test_raindrop_integration.py @@ -0,0 +1,178 @@ +import unittest +import os +import pandas as pd +from mindsdb.integrations.handlers.raindrop_handler.raindrop_handler import RaindropHandler + + +class TestRaindropHandlerIntegration(unittest.TestCase): + """Integration tests for RaindropHandler (requires valid API key)""" + + @classmethod + def setUpClass(cls): + """Set up the test environment""" + cls.api_key = os.environ.get("RAINDROP_API_KEY") + if not cls.api_key: + raise unittest.SkipTest("RAINDROP_API_KEY environment variable not set") + + cls.handler = RaindropHandler("test_raindrop_handler") + cls.handler.connection_data = {"api_key": cls.api_key} + + def test_check_connection(self): + """Test that we can connect to the Raindrop.io API""" + response = self.handler.check_connection() + self.assertTrue(response.success, f"Connection failed: {response.error_message}") + + def test_get_tables(self): + """Test that tables are properly registered""" + tables = self.handler.get_tables() + table_names = [table.data[0] for table in tables.data] + + self.assertIn("raindrops", table_names) + self.assertIn("bookmarks", table_names) + self.assertIn("collections", table_names) + + def test_raindrops_table_select(self): + """Test selecting from raindrops table""" + # Test basic select + query = "SELECT * FROM raindrops LIMIT 5" + result = self.handler.native_query(query) + self.assertTrue(result.success, f"Query failed: {result.error_message}") + + # Check that we get a DataFrame + if hasattr(result, "data_frame") and result.data_frame is not None: + self.assertIsInstance(result.data_frame, pd.DataFrame) + + def test_collections_table_select(self): + """Test selecting from collections table""" + query = "SELECT * FROM collections LIMIT 5" + result = self.handler.native_query(query) + self.assertTrue(result.success, f"Query failed: {result.error_message}") + + # Check that we get a DataFrame + if hasattr(result, "data_frame") and result.data_frame is not None: + self.assertIsInstance(result.data_frame, pd.DataFrame) + + def test_raindrops_table_columns(self): + """Test that raindrops table has expected columns""" + raindrops_table = self.handler.get_table("raindrops") + columns = raindrops_table.get_columns() + + expected_columns = [ + "_id", + "link", + "title", + "excerpt", + "note", + "type", + "cover", + "tags", + "important", + "reminder", + "removed", + "created", + "lastUpdate", + "domain", + "collection.id", + "collection.title", + "user.id", + "broken", + "cache", + "file.name", + "file.size", + "file.type", + ] + + for col in expected_columns: + self.assertIn(col, columns, f"Column {col} not found in raindrops table") + + def test_collections_table_columns(self): + """Test that collections table has expected columns""" + collections_table = self.handler.get_table("collections") + columns = collections_table.get_columns() + + expected_columns = [ + "_id", + "title", + "description", + "color", + "view", + "public", + "sort", + "count", + "created", + "lastUpdate", + "expanded", + "parent.id", + "user.id", + "cover", + "access.level", + "access.draggable", + ] + + for col in expected_columns: + self.assertIn(col, columns, f"Column {col} not found in collections table") + + def test_create_and_delete_bookmark(self): + """Test creating and deleting a bookmark (if API key has write permissions)""" + try: + # Create a test bookmark + insert_query = """ + INSERT INTO raindrops (link, title, note, tags) + VALUES ('https://example.com/test', 'Test Bookmark', 'Test note', 'test,automated') + """ + result = self.handler.native_query(insert_query) + + if not result.success: + # Skip if we don't have write permissions + self.skipTest(f"Cannot create bookmarks: {result.error_message}") + + # Try to find the bookmark we just created + select_query = "SELECT * FROM raindrops WHERE title = 'Test Bookmark' LIMIT 1" + result = self.handler.native_query(select_query) + self.assertTrue(result.success) + + if hasattr(result, "data_frame") and result.data_frame is not None and not result.data_frame.empty: + bookmark_id = result.data_frame["_id"].iloc[0] + + # Delete the test bookmark + delete_query = f"DELETE FROM raindrops WHERE _id = {bookmark_id}" + result = self.handler.native_query(delete_query) + self.assertTrue(result.success) + + except Exception as e: + self.fail(f"Create/delete test failed: {e}") + + def test_create_and_delete_collection(self): + """Test creating and deleting a collection (if API key has write permissions)""" + try: + # Create a test collection + insert_query = """ + INSERT INTO collections (title, description, color) + VALUES ('Test Collection', 'Automated test collection', '#FF0000') + """ + result = self.handler.native_query(insert_query) + + if not result.success: + # Skip if we don't have write permissions + self.skipTest(f"Cannot create collections: {result.error_message}") + + # Try to find the collection we just created + select_query = "SELECT * FROM collections WHERE title = 'Test Collection' LIMIT 1" + result = self.handler.native_query(select_query) + self.assertTrue(result.success) + + if hasattr(result, "data_frame") and result.data_frame is not None and not result.data_frame.empty: + collection_id = result.data_frame["_id"].iloc[0] + + # Delete the test collection + delete_query = f"DELETE FROM collections WHERE _id = {collection_id}" + result = self.handler.native_query(delete_query) + self.assertTrue(result.success) + + except Exception as e: + self.fail(f"Create/delete collection test failed: {e}") + + +if __name__ == "__main__": + # Run integration tests only if API key is available + unittest.main() diff --git a/mindsdb/integrations/handlers/raindrop_handler/verify_implementation.py b/mindsdb/integrations/handlers/raindrop_handler/verify_implementation.py new file mode 100644 index 00000000000..f220ba740c5 --- /dev/null +++ b/mindsdb/integrations/handlers/raindrop_handler/verify_implementation.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 + +""" +Verification script for the Raindrop.io handler implementation. +This script checks all the key functionality without requiring a real API key. + +Recent improvements: +- Uses logging instead of print statements for better integration with MindsDB logging +- Tests robustness of data normalization with missing columns +- Validates error handling for various edge cases +- Implements rate limiting to prevent API quota exhaustion +- Optimizes pagination for small LIMIT queries +""" + +import sys +import logging +from unittest.mock import Mock, patch + +# Set up logging +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +logger = logging.getLogger(__name__) + + +def test_handler_loading(): + """Test that the handler can be loaded and instantiated""" + try: + from mindsdb.integrations.handlers.raindrop_handler import Handler, name, type, title, connection_args + + logger.info("[PASS] Handler module loaded successfully") + logger.info(f" Name: {name}") + logger.info(f" Type: {type}") + logger.info(f" Title: {title}") + logger.info(f" Connection args: {list(connection_args.keys())}") + + # Test instantiation + handler = Handler("test_handler") + logger.info("[PASS] Handler instantiated successfully") + logger.info(f" Tables: {list(handler._tables.keys())}") + + return True + except Exception as e: + logger.error(f"[FAIL] Handler loading failed: {e}") + return False + + +def test_api_client(): + """Test the API client functionality""" + try: + from mindsdb.integrations.handlers.raindrop_handler.raindrop_handler import RaindropAPIClient + + client = RaindropAPIClient("test_key") + logger.info("[PASS] API client instantiated successfully") + logger.info(f" Base URL: {client.base_url}") + logger.info(f" Headers configured: {'Authorization' in client.headers}") + + return True + except Exception as e: + logger.error(f"[FAIL] API client test failed: {e}") + return False + + +def test_table_functionality(): + """Test table functionality with mocked data""" + try: + from mindsdb.integrations.handlers.raindrop_handler.raindrop_tables import RaindropsTable, CollectionsTable + import pandas as pd + + # Test RaindropsTable + handler_mock = Mock() + raindrops_table = RaindropsTable(handler_mock) + + columns = raindrops_table.get_columns() + logger.info(f"[PASS] RaindropsTable columns: {len(columns)} columns") + + # Test data normalization + test_data = pd.DataFrame( + [ + { + "_id": 123, + "title": "Test", + "collection": {"$id": 456, "title": "Test Collection"}, + "tags": ["tag1", "tag2"], + "created": "2024-01-01T00:00:00Z", + } + ] + ) + + raindrops_table._normalize_raindrop_data(test_data) + logger.info("[PASS] RaindropsTable data normalization works") + + # Test data preparation + raindrops_table._prepare_raindrop_data( + {"link": "https://example.com", "title": "Test", "tags": "tag1,tag2", "collection_id": 123} + ) + logger.info("[PASS] RaindropsTable data preparation works") + + # Test CollectionsTable + collections_table = CollectionsTable(handler_mock) + columns = collections_table.get_columns() + logger.info(f"[PASS] CollectionsTable columns: {len(columns)} columns") + + return True + except Exception as e: + logger.error(f"[FAIL] Table functionality test failed: {e}") + return False + + +def test_connection_handling(): + """Test connection handling""" + try: + from mindsdb.integrations.handlers.raindrop_handler import Handler + + # Test with missing API key + handler = Handler("test") + try: + handler.connect() + logger.error("[FAIL] Should have failed with missing API key") + return False + except ValueError as e: + if "API key is required" in str(e): + logger.info("[PASS] Properly validates missing API key") + else: + logger.error(f"[FAIL] Unexpected error: {e}") + return False + + # Test with API key + handler.connection_data = {"api_key": "test_key"} + + with patch("mindsdb.integrations.handlers.raindrop_handler.raindrop_handler.RaindropAPIClient") as mock_client: + mock_instance = Mock() + mock_client.return_value = mock_instance + + handler.connect() + logger.info("[PASS] Connection with API key works") + + # Test connection check + mock_instance.get_user_stats.return_value = {"result": True} + status = handler.check_connection() + logger.info(f"[PASS] Connection check works: {status.success}") + + return True + except Exception as e: + logger.error(f"[FAIL] Connection handling test failed: {e}") + return False + + +def main(): + """Run all verification tests""" + logger.info("[VERIFY] Verifying Raindrop.io Handler Implementation") + logger.info("=" * 50) + + tests = [ + ("Handler Loading", test_handler_loading), + ("API Client", test_api_client), + ("Table Functionality", test_table_functionality), + ("Connection Handling", test_connection_handling), + ] + + passed = 0 + total = len(tests) + + for test_name, test_func in tests: + logger.info(f"\n[TEST] {test_name}") + logger.info("-" * 30) + if test_func(): + passed += 1 + else: + logger.error(f"[FAILED] {test_name} failed") + + logger.info("\n" + "=" * 50) + logger.info(f"[RESULTS] Test Results: {passed}/{total} tests passed") + + if passed == total: + logger.info("[SUCCESS] All tests passed! The Raindrop.io handler is ready for use.") + return 0 + else: + logger.error("[FAILED] Some tests failed. Please check the implementation.") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) From eef5ad36a7c2f3abeadf92235e7ab11c00ff4aba Mon Sep 17 00:00:00 2001 From: RITWICK RAJ MAKHAL Date: Fri, 13 Mar 2026 15:00:22 +0530 Subject: [PATCH 045/125] Add Denodo integration handler (#10432) --- .../handlers/denodo_handler/README.md | 57 +++++++ .../handlers/denodo_handler/__about__.py | 9 + .../handlers/denodo_handler/__init__.py | 30 ++++ .../denodo_handler/connection_args.py | 46 ++++++ .../handlers/denodo_handler/denodo_handler.py | 155 ++++++++++++++++++ .../handlers/denodo_handler/icon.png | Bin 0 -> 3331 bytes .../handlers/hubspot_handler/README.md | 5 + 7 files changed, 302 insertions(+) create mode 100644 mindsdb/integrations/handlers/denodo_handler/README.md create mode 100644 mindsdb/integrations/handlers/denodo_handler/__about__.py create mode 100644 mindsdb/integrations/handlers/denodo_handler/__init__.py create mode 100644 mindsdb/integrations/handlers/denodo_handler/connection_args.py create mode 100644 mindsdb/integrations/handlers/denodo_handler/denodo_handler.py create mode 100644 mindsdb/integrations/handlers/denodo_handler/icon.png diff --git a/mindsdb/integrations/handlers/denodo_handler/README.md b/mindsdb/integrations/handlers/denodo_handler/README.md new file mode 100644 index 00000000000..87dcf6b3eda --- /dev/null +++ b/mindsdb/integrations/handlers/denodo_handler/README.md @@ -0,0 +1,57 @@ +--- +title: Denodo +sidebarTitle: Denodo +--- + +This documentation describes the integration of MindsDB with [Denodo](https://www.denodo.com/), a powerful data virtualization platform that enables real-time access and integration of multiple data sources. +The integration allows MindsDB to query Denodo views and enhance them with AI capabilities. + +## Prerequisites + +Before proceeding, ensure the following prerequisites are met: + +1. Install MindsDB locally via [Docker](https://docs.mindsdb.com/setup/self-hosted/docker) or [Docker Desktop](https://docs.mindsdb.com/setup/self-hosted/docker-desktop). + +## Connection + +Establish a connection to Denodo from MindsDB by executing the following SQL command and providing its [handler name](https://github.com/mindsdb/mindsdb/tree/main/mindsdb/integrations/handlers/denodo_handler) as an engine. + +```sql +CREATE DATABASE denodo_conn +WITH ENGINE = 'denodo', +PARAMETERS = { + "host": "host-name", + "port": 9996, + "database": "db-name", + "user": "user-name", + "password": "password" +}; +``` + +Required connection parameters include the following: + +- `user`: The username for the Denodo database. +- `password`: The password for the Denodo database. +- `host`: The hostname, IP address, or URL of the Denodo server. +- `port`: The port number for connecting to the Denodo server (default is `9999`). +- `database`: The name of the Denodo virtual database to connect to. + +## Usage + +The following usage examples utilize the connection to Denodo made via the `CREATE DATABASE` statement and named `denodo_conn`. + +Retrieve data from a specified Denodo view by providing the integration and view name. + +```sql +SELECT * +FROM denodo_conn.view_name +LIMIT 10; +``` + +Running native SQL queries on Denodo views is also supported. + +```sql +SELECT * FROM denodno_conn ( + DESC VIEW view_name +); +``` \ No newline at end of file diff --git a/mindsdb/integrations/handlers/denodo_handler/__about__.py b/mindsdb/integrations/handlers/denodo_handler/__about__.py new file mode 100644 index 00000000000..c780673cdff --- /dev/null +++ b/mindsdb/integrations/handlers/denodo_handler/__about__.py @@ -0,0 +1,9 @@ +__title__ = "MindsDB Denodo handler" +__package_name__ = "mindsdb_denodo_handler" +__version__ = "0.0.1" +__description__ = "MindsDB handler for Denodo" +__author__ = "Ritwick Raj Makhal" +__github__ = "https://github.com/mindsdb/mindsdb" +__pypi__ = "https://pypi.org/project/mindsdb/" +__license__ = "MIT" +__copyright__ = "Copyright 2022- mindsdb" diff --git a/mindsdb/integrations/handlers/denodo_handler/__init__.py b/mindsdb/integrations/handlers/denodo_handler/__init__.py new file mode 100644 index 00000000000..20293b54d55 --- /dev/null +++ b/mindsdb/integrations/handlers/denodo_handler/__init__.py @@ -0,0 +1,30 @@ +from mindsdb.integrations.libs.const import HANDLER_TYPE + +from .__about__ import __version__ as version, __description__ as description +from .connection_args import connection_args, connection_args_example + +try: + from .denodo_handler import DenodoHandler as Handler + + import_error = None +except Exception as e: + Handler = None + import_error = e + +title = "Denodo" +name = "denodo" +type = HANDLER_TYPE.DATA +icon_path = "icon.png" + +__all__ = [ + "Handler", + "version", + "name", + "type", + "title", + "description", + "connection_args", + "connection_args_example", + "import_error", + "icon_path", +] diff --git a/mindsdb/integrations/handlers/denodo_handler/connection_args.py b/mindsdb/integrations/handlers/denodo_handler/connection_args.py new file mode 100644 index 00000000000..384bfa677c7 --- /dev/null +++ b/mindsdb/integrations/handlers/denodo_handler/connection_args.py @@ -0,0 +1,46 @@ +from collections import OrderedDict + +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE + + +connection_args = OrderedDict( + user={ + "type": ARG_TYPE.STR, + "description": "The user name used to authenticate with the Denodo server.", + "required": True, + "label": "User", + }, + password={ + "type": ARG_TYPE.PWD, + "description": "The password to authenticate the user with the Denodo server.", + "required": True, + "label": "Password", + "secret": True, + }, + database={ + "type": ARG_TYPE.STR, + "description": "The database name to use when connecting with the Denodo server.", + "required": True, + "label": "Database", + }, + host={ + "type": ARG_TYPE.STR, + "description": "The host name or IP address of the Denodo server.", + "required": True, + "label": "Host", + }, + port={ + "type": ARG_TYPE.INT, + "description": "The TCP/IP port of the Denodo server. Must be an integer.", + "required": True, + "label": "Port", + }, +) + +connection_args_example = { + "host": "localhost", + "port": 9996, + "user": "admin", + "password": "password", + "database": "database", +} diff --git a/mindsdb/integrations/handlers/denodo_handler/denodo_handler.py b/mindsdb/integrations/handlers/denodo_handler/denodo_handler.py new file mode 100644 index 00000000000..4c0c2643aa5 --- /dev/null +++ b/mindsdb/integrations/handlers/denodo_handler/denodo_handler.py @@ -0,0 +1,155 @@ +import pandas as pd +from typing import Optional +import psycopg2 as dbdriver +from psycopg2 import OperationalError, InterfaceError, ProgrammingError + +from mindsdb_sql_parser import parse_sql +from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender +from mindsdb_sql_parser.ast.base import ASTNode + +from mindsdb.utilities import log +from mindsdb.integrations.libs.base import DatabaseHandler +from mindsdb.integrations.libs.response import ( + HandlerStatusResponse as StatusResponse, + HandlerResponse as Response, + RESPONSE_TYPE, +) + +logger = log.getLogger(__name__) + + +class DenodoHandler(DatabaseHandler): + """ + This handler handles connection and execution of the Denodo statements. + """ + + name = "denodo" + + def __init__(self, name: str, **kwargs) -> None: + super().__init__(name) + self.parser = parse_sql + self.dialect = "mysql" + self.connection_data = kwargs.get("connection_data", {}) + self.database = self.connection_data.get("database") + + self.connection = None + + def connect(self) -> Optional[dbdriver.extensions.connection]: + """ + Connect to the Denodo database using the connection data provided. + + Returns: + Optional[dbdriver.extensions.connection]: A connection object if successful, None otherwise. + """ + if self.connection is not None: + return self.connection + + try: + self.connection = dbdriver.connect( + host=self.connection_data.get("host"), + port=self.connection_data.get("port"), + user=self.connection_data.get("user"), + password=self.connection_data.get("password"), + database=self.connection_data.get("database"), + ) + return self.connection + except (OperationalError, InterfaceError) as e: + logger.error(f"Error connecting to Denodo: {str(e)}") + raise ConnectionError(f"Failed to connect to Denodo: {str(e)}") + + def disconnect(self) -> None: + """ + Safely close the database connection. + """ + if self.connection is not None: + self.connection.close() + self.connection = None + + def _validate_connection(self) -> None: + """ + Check if the connection is still active and reconnect if necessary. + """ + if not self.connection: + self.connect() + try: + with self.connection.cursor() as cursor: + cursor.execute("SELECT 1") + except (OperationalError, InterfaceError): + self.connect() + + def check_connection(self) -> StatusResponse: + """ + Check if the connection is still active. + + Returns: + StatusResponse: A response object containing the status of the connection. + """ + try: + self._validate_connection() + return StatusResponse(True) + except Exception as e: + logger.error(f"Connection check failed: {str(e)}") + return StatusResponse(False, str(e)) + + def native_query(self, query: str) -> Response: + """ + Executes a VQL query on the Denodo database and returns the result. + + Args: + query (str): The VQL query to be executed. + + Returns: + Response: A response object containing the result of the query or an error message. + """ + self._validate_connection() + + try: + connection = self.connect() + with connection.cursor() as cur: + cur.execute(query) + if cur.description is not None: + columns = [desc[0] for desc in cur.description] + result = cur.fetchall() + response = Response( + resp_type=RESPONSE_TYPE.TABLE, + query=query, + data_frame=pd.DataFrame(result, columns=columns), + ) + else: + response = Response(RESPONSE_TYPE.OK) + + except (OperationalError, InterfaceError, ProgrammingError) as e: + logger.error(f"Error running query: {query} on {self.database}!") + response = Response(RESPONSE_TYPE.ERROR, error_message=str(e)) + + return response + + def query(self, query: ASTNode) -> Response: + """ + Execute a SQL query and return results. + """ + renderer = SqlalchemyRender(self.dialect) + query_str = renderer.get_string(query, with_failback=True) + return self.native_query(query_str) + + def get_tables(self) -> Response: + """ + Get all tables in current schema. + """ + query = "SELECT name FROM GET_VIEWS();" + result = self.native_query(query) + df = result.data_frame.rename(columns={"name": "TABLE_NAME"}) + result.data_frame = df + return result + + def get_columns(self, table_name: str) -> Response: + """ + Get columns for specified table using parameterized query. + """ + query = f"CALL GET_VIEW_COLUMNS('{self.database}', '{table_name}');" + result = self.native_query(query) + df = result.data_frame.rename( + columns={"column_name": "COLUMN_NAME", "data_type": "DATA_TYPE"} + ) + result.data_frame = df + return result diff --git a/mindsdb/integrations/handlers/denodo_handler/icon.png b/mindsdb/integrations/handlers/denodo_handler/icon.png new file mode 100644 index 0000000000000000000000000000000000000000..f7d20018201fa60e5a028c0f3c22f0df528fad21 GIT binary patch literal 3331 zcmb_fX*3jU8y=~YEMX`VB0{!g>{85FVk~7}vn7V{3K_-{UP~FKnJh6fZ!@-pjEF3g znzDQ)W(XlWGnNcvH3zH_eoxz6*P>%Px(p5>hDeq>{1%F8Xr4FCXm z&CS5K4teJ=I!+uv96xyM;tuKLA7;*B0Kh51{{#o^XWT9Tz*A=qzF~(L`t`gkVzNW* zy5V{hvLGQ!;2F=?;hHPp9__>unS0QqNit{8aUiCGzX+hx?o9L899g5Y`W0go@XV@P`k)yCBo}7|lnk zKc?r;Q$n-53RgtW(FA6FJc zBjV)VzkQ;+EXDpFsuvjk5RO&0s|k^fm*@zX?ah_Nrl%-eCfB7f0C91D(GF$rHJq--t`4^?Jy7UK3%!%Xyuhg9%o%CgJ+d+@l{dfzFnW31oB01wrC zaeOu$0G|_eSP^uWQ4+xfKn|u;J?e!q#bNM>#G?iOiu0_(rNAH?xcOkP?+kmXDU(kN z1H6h7(O)g!dVKg0HrQz8rkkFK{)U$Uq7d7vWR{$-7{pd1x1M=7*MPl@dsDdDW;C@w z^hZla9I_-k4Bl(=s&qSuB{jnt+P+x(7q5Mpa*w zn^{_5u*CDQRk0s5qx|<}<1cT}KJ{=TMJw>_DaZ`FEy~{8r1h}57XqGv(o7)M6_$BT zkY%#FF$ZsSJOi)$Oaq?zQb?OI0SgB2yX1G^__10hv@yH>fs~cJWX*N|r1e$pX)j)f z;|p}VQ-9{A6xRr|N=ybmplX?NdBr9n0yz}ThHG92p6>XsPQ=vlP93e_dxiUpIy`ky z2i|s*p!Or{9z;{ytmE>&dNXK|uBAJmMt(oy5-C;QRrtzN&tAm~GNZIhhCI8$d{FDk z2$8n9Af=gJ)zaB-r_WqVx|z3V8W2mHb_GIk<*3T!411Q_s);mPU#eR`EU>@W>_kpvMdiGkaT;?)DoC7HTLp<+y(wE)UnaM6N|OGwq@J z=UWl_wnpC;jz0x}46>%R9{b13|k%q~tSB#up5~vEQ zmi!qp-{ZVStxSCzqF)6Yc9eXs&3es7`O@787Rt!<^pY9eF+pQqP@qDzHxI1b!n=h% zKQttM^lNE9d2?`TDE;(L-Fw*)Cvs(#3@?t%SEEs~DsKxb(Q;HSTcG=S&<8G#)wYSD z;lfEuX-_9AdD;H2H8;_ZIqM3Ibq|I_Ut8$)c~edS-`exu^`$F3vo8a~sJ8qQ9T&&L zT}?fByhM``JqqYx#b3mDwJKGcwYR;Il-9&L=hWEml#n1-SO$`O0AbC1gi$TVb*t~- zk7Q0(i9Ao!@vKTq93xc>A6OQw29#Uli{$j253XDf8!NRRboc}gN~%7{ymB2)$gQkm zb%1leKCh9C35newb3sUOzPd(8Erc=>WtIHhU&gHLU62SH=m%RPEekB9PMteb60R^e zHJbpbcWZ))O3rGpJj)%6Riu8tHm~zbsTsU`OzRpV)Jay zPh+MKa>QSTy-6!xBJRD-Uax0Fy0m|$#$x$G12fVqlq7r&Kep725~o7`aACXKe>zYB z7etTv+Y98pGu_YEDrTA;9K>P7a!$)&RKa&ccwJVE--ovmNX zO0sS1pW8nX^2%6oiI(hc&|bvHX*YNkiR&!)G*SQ3MNDXdoW;N5R z3w$cc3pX~JArrVjZ4b_rDTqEFM!P)gZo;X;p%r2}gP6OP+@|{%tr-8*$>cMymuNIf z!nK_{hnILDNkEU@#<$#Vyk{6CD#bf_JxXToY>wqd`|Zx+STTNQCDeThE7kQM@=WuM zQvc(>=vxIPE?1A>%)C^$Qm8CB2PlY8Sff54jmILK?&@wE4OvTOxJs3I^iN<*Y59p{&X30^B<8ocq7n@_%s6zi-QbKO@F_pZ^V6 WrPSOC+j-d40L&p);3}g#iT?&+w<6L2 literal 0 HcmV?d00001 diff --git a/mindsdb/integrations/handlers/hubspot_handler/README.md b/mindsdb/integrations/handlers/hubspot_handler/README.md index 032024df64e..2bd51529968 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/README.md +++ b/mindsdb/integrations/handlers/hubspot_handler/README.md @@ -111,6 +111,11 @@ Association tables are read-only and support `SELECT` only. They expose relation The handler provides `SHOW TABLES` and `information_schema.columns` support for all tables. Column statistics are sampled for core CRM and engagement tables. +**Important Notes on Field Values:** +- **Industry codes**: HubSpot uses predefined industry values (e.g., `COMPUTER_SOFTWARE`, `BIOTECHNOLOGY`, `FINANCIAL_SERVICES`). See [HubSpot's industry list](https://knowledge.hubspot.com/properties/hubspots-default-company-properties#industry) for all valid options. +- **Deal stages**: Each HubSpot account has custom pipeline stages. Use the stage IDs from your account (e.g., `presentationscheduled`, `closedwon`, `closedlost`, or numeric IDs like `110382973`). +- **Email validation**: Contact email addresses must be valid email formats (e.g., `user@example.com`). + ## Example Usage ### Basic Connection From 336545293f76af1842dadc61d6d3c0d17caeccd7 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Fri, 13 Mar 2026 12:32:48 +0300 Subject: [PATCH 046/125] Fix hubspot readme 2 (#12289) --- .../integrations/handlers/denodo_handler/denodo_handler.py | 4 +--- mindsdb/integrations/handlers/hubspot_handler/README.md | 5 ----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/mindsdb/integrations/handlers/denodo_handler/denodo_handler.py b/mindsdb/integrations/handlers/denodo_handler/denodo_handler.py index 4c0c2643aa5..14af2d18372 100644 --- a/mindsdb/integrations/handlers/denodo_handler/denodo_handler.py +++ b/mindsdb/integrations/handlers/denodo_handler/denodo_handler.py @@ -148,8 +148,6 @@ def get_columns(self, table_name: str) -> Response: """ query = f"CALL GET_VIEW_COLUMNS('{self.database}', '{table_name}');" result = self.native_query(query) - df = result.data_frame.rename( - columns={"column_name": "COLUMN_NAME", "data_type": "DATA_TYPE"} - ) + df = result.data_frame.rename(columns={"column_name": "COLUMN_NAME", "data_type": "DATA_TYPE"}) result.data_frame = df return result diff --git a/mindsdb/integrations/handlers/hubspot_handler/README.md b/mindsdb/integrations/handlers/hubspot_handler/README.md index 2bd51529968..032024df64e 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/README.md +++ b/mindsdb/integrations/handlers/hubspot_handler/README.md @@ -111,11 +111,6 @@ Association tables are read-only and support `SELECT` only. They expose relation The handler provides `SHOW TABLES` and `information_schema.columns` support for all tables. Column statistics are sampled for core CRM and engagement tables. -**Important Notes on Field Values:** -- **Industry codes**: HubSpot uses predefined industry values (e.g., `COMPUTER_SOFTWARE`, `BIOTECHNOLOGY`, `FINANCIAL_SERVICES`). See [HubSpot's industry list](https://knowledge.hubspot.com/properties/hubspots-default-company-properties#industry) for all valid options. -- **Deal stages**: Each HubSpot account has custom pipeline stages. Use the stage IDs from your account (e.g., `presentationscheduled`, `closedwon`, `closedlost`, or numeric IDs like `110382973`). -- **Email validation**: Contact email addresses must be valid email formats (e.g., `user@example.com`). - ## Example Usage ### Basic Connection From e3daacc6a99241ed42afe9c90e47f30ca8f583c0 Mon Sep 17 00:00:00 2001 From: Parthiv Makwana <75653580+parthiv11@users.noreply.github.com> Date: Fri, 13 Mar 2026 17:22:15 +0530 Subject: [PATCH 047/125] added motherduck support in duckDB (#10385) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michael Olayemi Olawepo <154475559+sejubar@users.noreply.github.com> Co-authored-by: andrew Co-authored-by: April I. Murphy <36110273+aimurphy@users.noreply.github.com> Co-authored-by: Minura Punchihewa <49385643+MinuraPunchihewa@users.noreply.github.com> Co-authored-by: Konstantin Sivakov Co-authored-by: martyna-mindsdb <109554435+martyna-mindsdb@users.noreply.github.com> Co-authored-by: Sebastián Tobón Hernández --- .../handlers/duckdb_handler/README.md | 49 +++++++++++++------ .../duckdb_handler/connection_args.py | 22 ++++++--- .../handlers/duckdb_handler/duckdb_handler.py | 37 +++++++------- 3 files changed, 70 insertions(+), 38 deletions(-) diff --git a/mindsdb/integrations/handlers/duckdb_handler/README.md b/mindsdb/integrations/handlers/duckdb_handler/README.md index 54c1040a42c..5fa9125b940 100644 --- a/mindsdb/integrations/handlers/duckdb_handler/README.md +++ b/mindsdb/integrations/handlers/duckdb_handler/README.md @@ -1,41 +1,62 @@ -# DuckDB Handler +# DuckDB Handler This is the implementation of the DuckDB handler for MindsDB. ## DuckDB DuckDB is an open-source analytical database system. DuckDB is designed for fast execution of analytical queries. -There are no external dependencies and the DBMS runs completly embedded within a host process, similar to SQLite. +There are no external dependencies, and the DBMS runs completely embedded within a host process, similar to SQLite. DuckDB provides a rich SQL dialect with support for complex queries with transactional guarantees (ACID). -## Implementation -This handler was implemented using the `duckdb` python client library. +## Implementation +This handler was implemented using the `duckdb` Python client library. ### DuckDB version -The DuckDB handler is currently using the `0.7.1.dev187` pre-relase version of the python client library. In case of issues, make sure your DuckDB database is compatible with this version. See the DuckDB handler [requirements.txt](requirements.txt) for details. - +The DuckDB handler is currently using the `1.1.3` release version of the Python client library. In case of issues, make sure your DuckDB or MotherDuck database is compatible with this version. See the DuckDB handler [requirements.txt](requirements.txt) for details. The required arguments to establish a connection are: -* `database`: the name of the DuckDB database file. May also be set to `:memory:`, which will create an in-memory database. +* `database`: the name of the DuckDB or MotherDuck database file. + - Set to `:memory:` to create an in-memory database. + - For MotherDuck, specify the database and motherduck_token. -The optional arguments are: +Additional optional arguments include: +* `motherduck_token`: a token to authenticate with MotherDuck. * `read_only`: a flag that specifies if the connection should be made in read-only mode. -This is required if multiple processes want to access the same database file at the same time. - + - This is required if multiple processes want to access the same database file simultaneously. ## Usage -In order to make use of this handler and connect to a DuckDB database in MindsDB, the following syntax can be used: +To connect to a DuckDB or MotherDuck database in MindsDB, the following syntax can be used: +### DuckDB Example ```sql CREATE DATABASE duckdb_datasource WITH engine='duckdb', parameters={ - "database":"db.duckdb" + "database": "db.duckdb" }; ``` -Now, you can use this established connection to query your database as follows: +### MotherDuck Example +```sql +CREATE DATABASE md_datasource +WITH +engine='duckdb', +parameters={ + "database": "sample_data", + "motherduck_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." +}; +``` + +Once the connection is established, you can query the database: + ```sql SELECT * FROM duckdb_datasource.my_table; -``` \ No newline at end of file +``` + +For MotherDuck: +```sql +SELECT * FROM md_datasource.movies; +``` + +By leveraging these features, MindsDB provides powerful integrations with DuckDB and MotherDuck for scalable analytics. \ No newline at end of file diff --git a/mindsdb/integrations/handlers/duckdb_handler/connection_args.py b/mindsdb/integrations/handlers/duckdb_handler/connection_args.py index e5a372f9e88..4d9591e5eb6 100644 --- a/mindsdb/integrations/handlers/duckdb_handler/connection_args.py +++ b/mindsdb/integrations/handlers/duckdb_handler/connection_args.py @@ -2,16 +2,26 @@ from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE - connection_args = OrderedDict( database={ - 'type': ARG_TYPE.STR, - 'description': 'The database file to read and write from. The special value :memory: (default) can be used to create an in-memory database.', + "type": ARG_TYPE.STR, + "description": ( + "The database file to read and write from. The special value :memory: (default) " + "can be used to create an in-memory database." + ), + }, + motherduck_token={ + "type": ARG_TYPE.STR, + "description": "Motherduck access token if want to connect motherduck database.", }, read_only={ - 'type': ARG_TYPE.BOOL, - 'description': 'A flag that specifies if the connection should be made in read-only mode.', + "type": ARG_TYPE.BOOL, + "description": ("A flag that specifies if the connection should be made in read-only mode."), }, ) -connection_args_example = OrderedDict(database='db.duckdb', read_only=True) +connection_args_example = OrderedDict( + database="sample_data", + read_only=True, + motherduck_token="ey...enKoT.SsEcCa......", +) diff --git a/mindsdb/integrations/handlers/duckdb_handler/duckdb_handler.py b/mindsdb/integrations/handlers/duckdb_handler/duckdb_handler.py index 7ae5423859c..bc407ef0575 100644 --- a/mindsdb/integrations/handlers/duckdb_handler/duckdb_handler.py +++ b/mindsdb/integrations/handlers/duckdb_handler/duckdb_handler.py @@ -19,14 +19,14 @@ class DuckDBHandler(DatabaseHandler): """This handler handles connection and execution of the DuckDB statements.""" - name = 'duckdb' + name = "duckdb" def __init__(self, name: str, **kwargs): super().__init__(name) self.parser = parse_sql - self.dialect = 'postgresql' - self.connection_data = kwargs.get('connection_data') - self.renderer = SqlalchemyRender('postgres') + self.dialect = "postgresql" + self.connection_data = kwargs.get("connection_data") + self.renderer = SqlalchemyRender("postgres") self.connection = None self.is_connected = False @@ -44,10 +44,17 @@ def connect(self) -> DuckDBPyConnection: if self.is_connected is True: return self.connection + motherduck_token = self.connection_data.get("motherduck_token") + if motherduck_token: + database = ( + f"md:{self.connection_data.get('database')}?motherduck_token={motherduck_token}&attach_mode=single" + ) + else: + database = self.connection_data.get("database") args = { - 'database': self.connection_data.get('database'), - 'read_only': self.connection_data.get('read_only'), + "database": database, + "read_only": self.connection_data.get("read_only"), } self.connection = duckdb.connect(**args) @@ -78,9 +85,7 @@ def check_connection(self) -> StatusResponse: self.connect() response.success = True except Exception as e: - logger.error( - f'Error connecting to DuckDB {self.connection_data["database"]}, {e}!' - ) + logger.error(f"Error connecting to DuckDB {self.connection_data['database']}, {e}!") response.error_message = str(e) finally: if response.success is True and need_to_close: @@ -111,17 +116,13 @@ def native_query(self, query: str) -> Response: if result: response = Response( RESPONSE_TYPE.TABLE, - data_frame=pd.DataFrame( - result, columns=[x[0] for x in cursor.description] - ), + data_frame=pd.DataFrame(result, columns=[x[0] for x in cursor.description]), ) else: connection.commit() response = Response(RESPONSE_TYPE.OK) except Exception as e: - logger.error( - f'Error running query: {query} on {self.connection_data["database"]}!' - ) + logger.error(f"Error running query: {query} on {self.connection_data['database']}!") response = Response(RESPONSE_TYPE.ERROR, error_message=str(e)) cursor.close() @@ -150,10 +151,10 @@ def get_tables(self) -> Response: Response: Names of the tables in the database. """ - q = 'SHOW TABLES;' + q = "SHOW TABLES;" result = self.native_query(q) df = result.data_frame - result.data_frame = df.rename(columns={df.columns[0]: 'table_name'}) + result.data_frame = df.rename(columns={df.columns[0]: "table_name"}) return result def get_columns(self, table_name: str) -> Response: @@ -166,5 +167,5 @@ def get_columns(self, table_name: str) -> Response: Response: Details of the table. """ - query = f'DESCRIBE {table_name};' + query = f"DESCRIBE {table_name};" return self.native_query(query) From 5e1e09256df16523e6f973c9c75346ecc5a01b24 Mon Sep 17 00:00:00 2001 From: Andrey Date: Mon, 16 Mar 2026 15:58:26 +0300 Subject: [PATCH 048/125] Support of xls files (#12282) --- mindsdb/integrations/utilities/files/file_reader.py | 13 ++++++++++++- requirements/requirements.txt | 1 + tests/scripts/check_requirements.py | 1 + 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/utilities/files/file_reader.py b/mindsdb/integrations/utilities/files/file_reader.py index 460ddd5137e..f3d013bf200 100644 --- a/mindsdb/integrations/utilities/files/file_reader.py +++ b/mindsdb/integrations/utilities/files/file_reader.py @@ -37,6 +37,7 @@ class _SINGLE_PAGE_FORMAT: @dataclass(frozen=True, slots=True) class _MULTI_PAGE_FORMAT: XLSX: str = "xlsx" + XLS: str = "xls" MULTI_PAGE_FORMAT = _MULTI_PAGE_FORMAT() @@ -155,9 +156,10 @@ def get_format_by_content(self): if file_type is not None: if file_type.mime in { "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "application/vnd.ms-excel", }: return MULTI_PAGE_FORMAT.XLSX + if file_type.mime == "application/vnd.ms-excel": + return MULTI_PAGE_FORMAT.XLS if file_type.mime == "application/pdf": return SINGLE_PAGE_FORMAT.PDF @@ -381,3 +383,12 @@ def read_xlsx( else: df = pd.read_excel(xls, sheet_name=page_name) yield page_name, df + + @staticmethod + def read_xls( + file_obj: BytesIO, + page_name: str | None = None, + only_names: bool = False, + **kwargs, + ): + return FileReader.read_xlsx(file_obj, page_name=page_name, only_names=only_names, **kwargs) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index ee0a73eabb5..231f8e681c0 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -41,6 +41,7 @@ pymupdf==1.25.2 filetype charset-normalizer openpyxl # used by pandas to read txt and xlsx files +xlrd>=2.0.1 # used by pandas to read legacy .xls files aipdf==0.0.7.0 pyarrow<=19.0.0 # used by pandas to read feather files in Files handler orjson==3.11.3 diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index 800a7fa1d4c..be5c2172baf 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -104,6 +104,7 @@ def get_requirements_with_DEP002(path): "langchain-experimental", "lxml", "openpyxl", + "xlrd", "onnxruntime", "litellm", "numba", # required in a few files for the hierarchicalforecast. Otherwise, uv may install an old version. From 5f2a5b6c34edc98e9f93f499e6f40766d6b4b48f Mon Sep 17 00:00:00 2001 From: ianu82 <86010258+ianu82@users.noreply.github.com> Date: Tue, 17 Mar 2026 09:53:38 +0000 Subject: [PATCH 049/125] Fix wrong JSON exception type in Postgres subscribe path (#12254) --- .../integrations/handlers/postgres_handler/postgres_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py b/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py index 64afe0913aa..249c31a842f 100644 --- a/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +++ b/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py @@ -616,7 +616,7 @@ def subscribe(self, stop_event, callback, table_name, columns=None, **kwargs): def process_event(event): try: row = json.loads(event.payload) - except json.JSONDecoder: + except json.JSONDecodeError: return # check column in input data From 2a76277f67ea9ba85c3bf9a1564679e350184f19 Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Thu, 19 Mar 2026 12:31:53 +0000 Subject: [PATCH 050/125] Fix MySQL get_columns schema scoping (#12260) --- .../integrations/handlers/mysql_handler/mysql_handler.py | 3 ++- tests/unit/handlers/test_mysql.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py b/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py index 73d0450954d..86882d03563 100644 --- a/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py +++ b/mindsdb/integrations/handlers/mysql_handler/mysql_handler.py @@ -382,7 +382,8 @@ def get_columns(self, table_name: str) -> Response: from information_schema.columns where - table_name = '{table_name}'; + table_name = '{table_name}' + and table_schema = DATABASE(); """ result = self.native_query(q) result.to_columns_table_response(map_type_fn=_map_type) diff --git a/tests/unit/handlers/test_mysql.py b/tests/unit/handlers/test_mysql.py index 065cc66896f..a506e0ba844 100644 --- a/tests/unit/handlers/test_mysql.py +++ b/tests/unit/handlers/test_mysql.py @@ -73,7 +73,8 @@ def get_columns_query(self): from information_schema.columns where - table_name = '{self.mock_table}'; + table_name = '{self.mock_table}' + and table_schema = DATABASE(); """ def create_handler(self): @@ -457,7 +458,8 @@ def test_get_columns(self): from information_schema.columns where - table_name = '{table_name}'; + table_name = '{table_name}' + and table_schema = DATABASE(); """ self.assertEqual(call_args, expected_sql) self.assertEqual(response, expected_response) From 0cdc14d85b15838208741ebeec0cb57de938eaca Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Thu, 19 Mar 2026 15:34:59 +0300 Subject: [PATCH 051/125] Fix MariaDB test (#12305) --- tests/unit/handlers/test_mariadb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/handlers/test_mariadb.py b/tests/unit/handlers/test_mariadb.py index ecc5fdbfd46..9d75a8dce72 100644 --- a/tests/unit/handlers/test_mariadb.py +++ b/tests/unit/handlers/test_mariadb.py @@ -59,7 +59,8 @@ def get_columns_query(self): from information_schema.columns where - table_name = '{self.mock_table}'; + table_name = '{self.mock_table}' + and table_schema = DATABASE(); """ def create_handler(self): From 706868721d70677f1487097dceb602e0458163c2 Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:40:33 +0000 Subject: [PATCH 052/125] Fix malformed BigQuery PK metadata SQL (#12252) --- .../integrations/handlers/bigquery_handler/bigquery_handler.py | 2 +- tests/unit/handlers/test_bigquery.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py b/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py index 6e5cc215dad..6c08d2aecb8 100644 --- a/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py +++ b/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py @@ -396,7 +396,7 @@ def meta_get_primary_keys(self, table_names: Optional[list] = None) -> Response: tc.table_name, kcu.column_name, kcu.ordinal_position, - tc.constraint_name, + tc.constraint_name FROM `{self.connection_data["project_id"]}.{self.connection_data["dataset"]}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS` AS tc JOIN diff --git a/tests/unit/handlers/test_bigquery.py b/tests/unit/handlers/test_bigquery.py index 2c48c87428a..026a90c1146 100644 --- a/tests/unit/handlers/test_bigquery.py +++ b/tests/unit/handlers/test_bigquery.py @@ -202,6 +202,7 @@ def test_meta_get_primary_keys_filters(self): query = self.handler.native_query.call_args[0][0] self.assertIn("AND tc.table_name IN ('orders')", query) + self.assertNotIn("tc.constraint_name,", query) def test_meta_get_foreign_keys_filters(self): self.handler.native_query = MagicMock(return_value=TableResponse(data=pd.DataFrame())) From 2801ba3d448ffbd466e782b940a1404800e1f2be Mon Sep 17 00:00:00 2001 From: Andrey Date: Thu, 19 Mar 2026 19:44:52 +0300 Subject: [PATCH 053/125] FAISS: Support several vector tables (#12246) --- mindsdb/api/executor/sql_query/sql_query.py | 14 + .../steps/fetch_dataframe_partition.py | 12 +- .../duckdb_faiss_handler.py | 577 ++++++------------ .../duckdb_faiss_table.py | 362 +++++++++++ .../duckdb_faiss_handler/faiss_index.py | 164 ++++- .../duckdb_faiss_handler/requirements.txt | 3 +- .../test_faiss_handler.py | 4 +- .../interfaces/knowledge_base/controller.py | 23 +- 8 files changed, 748 insertions(+), 411 deletions(-) create mode 100644 mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py diff --git a/mindsdb/api/executor/sql_query/sql_query.py b/mindsdb/api/executor/sql_query/sql_query.py index 763db9cceac..4e3c2a55097 100644 --- a/mindsdb/api/executor/sql_query/sql_query.py +++ b/mindsdb/api/executor/sql_query/sql_query.py @@ -20,6 +20,8 @@ ApplyTimeseriesPredictorStep, ApplyPredictorRowStep, ApplyPredictorStep, + InsertToTable, + FetchDataframeStepPartition, ) from mindsdb.api.executor.planner.exceptions import PlanningException @@ -276,6 +278,9 @@ def execute_query(self): ) if self.planner.plan.is_async and ctx.task_id is None: + # release KB locks before inserting in background + self.release_kb_lock(steps) + # add to task self.run_query.add_to_task() # return query info @@ -340,5 +345,14 @@ def execute_step(self, step, steps_data=None): return handler(self, steps_data=steps_data).call(step) + def release_kb_lock(self, steps): + # find knowledge bases that are used as tables to insert. + # then release locks of vector for these knowledge bases + for step in steps: + if isinstance(step, InsertToTable): + self.session.kb_controller.release_lock(step.table, project_name=self.database) + if isinstance(step, FetchDataframeStepPartition): + self.release_kb_lock(step.steps) + SQLQuery.register_steps() diff --git a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py index 77f53fc8bd3..9775a2867e9 100644 --- a/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py +++ b/mindsdb/api/executor/sql_query/steps/fetch_dataframe_partition.py @@ -91,10 +91,14 @@ def call(self, step: FetchDataframeStepPartition) -> ResultSet: use_threads = False on_error = step.params.get("error", "raise") - if use_threads: - return self.fetch_threads(run_query, query, thread_count=thread_count, on_error=on_error) - else: - return self.fetch_iterate(run_query, query, on_error=on_error) + try: + if use_threads: + return self.fetch_threads(run_query, query, thread_count=thread_count, on_error=on_error) + else: + return self.fetch_iterate(run_query, query, on_error=on_error) + finally: + # release KB locks after inserting in background + self.sql_query.release_kb_lock(self.substeps) def repeat_till_reach_limit(self, step, limit): first_table_limit = limit * 2 diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py index fc413f14d68..1cabd09ae79 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py @@ -1,21 +1,15 @@ import os -from typing import List +import re +import shutil +import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import List, Iterator import pandas as pd -import orjson -import duckdb -from mindsdb_sql_parser.ast import ( - Select, - Delete, - Identifier, - BinaryOperation, - Constant, - NullConstant, - Star, - Tuple as AstTuple, - Function, - TypeCast, -) + from mindsdb.integrations.libs.response import ( RESPONSE_TYPE, @@ -25,7 +19,6 @@ from mindsdb.integrations.libs.vectordatabase_handler import ( FilterCondition, VectorStoreHandler, - FilterOperator, ) from mindsdb.integrations.libs.keyword_search_base import KeywordSearchBase from mindsdb.integrations.utilities.sql_utils import KeywordSearchArgs @@ -33,11 +26,21 @@ from mindsdb.utilities import log from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender -from .faiss_index import FaissIVFIndex +from .duckdb_faiss_table import DuckDBFaissTable logger = log.getLogger(__name__) +TABLE_CACHE_TTL_SECONDS = 60 + + +@dataclass +class TableCacheEntry: + table: DuckDBFaissTable + last_used_ts: float + in_use_count: int = 0 + + class DuckDBFaissHandler(VectorStoreHandler, KeywordSearchBase): """This handler handles connection and execution of DuckDB with Faiss vector indexing.""" @@ -61,173 +64,206 @@ def __init__(self, name: str, **kwargs): raise ValueError(f"Persist directory {self.persist_directory} does not exist") else: # Use default handler storage - self.persist_directory = self.handler_storage.folder_get("data") + self.persist_directory = self.handler_storage.folder_get("") self._use_handler_storage = True - # DuckDB connection - self.connection = None - self.is_connected = False - - # Initialize storage paths - self.duckdb_path = os.path.join(self.persist_directory, "duckdb.db") - self.faiss_index_path = self.persist_directory - self.connect() - - # check keyword index - self.is_kw_index_enabled = False - with self.connection.cursor() as cur: - # check index exists - df = cur.execute( - "SELECT * FROM information_schema.schemata WHERE schema_name = 'fts_main_meta_data'" - ).fetchdf() - if len(df) > 0: - self.is_kw_index_enabled = True - - def connect(self) -> duckdb.DuckDBPyConnection: - """Connect to DuckDB database.""" - if self.is_connected: - return self.connection + Path(self.persist_directory).mkdir(parents=True, exist_ok=True) - try: - self.connection = duckdb.connect(self.duckdb_path) - self.faiss_index = FaissIVFIndex(self.faiss_index_path, self.connection_data) - self.is_connected = True + self.tables_cache = {} + self.tables_cache_lock = threading.Lock() - logger.info("Connected to DuckDB database") - return self.connection + def connect(self): + """ + Handler readiness check. + Must not open long-lived DuckDB/FAISS resources; tables are opened per operation. + """ - except Exception as e: - logger.error(f"Error connecting to DuckDB: {e}") - raise + self.is_connected = True + return True def disconnect(self): - """Close DuckDB connection.""" - if self.is_connected and self.connection: - self.connection.close() - self.faiss_index.close() - self.is_connected = False + with self.tables_cache_lock: + for item in self.tables_cache.values(): + item.table.close() - def create_table(self, table_name: str, if_not_exists=True): - with self.connection.cursor() as cur: - cur.execute("CREATE SEQUENCE IF NOT EXISTS faiss_id_sequence START 1") - - cur.execute(""" - CREATE TABLE IF NOT EXISTS meta_data ( - faiss_id INTEGER PRIMARY KEY DEFAULT nextval('faiss_id_sequence'), -- id in FAISS index - id TEXT NOT NULL, -- chunk id - content TEXT, - metadata JSON - ) - """) - - def drop_table(self, table_name: str, if_exists=True): - """Drop table from both DuckDB and Faiss.""" - with self.connection.cursor() as cur: - drop_sql = f"DROP TABLE {'IF EXISTS' if if_exists else ''} meta_data" - cur.execute(drop_sql) + self.tables_cache = {} - if self.faiss_index: - self.faiss_index.drop() - - def create_index(self, table_name: str, type: str = "ivf_file", nlist: int = None, train_count: int = None): - if type not in ("ivf", "ivf_file"): - raise NotImplementedError("Only ivf or ivf_file indexes are supported") - - self.faiss_index.create_index(type, nlist=nlist, train_count=train_count) + def check_connection(self) -> Response: + """Check the connection to the database.""" + try: + if not self.is_connected: + self.connect() + return StatusResponse(RESPONSE_TYPE.OK) + except Exception as e: + logger.error(f"Connection check failed: {e}") + return StatusResponse(RESPONSE_TYPE.ERROR, error_message=str(e)) - def insert(self, table_name: str, data: pd.DataFrame): - """Insert data into both DuckDB and Faiss.""" + def __del__(self): + """Cleanup on deletion.""" + self.disconnect() + + # -- manage tables -- + + @staticmethod + def _validate_table_name(table_name: str) -> None: + if table_name in (".", ".."): + raise ValueError("Invalid table_name") + if "/" in table_name or "\\" in table_name: + raise ValueError("table_name must not contain path separators") + if not re.fullmatch(r"[A-Za-z0-9_-]+", table_name): + raise ValueError( + "Invalid table_name: only letters, digits, '_' and '-' are allowed (no spaces, dots, or other symbols)" + ) - if self.is_kw_index_enabled: - # drop index, it will be created before a first keyword search - self.drop_kw_index() + def get_table_dir(self, table_name: str) -> Path: + """ + Get folder for a table name + Prevent path traversal by requiring the resolved path to stay within persist_directory. + """ + root = Path(self.persist_directory).resolve() + table_dir = (Path(self.persist_directory) / table_name).resolve() + if table_dir == root or root not in table_dir.parents: + raise ValueError("Invalid table_name path") + return table_dir + + def _close_cached_table(self, table_name: str) -> None: + entry = self.tables_cache.pop(table_name, None) + if entry is None: + return + try: + entry.table.close() + except Exception: + logger.exception("Failed to close cached table '%s'", table_name) - with self.connection.cursor() as cur: - df_ids = cur.execute(""" - insert into meta_data (id, content, metadata) ( - select id, content, metadata from data - ) - RETURNING faiss_id, id - """).fetchdf() + def _close_old_tables_cache(self): + """ + Close stale cached tables that have not been used for more than TTL. + Tables that are currently in use are never closed by pruning. + """ + if not self.tables_cache: + return + + with self.tables_cache_lock: + now_ts = time.time() + to_close: List[str] = [] + for table_name, entry in self.tables_cache.items(): + if entry.in_use_count > 0: + continue + if now_ts - entry.last_used_ts > TABLE_CACHE_TTL_SECONDS: + to_close.append(table_name) + + for table_name in to_close: + self._close_cached_table(table_name) + + @contextmanager + def open_table(self, table_name: str) -> Iterator[DuckDBFaissTable]: + """ + Open DuckDB and Faiss resources scoped to one vector table. + Must always be closed after use to avoid long-lived locks / RAM usage. - data = data.merge(df_ids, on="id") + If `use_cache=True` and `table.cache_required` is True, the opened table is cached + in `self.tables_cache` and re-used across calls. Cached tables are pruned if they + haven't been used for more than TABLE_CACHE_TTL_SECONDS. + """ + table_dir = self.get_table_dir(table_name) + if not table_dir.exists(): + raise ValueError(f"Table '{table_name}' does not exist") - vectors = data["embeddings"] - ids = data["faiss_id"] + with self.tables_cache_lock: + entry = self.tables_cache.get(table_name) - self.faiss_index.insert(list(vectors), list(ids)) - self._sync() + if entry is not None: + table = entry.table + else: + table = DuckDBFaissTable(table_name=table_name, table_dir=table_dir, handler=self).open() - # def upsert(self, table_name: str, data: pd.DataFrame): - # # delete by ids and insert - # ids = list(data['id']) - # self.delete(table_name, [FilterCondition(column='id', op=FilterOperator.IN, value=ids)]) - # self.insert(table_name, data) + if table.cache_required: + entry = TableCacheEntry(table=table, last_used_ts=time.time()) + self.tables_cache[table_name] = entry - def select( - self, - table_name: str, - columns: List[str] = None, - conditions: List[FilterCondition] = None, - offset: int = None, - limit: int = None, - ) -> pd.DataFrame: - """Select data with hybrid search logic.""" - - vector_filter = None - meta_filters = [] - if conditions is None: - conditions = [] - for condition in conditions: - if condition.column == "embeddings": - vector_filter = condition + try: + if entry: + with self.tables_cache_lock: + entry.in_use_count += 1 + + yield table + finally: + if entry: + entry.in_use_count -= 1 + entry.last_used_ts = time.time() else: - meta_filters.append(condition) + table.close() - if vector_filter is None: - # If only metadata in filter: - # query duckdb only - return self._select_from_metadata(meta_filters=meta_filters, limit=limit).drop("faiss_id", axis=1) + self._close_old_tables_cache() - # vector_filter is not None - if not meta_filters: - # If only content in filter: query faiss and attach to metadata - return self._select_with_vector(vector_filter=vector_filter, limit=limit) + def create_table(self, table_name: str, if_not_exists=True): + self._validate_table_name(table_name) + table_dir = self.get_table_dir(table_name) + if table_dir.exists() and not if_not_exists: + raise ValueError(f"Vector table '{table_name}' already exists") + table_dir.mkdir(parents=True, exist_ok=True) + + with self.open_table(table_name) as table: + with table.connection.cursor() as cur: + cur.execute("CREATE SEQUENCE IF NOT EXISTS faiss_id_sequence START 1") + cur.execute(""" + CREATE TABLE IF NOT EXISTS meta_data ( + faiss_id INTEGER PRIMARY KEY DEFAULT nextval('faiss_id_sequence'), -- id in FAISS index + id TEXT NOT NULL, -- chunk id + content TEXT, + metadata JSON + ) + """) - """ - If metadata + content: - Query faiss, use limit = 1000 - Query duckdb with `id in (...)` - If count of results is less than input LIMIT value - Repeat the search with increased limit value - Limit value for step = 1000 * 5^i (1000, 2000, 25000, 125000 …) - """ + def drop_table(self, table_name: str, if_exists=True): + """Drop table from both DuckDB and Faiss.""" + table_dir = self.get_table_dir(table_name) - df = pd.DataFrame() + if not table_dir.exists(): + if if_exists: + return + raise ValueError(f"Vector table '{table_name}' does not exist") - total_size = self.get_total_size() + with self.tables_cache_lock: + self._close_cached_table(table_name) - for i in range(10): - batch_size = 1000 * 5**i + shutil.rmtree(table_dir, ignore_errors=False) - # TODO implement reverse search: - # if batch_size > 25% of db: search metadata first and then in faiss by list of ids + if self._use_handler_storage: + self.handler_storage.folder_sync(table_name) - df = self._select_with_vector(vector_filter=vector_filter, meta_filters=meta_filters, limit=batch_size) - if batch_size >= total_size or len(df) >= limit: - break + def get_tables(self) -> Response: + """Get list of tables.""" + rows = [] + root = Path(self.persist_directory) + if root.exists(): + for item in root.iterdir(): + if not item.is_dir(): + continue + rows.append({"table_name": item.name}) + df = pd.DataFrame(rows, columns=["table_name"]) + return Response(RESPONSE_TYPE.TABLE, data_frame=df) - return df[:limit] + # -- table methods -- - def create_kw_index(self): - with self.connection.cursor() as cur: - cur.execute("PRAGMA create_fts_index('meta_data', 'id', 'content')") - self.is_kw_index_enabled = True + def create_index(self, table_name: str, type: str = "ivf_file", nlist: int = None, train_count: int = None): + with self.open_table(table_name) as table: + table.create_index(type=type, nlist=nlist, train_count=train_count) - def drop_kw_index(self): - with self.connection.cursor() as cur: - cur.execute("pragma drop_fts_index('meta_data')") - self.is_kw_index_enabled = False + def insert(self, table_name: str, data: pd.DataFrame): + with self.open_table(table_name) as table: + table.insert(data) + + def select( + self, + table_name: str, + columns: List[str] = None, + conditions: List[FilterCondition] = None, + offset: int = None, + limit: int = None, + ) -> pd.DataFrame: + with self.open_table(table_name) as table: + return table.select(conditions=conditions, offset=offset, limit=limit) def keyword_select( self, @@ -238,229 +274,20 @@ def keyword_select( limit: int = None, keyword_search_args: KeywordSearchArgs = None, ) -> pd.DataFrame: - if not self.is_kw_index_enabled: - # keyword search is used for first time: create index - self.create_kw_index() - - with self.connection.cursor() as cur: - where_clause = self._translate_filters(conditions) - - score = Function( - namespace="fts_main_meta_data", - op="match_bm25", - args=[ - Identifier("id"), - Constant(keyword_search_args.query), - BinaryOperation(op=":=", args=[Identifier("fields"), Constant(keyword_search_args.column)]), - ], + with self.open_table(table_name) as table: + return table.keyword_select( + conditions=conditions, + offset=offset, + limit=limit, + keyword_search_args=keyword_search_args, ) - no_emtpy_score = BinaryOperation(op="is not", args=[score, NullConstant()]) - if where_clause: - where_clause = BinaryOperation(op="and", args=[where_clause, no_emtpy_score]) - else: - where_clause = no_emtpy_score - - query = Select( - targets=[Star(), BinaryOperation(op="-", args=[Constant(1), score], alias=Identifier("distance"))], - from_table=Identifier("meta_data"), - where=where_clause, - ) - - sql = self.renderer.get_string(query, with_failback=True) - cur.execute(sql) - df = cur.fetchdf() - df["metadata"] = df["metadata"].apply(orjson.loads) - return df - - def get_total_size(self): - with self.connection.cursor() as cur: - cur.execute("select count(1) size from meta_data") - df = cur.fetchdf() - return df["size"].iloc[0] - - def _select_with_vector(self, vector_filter: FilterCondition, meta_filters=None, limit=None) -> pd.DataFrame: - embedding = vector_filter.value - if isinstance(embedding, str): - embedding = orjson.loads(embedding) - - distances, faiss_ids = self.faiss_index.search(embedding, limit or 100) - - # Fetch full data from DuckDB - if len(faiss_ids) > 0: - # ids = [str(idx) for idx in faiss_ids] - meta_df = self._select_from_metadata(faiss_ids=faiss_ids, meta_filters=meta_filters) - vector_df = pd.DataFrame({"faiss_id": faiss_ids, "distance": distances}) - return vector_df.merge(meta_df, on="faiss_id").drop("faiss_id", axis=1).sort_values(by="distance") - - return pd.DataFrame([], columns=["id", "content", "metadata", "distance"]) - - def _select_from_metadata(self, faiss_ids=None, meta_filters=None, limit=None): - query = Select( - targets=[Star()], - from_table=Identifier("meta_data"), - ) - - where_clause = self._translate_filters(meta_filters) - - if faiss_ids: - # TODO what if ids list is too long - split search into batches - in_filter = BinaryOperation( - op="IN", args=[Identifier("faiss_id"), AstTuple([Constant(i) for i in faiss_ids])] - ) - # split into chunks - chunk_size = 10000 - if len(faiss_ids) > chunk_size: - dfs = [] - chunk = 0 - total = 0 - while chunk * chunk_size < len(faiss_ids): - # create results with partition - ids = faiss_ids[chunk * chunk_size : (chunk + 1) * chunk_size] - chunk += 1 - df = self._select_from_metadata(faiss_ids=ids, meta_filters=meta_filters, limit=limit) - total += len(df) - if limit is not None and limit <= total: - # cut the extra from the end - df = df[: -(total - limit)] - dfs.append(df) - break - if len(df) > 0: - dfs.append(df) - if len(dfs) == 0: - return pd.DataFrame([], columns=["faiss_id", "id", "content", "metadata"]) - return pd.concat(dfs) - - if where_clause is None: - where_clause = in_filter - else: - where_clause = BinaryOperation(op="AND", args=[where_clause, in_filter]) - - if limit is not None: - query.limit = Constant(limit) - - query.where = where_clause - - with self.connection.cursor() as cur: - sql = self.renderer.get_string(query, with_failback=True) - cur.execute(sql) - df = cur.fetchdf() - df["metadata"] = df["metadata"].apply(orjson.loads) - return df - - def _translate_filters(self, meta_filters): - if not meta_filters: - return None - - where_clause = None - for item in meta_filters: - parts = item.column.split(".") - key = Identifier(parts[0]) - - # converts 'col.el1.el2' to col->'el1'->>'el2' - if len(parts) > 1: - # intermediate elements - for el in parts[1:-1]: - key = BinaryOperation(op="->", args=[key, Constant(el)]) - - # last element - key = BinaryOperation(op="->>", args=[key, Constant(parts[-1])]) - - is_orig_id = item.column == "metadata._original_doc_id" - - type_cast = None - value = item.value - - if isinstance(value, list) and len(value) > 0 and item.op in (FilterOperator.IN, FilterOperator.NOT_IN): - if is_orig_id: - # convert to str - item.value = [str(i) for i in value] - value = item.value[0] - elif is_orig_id: - if not isinstance(value, str): - value = item.value = str(item.value) - - if isinstance(value, int): - type_cast = "int" - elif isinstance(value, float): - type_cast = "float" - - if type_cast is not None: - key = TypeCast(type_cast, key) - - if item.op in (FilterOperator.NOT_IN, FilterOperator.IN): - values = [Constant(i) for i in item.value] - value = AstTuple(values) - else: - value = Constant(item.value) - - condition = BinaryOperation(op=item.op.value, args=[key, value]) - - if where_clause is None: - where_clause = condition - else: - where_clause = BinaryOperation(op="AND", args=[where_clause, condition]) - return where_clause - - def delete(self, table_name: str, conditions: List[FilterCondition] = None) -> Response: + def delete(self, table_name: str, conditions: List[FilterCondition] = None): """Delete data from both DuckDB and Faiss.""" - with self.connection.cursor() as cur: - where_clause = self._translate_filters(conditions) - - query = Select(targets=[Identifier("faiss_id")], from_table=Identifier("meta_data"), where=where_clause) - cur.execute(self.renderer.get_string(query, with_failback=True)) - df = cur.fetchdf() - ids = list(df["faiss_id"]) - - self.faiss_index.delete_ids(ids) - - query = Delete(table=Identifier("meta_data"), where=where_clause) - cur.execute(self.renderer.get_string(query, with_failback=True)) - - self._sync() + with self.open_table(table_name) as table: + table.delete(conditions) def get_dimension(self, table_name: str) -> int: - if self.faiss_index and self.faiss_index.index is not None: - return self.faiss_index.dim - - def _sync(self): - """Sync the database to disk if using persistent storage""" - self.faiss_index.dump() - if self._use_handler_storage: - self.handler_storage.folder_sync(self.persist_directory) - - def get_tables(self) -> Response: - """Get list of tables.""" - with self.connection.cursor() as cur: - df = cur.execute("show tables").fetchdf() - df = df.rename(columns={"name": "table_name"}) - - return Response(RESPONSE_TYPE.TABLE, data_frame=df) - - def check_connection(self) -> Response: - """Check the connection to the database.""" - try: - if not self.is_connected: - self.connect() - return StatusResponse(RESPONSE_TYPE.OK) - except Exception as e: - logger.error(f"Connection check failed: {e}") - return StatusResponse(RESPONSE_TYPE.ERROR, error_message=str(e)) - - def native_query(self, query: str) -> Response: - """Execute a native SQL query.""" - try: - with self.connection.cursor() as cur: - cur.execute(query) - result = cur.fetchdf() - return Response(RESPONSE_TYPE.TABLE, data_frame=result) - except Exception as e: - logger.error(f"Error executing native query: {e}") - return Response(RESPONSE_TYPE.ERROR, error_message=str(e)) - - def __del__(self): - """Cleanup on deletion.""" - if self.is_connected: - self._sync() - self.disconnect() + with self.open_table(table_name) as table: + return table.get_dimension() diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py new file mode 100644 index 00000000000..b8a5324ad63 --- /dev/null +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py @@ -0,0 +1,362 @@ +from pathlib import Path +from typing import List + +import pandas as pd +import orjson +import duckdb +from mindsdb_sql_parser.ast import ( + Select, + Delete, + Identifier, + BinaryOperation, + Constant, + NullConstant, + Star, + Tuple as AstTuple, + Function, + TypeCast, +) + + +from mindsdb.integrations.libs.vectordatabase_handler import ( + FilterCondition, + FilterOperator, +) +from mindsdb.integrations.utilities.sql_utils import KeywordSearchArgs + +from mindsdb.utilities import log + +from .faiss_index import FaissIVFIndex + +logger = log.getLogger(__name__) + + +class DuckDBFaissTable: + def __init__(self, table_name: str, table_dir: Path, handler): + self.table_name = table_name + self.handler = handler + self.connection: duckdb.DuckDBPyConnection | None = None + self.faiss_index: FaissIVFIndex | None = None + self.table_dir = table_dir + self.is_kw_index_enabled = False + self.cache_required = False + + def open(self) -> "DuckDBFaissTable": + duckdb_path = self.table_dir / "duckdb.db" + self.connection = duckdb.connect(str(duckdb_path)) + self.faiss_index = FaissIVFIndex(str(self.table_dir), self.handler.connection_data) + + self.cache_required = self.faiss_index.lock_required and self.faiss_index.get_size() > 100_000 + + # check keyword index + with self.connection.cursor() as cur: + # check index exists + df = cur.execute( + "SELECT * FROM information_schema.schemata WHERE schema_name = 'fts_main_meta_data'" + ).fetchdf() + if len(df) > 0: + self.is_kw_index_enabled = True + + return self + + def close(self) -> None: + self.faiss_index.close() + self.connection.close() + + def _create_kw_index(self): + with self.connection.cursor() as cur: + cur.execute("PRAGMA create_fts_index('meta_data', 'id', 'content')") + self.is_kw_index_enabled = True + + def _drop_kw_index(self): + with self.connection.cursor() as cur: + cur.execute("pragma drop_fts_index('meta_data')") + self.is_kw_index_enabled = False + + def _sync(self, dump_faiss=True): + if dump_faiss: + self.faiss_index.dump() + + if self.handler._use_handler_storage: + self.handler.handler_storage.folder_sync(self.table_name) + + def create_index(self, type: str = "ivf_file", nlist: int = None, train_count: int = None): + if type not in ("ivf", "ivf_file"): + raise NotImplementedError("Only ivf or ivf_file indexes are supported") + self.faiss_index.create_index(type, nlist=nlist, train_count=train_count) + # index was already saved. don't dump it twice + self._sync(dump_faiss=False) + + def insert(self, data: pd.DataFrame): + """Insert data into both DuckDB and Faiss.""" + + if self.is_kw_index_enabled: + # drop index, it will be created before a first keyword search + self._drop_kw_index() + + with self.connection.cursor() as cur: + df_ids = cur.execute(""" + insert into meta_data (id, content, metadata) ( + select id, content, metadata from data + ) + RETURNING faiss_id, id + """).fetchdf() + + data = data.merge(df_ids, on="id") + + vectors = data["embeddings"] + ids = data["faiss_id"] + + self.faiss_index.insert(list(vectors), list(ids)) + self._sync() + + def select( + self, + conditions: List[FilterCondition] = None, + offset: int = None, + limit: int = None, + ) -> pd.DataFrame: + """Select data with hybrid search logic.""" + + vector_filter = None + meta_filters = [] + if conditions is None: + conditions = [] + for condition in conditions: + if condition.column == "embeddings": + vector_filter = condition + else: + meta_filters.append(condition) + + if vector_filter is None: + # If only metadata in filter: + # query duckdb only + return self._select_from_metadata(meta_filters=meta_filters, limit=limit).drop("faiss_id", axis=1) + + # vector_filter is not None + if not meta_filters: + # If only content in filter: query faiss and attach to metadata + return self._select_with_vector(vector_filter=vector_filter, limit=limit) + + """ + If metadata + content: + Query faiss, use limit = 1000 + Query duckdb with `id in (...)` + If count of results is less than input LIMIT value + Repeat the search with increased limit value + Limit value for step = 1000 * 5^i (1000, 2000, 25000, 125000 …) + """ + + df = pd.DataFrame() + + total_size = self.get_total_size() + + for i in range(10): + batch_size = 1000 * 5**i + + # TODO implement reverse search: + # if batch_size > 25% of db: search metadata first and then in faiss by list of ids + + df = self._select_with_vector(vector_filter=vector_filter, meta_filters=meta_filters, limit=batch_size) + if batch_size >= total_size or len(df) >= limit: + break + + return df[:limit] + + def keyword_select( + self, + conditions: List[FilterCondition] = None, + offset: int = None, + limit: int = None, + keyword_search_args: KeywordSearchArgs = None, + ) -> pd.DataFrame: + if not self.is_kw_index_enabled: + # keyword search is used for first time: create index + self._create_kw_index() + + with self.connection.cursor() as cur: + where_clause = self._translate_filters(conditions) + + score = Function( + namespace="fts_main_meta_data", + op="match_bm25", + args=[ + Identifier("id"), + Constant(keyword_search_args.query), + BinaryOperation(op=":=", args=[Identifier("fields"), Constant(keyword_search_args.column)]), + ], + ) + + no_emtpy_score = BinaryOperation(op="is not", args=[score, NullConstant()]) + if where_clause: + where_clause = BinaryOperation(op="and", args=[where_clause, no_emtpy_score]) + else: + where_clause = no_emtpy_score + + query = Select( + targets=[Star(), BinaryOperation(op="-", args=[Constant(1), score], alias=Identifier("distance"))], + from_table=Identifier("meta_data"), + where=where_clause, + ) + + if limit is not None: + query.limit = Constant(limit) + + if offset is not None: + query.offset = Constant(offset) + + sql = self.handler.renderer.get_string(query, with_failback=True) + cur.execute(sql) + df = cur.fetchdf() + df["metadata"] = df["metadata"].apply(orjson.loads) + return df + + def delete(self, conditions: List[FilterCondition] = None): + """Delete data from both DuckDB and Faiss.""" + with self.connection.cursor() as cur: + where_clause = self._translate_filters(conditions) + + query = Select(targets=[Identifier("faiss_id")], from_table=Identifier("meta_data"), where=where_clause) + cur.execute(self.handler.renderer.get_string(query, with_failback=True)) + df = cur.fetchdf() + ids = list(df["faiss_id"]) + + self.faiss_index.delete_ids(ids) + + query = Delete(table=Identifier("meta_data"), where=where_clause) + cur.execute(self.handler.renderer.get_string(query, with_failback=True)) + + self._sync() + + def get_dimension(self) -> int: + if self.faiss_index and self.faiss_index.index is not None: + return self.faiss_index.dim + + def get_total_size(self): + with self.connection.cursor() as cur: + cur.execute("select count(1) size from meta_data") + df = cur.fetchdf() + return df["size"].iloc[0] + + def _select_with_vector(self, vector_filter: FilterCondition, meta_filters=None, limit=None) -> pd.DataFrame: + embedding = vector_filter.value + if isinstance(embedding, str): + embedding = orjson.loads(embedding) + + distances, faiss_ids = self.faiss_index.search(embedding, limit or 100) + + # Fetch full data from DuckDB + if len(faiss_ids) > 0: + # ids = [str(idx) for idx in faiss_ids] + meta_df = self._select_from_metadata(faiss_ids=faiss_ids, meta_filters=meta_filters) + vector_df = pd.DataFrame({"faiss_id": faiss_ids, "distance": distances}) + return vector_df.merge(meta_df, on="faiss_id").drop("faiss_id", axis=1).sort_values(by="distance") + + return pd.DataFrame([], columns=["id", "content", "metadata", "distance"]) + + def _select_from_metadata(self, faiss_ids=None, meta_filters=None, limit=None): + query = Select( + targets=[Star()], + from_table=Identifier("meta_data"), + ) + + where_clause = self._translate_filters(meta_filters) + + if faiss_ids: + # TODO what if ids list is too long - split search into batches + in_filter = BinaryOperation( + op="IN", args=[Identifier("faiss_id"), AstTuple([Constant(i) for i in faiss_ids])] + ) + # split into chunks + chunk_size = 10000 + if len(faiss_ids) > chunk_size: + dfs = [] + chunk = 0 + total = 0 + while chunk * chunk_size < len(faiss_ids): + # create results with partition + ids = faiss_ids[chunk * chunk_size : (chunk + 1) * chunk_size] + chunk += 1 + df = self._select_from_metadata(faiss_ids=ids, meta_filters=meta_filters, limit=limit) + total += len(df) + if limit is not None and limit <= total: + # cut the extra from the end + df = df[: -(total - limit)] + dfs.append(df) + break + if len(df) > 0: + dfs.append(df) + if len(dfs) == 0: + return pd.DataFrame([], columns=["faiss_id", "id", "content", "metadata"]) + return pd.concat(dfs) + + if where_clause is None: + where_clause = in_filter + else: + where_clause = BinaryOperation(op="AND", args=[where_clause, in_filter]) + + if limit is not None: + query.limit = Constant(limit) + + query.where = where_clause + + with self.connection.cursor() as cur: + sql = self.handler.renderer.get_string(query, with_failback=True) + cur.execute(sql) + df = cur.fetchdf() + df["metadata"] = df["metadata"].apply(orjson.loads) + return df + + def _translate_filters(self, meta_filters): + if not meta_filters: + return None + + where_clause = None + for item in meta_filters: + parts = item.column.split(".") + key = Identifier(parts[0]) + + # converts 'col.el1.el2' to col->'el1'->>'el2' + if len(parts) > 1: + # intermediate elements + for el in parts[1:-1]: + key = BinaryOperation(op="->", args=[key, Constant(el)]) + + # last element + key = BinaryOperation(op="->>", args=[key, Constant(parts[-1])]) + + is_orig_id = item.column == "metadata._original_doc_id" + + type_cast = None + value = item.value + + if isinstance(value, list) and len(value) > 0 and item.op in (FilterOperator.IN, FilterOperator.NOT_IN): + if is_orig_id: + # convert to str + item.value = [str(i) for i in value] + value = item.value[0] + elif is_orig_id: + if not isinstance(value, str): + value = item.value = str(item.value) + + if isinstance(value, int): + type_cast = "int" + elif isinstance(value, float): + type_cast = "float" + + if type_cast is not None: + key = TypeCast(type_cast, key) + + if item.op in (FilterOperator.NOT_IN, FilterOperator.IN): + values = [Constant(i) for i in item.value] + value = AstTuple(values) + else: + value = Constant(item.value) + + condition = BinaryOperation(op=item.op.value, args=[key, value]) + + if where_clause is None: + where_clause = condition + else: + where_clause = BinaryOperation(op="AND", args=[where_clause, condition]) + return where_clause diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py index e596eaf0cf6..3673a982f6f 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py @@ -4,13 +4,21 @@ import psutil from pathlib import Path -import portalocker +try: + import fcntl +except ImportError: + fcntl = None import faiss # faiss or faiss-gpu -from faiss.contrib.ondisk import merge_ondisk + +from mindsdb.utilities import log + from pydantic import BaseModel +logger = log.getLogger(__name__) + + def _normalize_rows(x: np.ndarray) -> np.ndarray: norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12 return x / norms @@ -25,6 +33,54 @@ class FaissParams(BaseModel): hnsw_ef_search: int | None = 64 +def merge_ondisk(trained_index: faiss.Index, shard_fnames: List[str], ivfdata_fname: str, shift_ids=False) -> None: + """ + Modified version of faiss.contrib.ondisk.merge_ondisk. Prevents leaving orphan memory mapped shard files + + Add the contents of the indexes stored in shard_fnames into the index trained_index. + The on-disk data is stored in ivfdata_fname + """ + assert not isinstance(trained_index, faiss.IndexIVFPQR), "IndexIVFPQR is not supported as an on disk index." + # merge the images into an on-disk index + # first load the inverted lists + ivfs = [] + indexes = [] + + for fname in shard_fnames: + # the IO_FLAG_MMAP is to avoid actually loading the data + # thus the total size of the inverted lists can exceed the available RAM + logger.info("read " + fname) + index = faiss.read_index(fname, faiss.IO_FLAG_MMAP) + index_ivf = faiss.extract_index_ivf(index) + ivfs.append(index_ivf.invlists) + + indexes.append(index) + + # construct the output index + index = trained_index + index_ivf = faiss.extract_index_ivf(index) + + assert index.ntotal == 0, "works only on empty index" + + # prepare the output inverted lists. They will be written to merged_index.ivfdata + invlists = faiss.OnDiskInvertedLists(index_ivf.nlist, index_ivf.code_size, ivfdata_fname) + + # merge all the inverted lists + ivf_vector = faiss.InvertedListsPtrVector() + for ivf in ivfs: + ivf_vector.push_back(ivf) + + logger.info("merge %d inverted lists " % ivf_vector.size()) + ntotal = invlists.merge_from_multiple(ivf_vector.data(), ivf_vector.size(), shift_ids) + + # now replace the inverted lists in the output index + index.ntotal = index_ivf.ntotal = ntotal + index_ivf.replace_invlists(invlists, True) + invlists.this.disown() + + del indexes + + class FaissIndex: def __init__(self, path: str, config: dict): self._normalize_vectors = False @@ -52,6 +108,7 @@ def __init__(self, path: str, config: dict): self.index_type = "flat" self.dim = None self.index_fd = None + self.lock_required = True recover_path = Path(self.path).parent / "recover" if recover_path.exists(): @@ -65,11 +122,13 @@ def __init__(self, path: str, config: dict): self._load_index() def _lock_index(self): - if os.name != "nt": + if not self.lock_required: + return + if os.name != "nt" and fcntl: self.index_fd = open(self.path, "rb") try: - portalocker.lock(self.index_fd, portalocker.LOCK_EX | portalocker.LOCK_NB) - except portalocker.exceptions.AlreadyLocked: + fcntl.flock(self.index_fd, fcntl.LOCK_EX | fcntl.LOCK_NB) + except OSError: raise ValueError(f"Index is already used: {self.path}") def _load_index(self): @@ -82,7 +141,12 @@ def _load_index(self): available_ram = psutil.virtual_memory().available if required_ram > _1gb and available_ram < required_ram: to_free_gb = round((required_ram - available_ram) / _1gb, 2) - raise ValueError(f"Unable load FAISS index into RAM, free up al least : {to_free_gb} Gb") + raise ValueError(f"Unable load FAISS index into RAM, free up at least : {to_free_gb} Gb") + + # check ivf_file before loading index and locking it + index_merged = Path(self.path).parent / "faiss_index_merged" + if index_merged.exists(): + self.lock_required = False self._lock_index() @@ -93,13 +157,19 @@ def _load_index(self): if hasattr(index, "index"): index = faiss.downcast_index(index.index) if isinstance(index, faiss.IndexIVFFlat): - self.index_type = "ivf" + if index_merged.exists(): + self.index_type = "ivf_file" + else: + self.index_type = "ivf" def close(self): if self.index_fd is not None: self.index_fd.close() self.index = None + def __del__(self): + self.close() + def _build_flat_index(self): # TODO option to create hnsw @@ -128,6 +198,9 @@ def _check_ram_usage(self, count_vectors, index_type: str = "flat", m=32, nlist= required = (self.dim * 4 + m * 2 * 4) * count_vectors case "ivf": required = (self.dim * 4 + 8) * count_vectors + self.dim * 4 * nlist + case "ivf_file": + # don't restrict for IVF file + required = 0 case _: raise ValueError(f"Unknown index type: {index_type}") @@ -337,29 +410,29 @@ def _get_dump_vector_files(self, dump_path): vec_files.sort() return vec_files - def _create_ivf_index(self, path, train_count, nlist): + def _create_ivf_index(self, dump_path, train_count, nlist): """ Build an in-memory IVF index - :param path: Directory containing memmap files + :param dump_path: Directory containing memmap files :param train_count: Number of vectors to use for training :param nlist: number of clusters for IVF """ # Load ids - ids_path = path / "ids.mmap" + ids_path = dump_path / "ids.mmap" if not os.path.exists(ids_path): raise FileNotFoundError(f"Missing ids memmap: {ids_path}") ids = np.fromfile(ids_path, dtype="int64") - ivf = self._train_ivf(path, nlist=nlist, train_count=train_count) + ivf = self._train_ivf(dump_path, nlist=nlist, train_count=train_count) - vec_files = self._get_dump_vector_files(path) + vec_files = self._get_dump_vector_files(dump_path) # load data start = 0 for fname in vec_files: - fpath = path / fname + fpath = dump_path / fname batch_data = np.fromfile(fpath, dtype="float32") rows = int(batch_data.shape[0] / self.dim) @@ -370,29 +443,33 @@ def _create_ivf_index(self, path, train_count, nlist): ivf.add_with_ids(batch_vectors, ids_batch) start += rows + # remove dumps + for item in dump_path.iterdir(): + item.unlink() + return ivf - def _create_ivf_file_index(self, path, train_count, nlist): + def _create_ivf_file_index(self, dump_path, train_count, nlist): """Build an IVF on disk index""" - index_path = path.parent - trained_index = self._train_ivf(path, train_count=train_count, nlist=nlist) + index_path = dump_path.parent + trained_index = self._train_ivf(dump_path, train_count=train_count, nlist=nlist) # store trained index trained_path = str(index_path / "faiss_index.trained") faiss.write_index(trained_index, trained_path) - ids_path = path / "ids.mmap" + ids_path = dump_path / "ids.mmap" if not os.path.exists(ids_path): raise FileNotFoundError(f"Missing ids memmap: {ids_path}") ids = np.fromfile(ids_path, dtype="int64") - vec_files = self._get_dump_vector_files(path) + vec_files = self._get_dump_vector_files(dump_path) start = 0 block_fnames = [] for num, fname in enumerate(vec_files): index = faiss.read_index(trained_path) - fpath = path / fname + fpath = dump_path / fname batch_data = np.fromfile(fpath, dtype="float32") rows = int(batch_data.shape[0] / self.dim) @@ -406,6 +483,10 @@ def _create_ivf_file_index(self, path, train_count, nlist): faiss.write_index(index, block_fname) start += rows + # remove dumps + for item in dump_path.iterdir(): + item.unlink() + index = faiss.read_index(trained_path) merge_ondisk(index, block_fnames, str(index_path / "faiss_index_merged")) @@ -415,6 +496,36 @@ def _create_ivf_file_index(self, path, train_count, nlist): return index + def get_size(self): + if self.index is None: + return 0 + else: + return self.index.ntotal + + def check_required_disk_space(self, index_type): + available = psutil.disk_usage(self.path).free + + # current size of index + index_size = 0 + base_path = Path(self.path).parent + for item in base_path.iterdir(): + if item.is_dir() or not item.name.startswith("faiss_index"): + continue + index_size += item.stat().st_size + + # k - how more space required than current index size + if index_type == "ivf_file": + # recovery + dump + shard files + k = 3.01 + else: + # recovery + dump + k = 2.01 + + # k-1 because the current index space will be reused + if available < index_size * (k - 1): + to_free_gb = round((index_size * (k - 1)) / 1024**3, 2) + raise ValueError(f"Unable run indexing FAISS not enough disk space, get free at least : {to_free_gb} Gb") + def create_index(self, index_type, nlist=None, train_count=None): """ Create or recreate IVF index @@ -436,10 +547,7 @@ def create_index(self, index_type, nlist=None, train_count=None): if nlist is None: nlist = self.config.nlist - if self.index is None: - ntotal = 0 - else: - ntotal = self.index.ntotal + ntotal = self.get_size() # faiss shows warning if train count is less than 39 * nlist and recommend to use at least this size for train data nlist_k = 39 @@ -453,6 +561,8 @@ def create_index(self, index_type, nlist=None, train_count=None): if train_count > ntotal: raise ValueError(f"Not enough data to create index: {ntotal}, at least {train_count} records are required") + self.check_required_disk_space(index_type) + dump_path.mkdir(exist_ok=True) # remove old items @@ -475,20 +585,20 @@ def create_index(self, index_type, nlist=None, train_count=None): # create ivf index if index_type == "ivf": ivf_index = self._create_ivf_index(dump_path, train_count=train_count, nlist=nlist) + self.lock_required = True elif index_type == "ivf_file": ivf_index = self._create_ivf_file_index(dump_path, train_count=train_count, nlist=nlist) + self.lock_required = False else: raise ValueError(f"Unknown index type: {index_type}") self.index = ivf_index - self.index_type = "ivf" + self.index_type = index_type self.dump() self._lock_index() - # remove unused items - for item in dump_path.iterdir(): - item.unlink() + # remove unused files dump_path.rmdir() for item in recover_path.iterdir(): diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/requirements.txt b/mindsdb/integrations/handlers/duckdb_faiss_handler/requirements.txt index 8a1860f26b2..3dd4dc56e15 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/requirements.txt +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/requirements.txt @@ -1,2 +1 @@ -faiss-cpu>=1.7.4 -portalocker +faiss-cpu==1.13.2 diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py index 01eb44b2cae..a0950730067 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py @@ -3,10 +3,10 @@ import pandas as pd -from tests.unit.executor.test_knowledge_base import TestKB as BaseTestKB, set_litellm_embedding +from tests.unit.executor.test_knowledge_base import TestKB, set_litellm_embedding -class TestFAISS(BaseTestKB): +class TestFAISS(TestKB): "Run unit tests using FAISS handler as storage" def _get_storage_table(self, kb_name): diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index c9703bd1ada..d2eabfe5c41 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -1510,7 +1510,7 @@ def _check_embedding_model(self, project_name, params: dict = None, kb_name="") except Exception as e: raise RuntimeError(f"Problem with embedding model config: {e}") from e - def delete(self, name: str, project_name: int, if_exists: bool = False) -> None: + def delete(self, name: str, project_name: str, if_exists: bool = False) -> None: """ Delete a knowledge base from the database """ @@ -1629,3 +1629,24 @@ def evaluate(self, table_name: str, project_name: str, params: dict = None) -> p scores = EvaluateBase.run(self.session, kb_table, params) return scores + + def release_lock(self, knowledge_base: Identifier, project_name): + # works only for FAISS dbs. + # if FAISS vector db is used in KB: remove this db from handlers cache. + # it will clear internal cache of tables in faiss handler and release locks for faiss files + + if len(knowledge_base.parts) > 1: + project_name, kb_name = knowledge_base.parts[-2:] + else: + kb_name = knowledge_base.parts[-1] + + project_id = self.session.database_controller.get_project(project_name).id + kb = self.get(kb_name, project_id) + if kb is None or kb.vector_database_id is None: + return + database = db.Integration.query.get(kb.vector_database_id) + if database is None: + return + + if database.engine == "duckdb_faiss": + self.session.integration_controller.handlers_cache.delete(database.name) From d561f22920371f73418e783fcc6bbc8254574116 Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Fri, 20 Mar 2026 06:56:58 +0000 Subject: [PATCH 054/125] Fix BigQuery empty SELECT response typing (#12258) --- .../bigquery_handler/bigquery_handler.py | 3 +- tests/unit/handlers/test_bigquery.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py b/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py index 6c08d2aecb8..2746ad487df 100644 --- a/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py +++ b/mindsdb/integrations/handlers/bigquery_handler/bigquery_handler.py @@ -139,7 +139,8 @@ def native_query(self, query: str) -> Response: ) query = connection.query(query, job_config=job_config) result = query.to_dataframe() - if not result.empty: + has_table_result = isinstance(result, pd.DataFrame) and (not result.empty or len(result.columns) > 0) + if has_table_result: response = Response(RESPONSE_TYPE.TABLE, result) else: response = Response(RESPONSE_TYPE.OK) diff --git a/tests/unit/handlers/test_bigquery.py b/tests/unit/handlers/test_bigquery.py index 026a90c1146..0f67b430ea0 100644 --- a/tests/unit/handlers/test_bigquery.py +++ b/tests/unit/handlers/test_bigquery.py @@ -92,6 +92,34 @@ def test_native_query(self): assert isinstance(data, DataHandlerResponse) self.assertFalse(data.error_code) + def test_native_query_empty_select_returns_table(self): + mock_conn = MagicMock() + self.handler.connect = MagicMock(return_value=mock_conn) + + mock_query = MagicMock() + mock_query.to_dataframe.return_value = pd.DataFrame(columns=["id"]) + mock_conn.query.return_value = mock_query + + with patch("mindsdb.integrations.handlers.bigquery_handler.bigquery_handler.QueryJobConfig"): + response = self.handler.native_query("SELECT id FROM table WHERE 1 = 0") + + self.assertEqual(response.type, RESPONSE_TYPE.TABLE) + self.assertEqual(list(response.data_frame.columns), ["id"]) + self.assertTrue(response.data_frame.empty) + + def test_native_query_empty_dataframe_without_columns_returns_ok(self): + mock_conn = MagicMock() + self.handler.connect = MagicMock(return_value=mock_conn) + + mock_query = MagicMock() + mock_query.to_dataframe.return_value = pd.DataFrame() + mock_conn.query.return_value = mock_query + + with patch("mindsdb.integrations.handlers.bigquery_handler.bigquery_handler.QueryJobConfig"): + response = self.handler.native_query("UPDATE table SET col = 1") + + self.assertEqual(response.type, RESPONSE_TYPE.OK) + def test_get_tables(self): """ Checks if the `get_tables` method correctly constructs the SQL query and if it calls `native_query` with the correct query. From bd2d10a2e982b8bc7babe14b7b45ced56c020562 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 10:52:14 +0300 Subject: [PATCH 055/125] Bump python-multipart from 0.0.20 to 0.0.22 in /requirements (#12153) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 231f8e681c0..1582e0fd0c8 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,7 +3,7 @@ flask == 3.0.3 werkzeug == 3.0.6 flask-restx >= 1.3.0, < 2.0.0 pandas == 2.2.3 -python-multipart == 0.0.20 +python-multipart == 0.0.22 cryptography>=35.0 psycopg[binary] psutil~=7.0 From fb107bd2f3f1e428c7eaaca30e857b0b35624310 Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Fri, 20 Mar 2026 08:45:35 +0000 Subject: [PATCH 056/125] Fix MSSQL check_connection exception handling (#12261) --- .../integrations/handlers/mssql_handler/mssql_handler.py | 2 +- tests/unit/handlers/test_mssql.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/mssql_handler/mssql_handler.py b/mindsdb/integrations/handlers/mssql_handler/mssql_handler.py index bfff010e02b..7b6e42fff34 100644 --- a/mindsdb/integrations/handlers/mssql_handler/mssql_handler.py +++ b/mindsdb/integrations/handlers/mssql_handler/mssql_handler.py @@ -327,7 +327,7 @@ def check_connection(self) -> StatusResponse: # Execute a simple query to test the connection cur.execute("select 1;") response.success = True - except OperationalError as e: + except Exception as e: logger.error(f"Error connecting to Microsoft SQL Server {self.database}, {e}!") response.error_message = str(e) diff --git a/tests/unit/handlers/test_mssql.py b/tests/unit/handlers/test_mssql.py index 37e4d06c8f7..d7024d51359 100644 --- a/tests/unit/handlers/test_mssql.py +++ b/tests/unit/handlers/test_mssql.py @@ -641,6 +641,13 @@ def test_check_connection(self): self.assertFalse(response.success) self.assertEqual(response.error_message, "Connection error") + self.handler.connect.side_effect = ValueError("Invalid connection args") + + response = self.handler.check_connection() + + self.assertFalse(response.success) + self.assertEqual(response.error_message, "Invalid connection args") + def test_types_casting(self): """Test that types are casted correctly""" query_str = "SELECT * FROM test_table" From a01c807e53009a5a3a8a37de1a7f47dfd808ea8c Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Fri, 20 Mar 2026 09:34:08 +0000 Subject: [PATCH 057/125] Fix Snowflake FK metadata parent/child direction (#12255) --- .../handlers/snowflake_handler/snowflake_handler.py | 8 ++++---- tests/unit/handlers/test_snowflake.py | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py b/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py index 1853b6f3447..581eb74b053 100644 --- a/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +++ b/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py @@ -698,10 +698,10 @@ def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> Data df = df[["pk_table_name", "pk_column_name", "fk_table_name", "fk_column_name"]] df = df.rename( columns={ - "pk_table_name": "child_table_name", - "pk_column_name": "child_column_name", - "fk_table_name": "parent_table_name", - "fk_column_name": "parent_column_name", + "pk_table_name": "parent_table_name", + "pk_column_name": "parent_column_name", + "fk_table_name": "child_table_name", + "fk_column_name": "child_column_name", } ) diff --git a/tests/unit/handlers/test_snowflake.py b/tests/unit/handlers/test_snowflake.py index 7118f3f1602..07c0c87b040 100644 --- a/tests/unit/handlers/test_snowflake.py +++ b/tests/unit/handlers/test_snowflake.py @@ -891,6 +891,11 @@ def test_meta_get_foreign_keys_filters(self): self.assertEqual(len(result.data_frame), 1) self.assertIn("child_table_name", result.data_frame.columns) + row = result.data_frame.iloc[0] + self.assertEqual(row["parent_table_name"], "ORDERS") + self.assertEqual(row["parent_column_name"], "CUSTOMER_ID") + self.assertEqual(row["child_table_name"], "CUSTOMERS") + self.assertEqual(row["child_column_name"], "ID") def test_meta_get_foreign_keys_handles_exception(self): self.handler.native_query = MagicMock(side_effect=Exception("boom")) From a05f479dbbaf4f724b6a28a3d758ddd09f806d68 Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Fri, 20 Mar 2026 12:07:24 +0000 Subject: [PATCH 058/125] Fix Databricks get_tables non-table response handling (#12259) --- .../databricks_handler/databricks_handler.py | 12 +++++++----- tests/unit/handlers/test_databricks.py | 8 ++++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py b/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py index 2feab0a37d4..e65b8229f4d 100644 --- a/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py +++ b/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py @@ -463,11 +463,13 @@ def get_tables(self, all: bool = False) -> Response: WHERE table_schema != 'information_schema' {all_filter} - """ - result = self.native_query(query) - df = result.data_frame - result.data_frame = df.rename(columns={col: col.upper() for col in df.columns}) - return result + """ + result = self.native_query(query) + if result.resp_type != RESPONSE_TYPE.TABLE or result.data_frame is None: + return result + df = result.data_frame + result.data_frame = df.rename(columns={col: col.upper() for col in df.columns}) + return result def get_columns(self, table_name: str, schema_name: str | None = None) -> Response: """ diff --git a/tests/unit/handlers/test_databricks.py b/tests/unit/handlers/test_databricks.py index 9dc2282e90f..b39fd46c936 100644 --- a/tests/unit/handlers/test_databricks.py +++ b/tests/unit/handlers/test_databricks.py @@ -205,6 +205,14 @@ def test_get_tables(self): """ self.handler.native_query.assert_called_once_with(expected_query) + def test_get_tables_returns_non_table_response_without_transform(self): + expected = Response(RESPONSE_TYPE.ERROR, error_message="boom") + self.handler.native_query = MagicMock(return_value=expected) + + result = self.handler.get_tables() + + self.assertIs(result, expected) + def test_get_columns(self): """ Tests if the `get_columns` method correctly constructs the SQL query and if it calls `native_query` with the correct query. From aea51d2096837547dc063523cbd9a9558e189179 Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Fri, 20 Mar 2026 12:07:35 +0000 Subject: [PATCH 059/125] Fix Databricks empty SELECT response type (#12257) --- .../databricks_handler/databricks_handler.py | 16 ++++++++-------- tests/unit/handlers/test_databricks.py | 9 +++++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py b/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py index e65b8229f4d..39527555732 100644 --- a/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py +++ b/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py @@ -402,14 +402,14 @@ def native_query(self, query: Text) -> Response: connection = self.connect() with connection.cursor() as cursor: try: - cursor.execute(query) - result = cursor.fetchall() - if result: - response = Response( - RESPONSE_TYPE.TABLE, - data_frame=pd.DataFrame(result, columns=[x[0] for x in cursor.description]), - ) - else: + cursor.execute(query) + result = cursor.fetchall() + if cursor.description: + response = Response( + RESPONSE_TYPE.TABLE, + data_frame=pd.DataFrame(result, columns=[x[0] for x in cursor.description]), + ) + else: response = Response(RESPONSE_TYPE.OK) connection.commit() except ServerOperationError as server_error: diff --git a/tests/unit/handlers/test_databricks.py b/tests/unit/handlers/test_databricks.py index b39fd46c936..74f62a731e7 100644 --- a/tests/unit/handlers/test_databricks.py +++ b/tests/unit/handlers/test_databricks.py @@ -185,6 +185,15 @@ def test_native_query(self): self.assertIsInstance(data, DataHandlerResponse) self.assertNotIsInstance(data, ErrorResponse) + def test_native_query_empty_select_returns_table(self): + self.mock_cursor.set_results([], ["id", "name"]) + + response = self.handler.native_query("SELECT id, name FROM table WHERE 1 = 0") + + self.assertEqual(response.type, RESPONSE_TYPE.TABLE) + self.assertEqual(list(response.data_frame.columns), ["id", "name"]) + self.assertEqual(len(response.data_frame), 0) + def test_get_tables(self): """ Tests if the `get_tables` method to confirm it correctly calls `native_query` with the appropriate SQL commands. From 391347f50f27d8632956144d278497b283191968 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Fri, 20 Mar 2026 15:12:07 +0300 Subject: [PATCH 060/125] ruff databricks handler (#12311) --- .../databricks_handler/databricks_handler.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py b/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py index 39527555732..755308d419b 100644 --- a/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py +++ b/mindsdb/integrations/handlers/databricks_handler/databricks_handler.py @@ -402,14 +402,14 @@ def native_query(self, query: Text) -> Response: connection = self.connect() with connection.cursor() as cursor: try: - cursor.execute(query) - result = cursor.fetchall() - if cursor.description: - response = Response( - RESPONSE_TYPE.TABLE, - data_frame=pd.DataFrame(result, columns=[x[0] for x in cursor.description]), - ) - else: + cursor.execute(query) + result = cursor.fetchall() + if cursor.description: + response = Response( + RESPONSE_TYPE.TABLE, + data_frame=pd.DataFrame(result, columns=[x[0] for x in cursor.description]), + ) + else: response = Response(RESPONSE_TYPE.OK) connection.commit() except ServerOperationError as server_error: @@ -463,13 +463,13 @@ def get_tables(self, all: bool = False) -> Response: WHERE table_schema != 'information_schema' {all_filter} - """ - result = self.native_query(query) - if result.resp_type != RESPONSE_TYPE.TABLE or result.data_frame is None: - return result - df = result.data_frame - result.data_frame = df.rename(columns={col: col.upper() for col in df.columns}) - return result + """ + result = self.native_query(query) + if result.resp_type != RESPONSE_TYPE.TABLE or result.data_frame is None: + return result + df = result.data_frame + result.data_frame = df.rename(columns={col: col.upper() for col in df.columns}) + return result def get_columns(self, table_name: str, schema_name: str | None = None) -> Response: """ From 0e7c3d6f12015bd0b5348b1ef4f1a606a1540807 Mon Sep 17 00:00:00 2001 From: "Farley Farley (yes, really)" Date: Sat, 21 Mar 2026 01:23:13 +1300 Subject: [PATCH 061/125] [Snyk] Security upgrade protobuf from 4.24.4 to 6.33.5 (#12186) Co-authored-by: snyk-bot Co-authored-by: Max Stepanov --- mindsdb/integrations/handlers/mlflow_handler/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mindsdb/integrations/handlers/mlflow_handler/requirements.txt b/mindsdb/integrations/handlers/mlflow_handler/requirements.txt index 40d1d2b4967..3ccfa559cbe 100644 --- a/mindsdb/integrations/handlers/mlflow_handler/requirements.txt +++ b/mindsdb/integrations/handlers/mlflow_handler/requirements.txt @@ -1,2 +1,3 @@ mlflow +protobuf>=6.33.5 # not directly required, pinned by Snyk to avoid a vulnerability sqlparse>=0.5.4 # not directly required, pinned by Snyk to avoid a vulnerability From aba9d3837447db86b2af3f62936907f40110f9d6 Mon Sep 17 00:00:00 2001 From: "Farley Farley (yes, really)" Date: Sat, 21 Mar 2026 01:25:55 +1300 Subject: [PATCH 062/125] [Snyk] Security upgrade protobuf from 4.24.4 to 6.33.5 (#12183) Co-authored-by: snyk-bot --- mindsdb/integrations/handlers/milvus_handler/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mindsdb/integrations/handlers/milvus_handler/requirements.txt b/mindsdb/integrations/handlers/milvus_handler/requirements.txt index 4872c767649..780dc2add48 100644 --- a/mindsdb/integrations/handlers/milvus_handler/requirements.txt +++ b/mindsdb/integrations/handlers/milvus_handler/requirements.txt @@ -1 +1,2 @@ pymilvus==2.3 +protobuf>=6.33.5 # not directly required, pinned by Snyk to avoid a vulnerability From 868af7e7bf64d6725a8cc27099544a2bf368d910 Mon Sep 17 00:00:00 2001 From: "Farley Farley (yes, really)" Date: Sat, 21 Mar 2026 01:28:57 +1300 Subject: [PATCH 063/125] [Snyk] Security upgrade protobuf from 4.24.4 to 6.33.5 (#12181) Co-authored-by: snyk-bot --- mindsdb/integrations/handlers/phoenix_handler/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mindsdb/integrations/handlers/phoenix_handler/requirements.txt b/mindsdb/integrations/handlers/phoenix_handler/requirements.txt index 7d8fd10bbc0..77441982eb4 100644 --- a/mindsdb/integrations/handlers/phoenix_handler/requirements.txt +++ b/mindsdb/integrations/handlers/phoenix_handler/requirements.txt @@ -1,2 +1,3 @@ pyphoenix phoenixdb +protobuf>=6.33.5 # not directly required, pinned by Snyk to avoid a vulnerability From 326284ba2461104bb97a662cce42f5325c33a227 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:32:05 +0300 Subject: [PATCH 064/125] Bump orjson from 3.11.3 to 3.11.6 in /requirements (#12288) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 1582e0fd0c8..96af4155038 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -44,7 +44,7 @@ openpyxl # used by pandas to read txt and xlsx files xlrd>=2.0.1 # used by pandas to read legacy .xls files aipdf==0.0.7.0 pyarrow<=19.0.0 # used by pandas to read feather files in Files handler -orjson==3.11.3 +orjson==3.11.6 mind-castle >= 0.4.9 pydantic-ai>=0.0.14 # Required for Pydantic AI agents From 4f7865bd068ace09569d7a8b48b3d2c5f7979514 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:34:06 +0300 Subject: [PATCH 065/125] Bump socket.io-parser from 4.2.5 to 4.2.6 in /docs (#12302) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/package-lock.json b/docs/package-lock.json index 035fed79a8a..afd6c9aaa98 100644 --- a/docs/package-lock.json +++ b/docs/package-lock.json @@ -12739,9 +12739,9 @@ } }, "node_modules/socket.io-parser": { - "version": "4.2.5", - "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.5.tgz", - "integrity": "sha512-bPMmpy/5WWKHea5Y/jYAP6k74A+hvmRCQaJuJB6I/ML5JZq/KfNieUVo/3Mh7SAqn7TyFdIo6wqYHInG1MU1bQ==", + "version": "4.2.6", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.6.tgz", + "integrity": "sha512-asJqbVBDsBCJx0pTqw3WfesSY0iRX+2xzWEWzrpcH7L6fLzrhyF8WPI8UaeM4YCuDfpwA/cgsdugMsmtz8EJeg==", "license": "MIT", "dependencies": { "@socket.io/component-emitter": "~3.1.0", From 453c65284e9eaaf93fcbfc37fde510d8f1b8a940 Mon Sep 17 00:00:00 2001 From: Richard Okonicha Date: Thu, 19 Mar 2026 02:30:15 +0100 Subject: [PATCH 066/125] Update Elasticsearch handler: fix array support, add Data Catalog, and add to CI - Fix array field handling (convert to JSON strings) - Add Data Catalog support (get_column_statistics, get_primary_keys, get_foreign_keys) - Standardize documentation format - Add elasticsearch to tests_unit.yml HANDLERS_TO_INSTALL for CI testing Addresses maintainer feedback on PR #11552 --- .github/workflows/tests_unit.yml | 1 + .../handlers/elasticsearch_handler/README.md | 146 +-- .../elasticsearch_handler/__about__.py | 6 +- .../elasticsearch_handler/connection_args.py | 20 + .../elasticsearch_handler.py | 938 ++++++++++++++++-- .../elasticsearch_handler/requirements.txt | 4 +- 6 files changed, 933 insertions(+), 182 deletions(-) diff --git a/.github/workflows/tests_unit.yml b/.github/workflows/tests_unit.yml index 3a9455a7f76..4cd366bd980 100644 --- a/.github/workflows/tests_unit.yml +++ b/.github/workflows/tests_unit.yml @@ -34,6 +34,7 @@ env: statsforecast chromadb confluence + elasticsearch # We measure 80% on this handlers, as they are the verified HANDLERS_TO_VERIFY: | mysql diff --git a/mindsdb/integrations/handlers/elasticsearch_handler/README.md b/mindsdb/integrations/handlers/elasticsearch_handler/README.md index b672ad22b99..da294c761cd 100644 --- a/mindsdb/integrations/handlers/elasticsearch_handler/README.md +++ b/mindsdb/integrations/handlers/elasticsearch_handler/README.md @@ -1,118 +1,130 @@ --- -title: ElasticSearch -sidebarTitle: ElasticSearch +title: Elasticsearch +sidebarTitle: Elasticsearch --- -This documentation describes the integration of MindsDB with [ElasticSearch](https://www.elastic.co/), a distributed, multitenant-capable full-text search engine with an HTTP web interface and schema-free JSON documents.. -The integration allows MindsDB to access data from ElasticSearch and enhance ElasticSearch with AI capabilities. +This documentation describes the integration of MindsDB with [Elasticsearch](https://www.elastic.co/elasticsearch/), a distributed search and analytics engine. +The integration allows MindsDB to access data stored in Elasticsearch indices and enhance Elasticsearch with AI capabilities. + +## Architecture + +This handler uses a **SQL-first architecture** with automatic fallback: + +1. **Primary**: Elasticsearch SQL API for maximum performance and compatibility +2. **Fallback**: Search API for array-containing indexes with automatic array-to-JSON conversion +3. **Security**: SSL/TLS support with certificate validation +4. **Efficiency**: Memory-efficient pagination for large datasets + +The handler automatically detects when SQL queries encounter array fields and seamlessly falls back to the Search API, converting arrays to JSON strings for SQL compatibility. This approach provides the best performance while handling all Elasticsearch data types. ## Prerequisites Before proceeding, ensure the following prerequisites are met: 1. Install MindsDB locally via [Docker](https://docs.mindsdb.com/setup/self-hosted/docker) or [Docker Desktop](https://docs.mindsdb.com/setup/self-hosted/docker-desktop). -2. To connect ElasticSearch to MindsDB, install the required dependencies following [this instruction](/setup/self-hosted/docker#install-dependencies). -3. Install or ensure access to ElasticSearch. +2. To connect Elasticsearch to MindsDB, install the required dependencies following [this instruction](https://docs.mindsdb.com/setup/self-hosted/docker#install-dependencies). +3. **If installing from source**: Python 3.11 or 3.12 is recommended. Install with: `pip install -e '.[elasticsearch]'` ## Connection -Establish a connection to ElasticSearch from MindsDB by executing the following SQL command and providing its [handler name](https://github.com/mindsdb/mindsdb/tree/main/mindsdb/integrations/handlers/elasticsearch_handler) as an engine. +Establish a connection to your Elasticsearch cluster from MindsDB by executing the following SQL command: ```sql -CREATE DATABASE elasticsearch_datasource +CREATE DATABASE elasticsearch_conn WITH ENGINE = 'elasticsearch', -PARAMETERS={ - 'cloud_id': 'xyz', -- optional, if hosts are provided - 'hosts': 'https://xyz.xyz.gcp.cloud.es.io:123', -- optional, if cloud_id is provided - 'api_key': 'xyz', -- optional, if user and password are provided - 'user': 'elastic', -- optional, if api_key is provided - 'password': 'xyz' -- optional, if api_key is provided +PARAMETERS = { + "hosts": "localhost:9200", + "user": "elastic", + "password": "changeme" }; ``` -The connection parameters include the following: - -* `cloud_id`: The Cloud ID provided with the ElasticSearch deployment. Required only when `hosts` is not provided. -* `hosts`: The ElasticSearch endpoint provided with the ElasticSearch deployment. Required only when `cloud_id` is not provided. -* `api_key`: The API key that you generated for the ElasticSearch deployment. Required only when `user` and `password` are not provided. -* `user` and `password`: The user and password used to authenticate. Required only when `api_key` is not provided. +Required connection parameters include the following: - -If you want to connect to the local instance of ElasticSearch, use the below statement: - -```sql -CREATE DATABASE elasticsearch_datasource -WITH ENGINE = 'elasticsearch', -PARAMETERS = { - "hosts": "127.0.0.1:9200", - "user": "user", - "password": "password" -}; -``` +* `hosts`: The Elasticsearch host(s) in format "host:port". For multiple hosts, use comma separation like "host1:port1,host2:port2". -Required connection parameters include the following (at least one of these parameters should be provided): +Optional connection parameters include the following: -* `hosts`: The IP address and port where ElasticSearch is deployed. -* `user`: The user used to autheticate access. -* `password`: The password used to autheticate access. - +* `user`: The username for Elasticsearch authentication. +* `password`: The password for Elasticsearch authentication. +* `api_key`: API key for authentication (alternative to user/password). +* `cloud_id`: Elastic Cloud deployment ID for hosted Elasticsearch. +* `ca_certs`: Path to CA certificate file for SSL verification. +* `client_cert`: Path to client certificate file for SSL authentication. +* `client_key`: Path to client private key file for SSL authentication. +* `verify_certs`: Boolean to enable/disable SSL certificate verification (default: true). +* `timeout`: Request timeout in seconds. ## Usage +The following usage examples utilize the connection to Elasticsearch made via the `CREATE DATABASE` statement and named `elasticsearch_conn`. + Retrieve data from a specified index by providing the integration name and index name: ```sql SELECT * -FROM elasticsearch_datasource.my_index +FROM elasticsearch_conn.products LIMIT 10; ``` - -The above examples utilize `elasticsearch_datasource` as the datasource name, which is defined in the `CREATE DATABASE` command. - +Query with filtering and aggregation: + +```sql +SELECT category, COUNT(*) as product_count, AVG(price) as avg_price +FROM elasticsearch_conn.products +WHERE price > 100 +GROUP BY category +ORDER BY product_count DESC; +``` + +Run queries with array fields (automatically converted to JSON strings): + +```sql +SELECT product_name, tags, categories +FROM elasticsearch_conn.products +WHERE product_id = '12345'; +``` -At the moment, the Elasticsearch SQL API has certain limitations that have an impact on the queries that can be issued via MindsDB. The most notable of these limitations are listed below: -1. Only `SELECT` queries are supported at the moment. -2. Array fields are not supported. -3. Nested fields cannot be queried directly. However, they can be accessed using the `.` operator. +**Array Field Support** -For a detailed guide on the limitations of the Elasticsearch SQL API, refer to the [official documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/sql-limitations.html). +The Elasticsearch handler automatically detects and converts array fields to JSON strings for SQL compatibility. This prevents "Arrays not supported" errors while preserving the original data structure. -## Troubleshooting Guide +## Troubleshooting `Database Connection Error` -* **Symptoms**: Failure to connect MindsDB with the Elasticsearch server. +* **Symptoms**: Failure to connect MindsDB with the Elasticsearch cluster. * **Checklist**: - 1. Make sure the Elasticsearch server is active. - 2. Confirm that server, cloud ID and credentials are correct. + 1. Make sure the Elasticsearch cluster is active and accessible. + 2. Confirm that host, port, user, and password are correct. Try a direct Elasticsearch connection. 3. Ensure a stable network between MindsDB and Elasticsearch. + 4. Check if authentication is required and credentials are valid. -`Transport Error` or `Request Error` +`Arrays Not Supported Error` -* **Symptoms**: Errors related to the issuing of unsupported queries to Elasticsearch. -* **Checklist**: - 1. Ensure the query is a `SELECT` query. - 2. Avoid querying array fields. - 3. Access nested fields using the `.` operator. - 4. Refer to the [official documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/sql-limitations.html) for more information if needed. +* **Symptoms**: SQL queries failing with "Arrays are not supported" message. +* **Solution**: This is automatically handled by the integration. Array fields are converted to JSON strings for SQL compatibility. +* **Note**: If you still encounter this error, the handler will automatically fall back to the Search API. -`SQL statement cannot be parsed by mindsdb_sql` - -* **Symptoms**: SQL queries failing or not recognizing index names containing special characters. -* **Checklist**: - 1. Ensure table names with special characters are enclosed in backticks. - 2. Examples: - * Incorrect: SELECT * FROM integration.travel-data - * Incorrect: SELECT * FROM integration.'travel-data' - * Correct: SELECT * FROM integration.\`travel-data\` +`SHOW TABLES returns empty or fails` + +* **Symptoms**: `SHOW TABLES FROM elasticsearch_conn` returns no results or fails. +* **Solution**: Use the information_schema alternative: + ```sql + SELECT table_name FROM information_schema.tables + WHERE table_schema = 'elasticsearch_conn'; + ``` -This [troubleshooting guide](https://www.elastic.co/guide/en/elasticsearch/reference/current/troubleshooting.html) provided by Elasticsearch might also be helpful. +## Limitations + +* **JOINs**: Not supported due to Elasticsearch architecture limitations. +* **Complex Subqueries**: Limited by Elasticsearch's SQL capabilities. +* **Real-time Data**: Elasticsearch has near-real-time search characteristics due to refresh intervals. \ No newline at end of file diff --git a/mindsdb/integrations/handlers/elasticsearch_handler/__about__.py b/mindsdb/integrations/handlers/elasticsearch_handler/__about__.py index 38a6c79dce6..9fa6bf695bc 100644 --- a/mindsdb/integrations/handlers/elasticsearch_handler/__about__.py +++ b/mindsdb/integrations/handlers/elasticsearch_handler/__about__.py @@ -1,8 +1,8 @@ __title__ = "MindsDB Elasticsearch handler" __package_name__ = "mindsdb_elasticsearch_handler" -__version__ = "0.0.1" -__description__ = "MindsDB handler for Elasticsearch" -__author__ = "Minura Punchihewa" +__version__ = "0.1.0" +__description__ = "MindsDB handler for Elasticsearch with SQL-first query execution" +__author__ = "MindsDB Inc" __github__ = "https://github.com/mindsdb/mindsdb" __pypi__ = "https://pypi.org/project/mindsdb/" __license__ = "MIT" diff --git a/mindsdb/integrations/handlers/elasticsearch_handler/connection_args.py b/mindsdb/integrations/handlers/elasticsearch_handler/connection_args.py index 9857096337b..358051e4fc4 100644 --- a/mindsdb/integrations/handlers/elasticsearch_handler/connection_args.py +++ b/mindsdb/integrations/handlers/elasticsearch_handler/connection_args.py @@ -29,6 +29,26 @@ "description": "The API key for authentication with the Elasticsearch server.", "secret": True, }, + ca_certs={ + "type": ARG_TYPE.STR, + "description": "Path to CA certificate file for SSL verification.", + }, + client_cert={ + "type": ARG_TYPE.STR, + "description": "Path to client certificate file for SSL authentication.", + }, + client_key={ + "type": ARG_TYPE.STR, + "description": "Path to client private key file for SSL authentication.", + }, + verify_certs={ + "type": ARG_TYPE.BOOL, + "description": "Whether to verify SSL certificates. Default: true", + }, + timeout={ + "type": ARG_TYPE.INT, + "description": "Request timeout in seconds. Default: 30", + }, ) connection_args_example = OrderedDict( diff --git a/mindsdb/integrations/handlers/elasticsearch_handler/elasticsearch_handler.py b/mindsdb/integrations/handlers/elasticsearch_handler/elasticsearch_handler.py index 3c7f2be6eb4..84273799b82 100644 --- a/mindsdb/integrations/handlers/elasticsearch_handler/elasticsearch_handler.py +++ b/mindsdb/integrations/handlers/elasticsearch_handler/elasticsearch_handler.py @@ -1,4 +1,5 @@ -from typing import Text, Dict, Optional +from typing import Text, Dict, Optional, List, Any +import json from elasticsearch import Elasticsearch from elasticsearch.exceptions import ( @@ -7,11 +8,19 @@ TransportError, RequestError, ) + +# ApiError is only available in Elasticsearch 8+ +try: + from elasticsearch.exceptions import ApiError +except ImportError: + ApiError = Exception # Fallback for ES 7.x compatibility + +# ESDialect: SQLAlchemy dialect for Elasticsearch, enables SQL query rendering +from es.elastic.sqlalchemy import ESDialect from pandas import DataFrame from mindsdb_sql_parser.ast.base import ASTNode from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender - -from mindsdb.integrations.libs.base import DatabaseHandler +from mindsdb.integrations.libs.base import MetaDatabaseHandler from mindsdb.integrations.libs.response import ( HandlerResponse as Response, HandlerStatusResponse as StatusResponse, @@ -23,28 +32,37 @@ logger = log.getLogger(__name__) -class ElasticsearchHandler(DatabaseHandler): +class ElasticsearchHandler(MetaDatabaseHandler): """ - This handler handles the connection and execution of SQL statements on Elasticsearch. + This handler handles the connection and execution of SQL statements on Elasticsearch + using a SQL-first architecture with automatic fallback capabilities. + + Features: + - SQL-first query execution with automatic Search API fallback + - Intelligent array field detection and JSON conversion + - SSL/TLS security configuration support + - Memory-efficient large dataset handling with pagination + - Comprehensive error handling and recovery mechanisms """ name = "elasticsearch" def __init__(self, name: Text, connection_data: Optional[Dict], **kwargs) -> None: """ - Initializes the handler. + Initializes the Elasticsearch handler with SQL-first query execution. Args: name (Text): The name of the handler instance. - connection_data (Dict): The connection data required to connect to the AWS (S3) account. - kwargs: Arbitrary keyword arguments. + connection_data (Dict): The connection data required to connect to the Elasticsearch cluster. + Should include hosts/cloud_id and authentication parameters. + **kwargs: Arbitrary keyword arguments. """ super().__init__(name) - self.connection_data = connection_data + self.connection_data = connection_data or {} self.kwargs = kwargs - self.connection = None self.is_connected = False + self._array_fields_cache: Dict[str, List[str]] = {} def __del__(self) -> None: """ @@ -55,67 +73,96 @@ def __del__(self) -> None: def connect(self) -> Elasticsearch: """ - Establishes a connection to the Elasticsearch host. + Establishes a connection to the Elasticsearch host with security configuration support. - Raises: - ValueError: If the expected connection parameters are not provided. + This method supports both on-premises and cloud Elasticsearch deployments with + SSL/TLS configuration and authentication options. Returns: elasticsearch.Elasticsearch: A connection object to the Elasticsearch host. + + Raises: + ValueError: If the expected connection parameters are not provided. + ConnectionError: If unable to establish connection to Elasticsearch. + AuthenticationException: If authentication fails. """ - if self.is_connected is True: + if self.is_connected: return self.connection - config = {} - - # Mandatory connection parameters. - if ("hosts" not in self.connection_data) and ("cloud_id" not in self.connection_data): - raise ValueError("Either the hosts or cloud_id parameter should be provided!") - - # Optional/Additional connection parameters. - optional_parameters = ["hosts", "cloud_id", "api_key"] - for parameter in optional_parameters: - if parameter in self.connection_data: - if parameter == "hosts": - config["hosts"] = self.connection_data[parameter].split(",") - else: - config[parameter] = self.connection_data[parameter] + # Validate required parameters + if not self.connection_data.get("hosts") and not self.connection_data.get("cloud_id"): + raise ValueError("Either 'hosts' or 'cloud_id' parameter must be provided") - # Ensure that if either user or password is provided, both are provided. - if ("user" in self.connection_data) != ("password" in self.connection_data): - raise ValueError("Both user and password should be provided if one of them is provided!") + config = {} - if "user" in self.connection_data: - config["basic_auth"] = ( - self.connection_data["user"], - self.connection_data["password"], - ) + # Connection parameters + if "hosts" in self.connection_data: + hosts_str = self.connection_data["hosts"] + hosts = hosts_str.split(",") + + # Validate host:port format + for host in hosts: + host = host.strip() + if ":" not in host: + raise ValueError( + f"Invalid host format '{host}'. Expected format: 'host:port' (e.g., 'localhost:9200')" + ) + # Additional validation: check port is numeric + try: + host_part, port_part = host.rsplit(":", 1) + int(port_part) # Validate port is numeric + except ValueError: + raise ValueError(f"Invalid port in host '{host}'. Port must be numeric") + + config["hosts"] = hosts + if "cloud_id" in self.connection_data: + config["cloud_id"] = self.connection_data["cloud_id"] + + # Authentication - API key takes precedence + if "api_key" in self.connection_data: + config["api_key"] = self.connection_data["api_key"] + # Skip user/password if API key is provided + else: + # Only check user/password if API key is not provided + user = self.connection_data.get("user") + password = self.connection_data.get("password") + if user and password: + config["http_auth"] = (user, password) + elif user or password: + raise ValueError("Both 'user' and 'password' must be provided together") + + # SSL/TLS configuration (secure by default) + config["verify_certs"] = self.connection_data.get("verify_certs", True) + if "ca_certs" in self.connection_data: + config["ca_certs"] = self.connection_data["ca_certs"] + if "client_cert" in self.connection_data: + config["client_cert"] = self.connection_data["client_cert"] + if "client_key" in self.connection_data: + config["client_key"] = self.connection_data["client_key"] + if "timeout" in self.connection_data: + config["timeout"] = self.connection_data["timeout"] try: - self.connection = Elasticsearch( - **config, - ) + self.connection = Elasticsearch(**config) self.is_connected = True return self.connection - except ConnectionError as conn_error: - logger.error(f"Connection error when connecting to Elasticsearch: {conn_error}") - raise - except AuthenticationException as auth_error: - logger.error(f"Authentication error when connecting to Elasticsearch: {auth_error}") + except (ConnectionError, AuthenticationException) as e: + logger.error(f"Connection failed: {e}") raise - except Exception as unknown_error: - logger.error(f"Unknown error when connecting to Elasticsearch: {unknown_error}") + except Exception as e: + logger.error(f"Unexpected connection error: {e}") raise def disconnect(self) -> None: """ Closes the connection to the Elasticsearch host if it's currently open. """ - if self.is_connected is False: + if not self.is_connected: return - - self.connection.close() - self.is_connected = False + try: + self.connection.close() + finally: + self.is_connected = False def check_connection(self) -> StatusResponse: """ @@ -125,81 +172,318 @@ def check_connection(self) -> StatusResponse: StatusResponse: An object containing the success status and an error message if an error occurs. """ response = StatusResponse(False) - need_to_close = self.is_connected is False + need_to_close = not self.is_connected try: connection = self.connect() - - # Execute a simple query to test the connection. + # Simple test query connection.sql.query(body={"query": "SELECT 1"}) response.success = True - # All exceptions are caught here to ensure that the connection is closed if an error occurs. except Exception as error: - logger.error(f"Error connecting to Elasticsearch, {error}!") + logger.error(f"Connection check failed: {error}") response.error_message = str(error) + if self.is_connected: + self.is_connected = False if response.success and need_to_close: self.disconnect() - elif not response.success and self.is_connected: - self.is_connected = False - return response def native_query(self, query: Text) -> Response: """ - Executes a native SQL query on the Elasticsearch host and returns the result. + Executes a native SQL query on the Elasticsearch host using SQL-first approach. + + This method uses a dual-strategy approach: + 1. Primary: Uses Elasticsearch SQL API for performance and compatibility + 2. Fallback: Automatically switches to Search API for array-containing indexes + 3. Handles pagination and large result sets Args: - query (str): The SQL query to be executed. + query (Text): The SQL query to be executed. Returns: Response: A response object containing the result of the query or an error message. """ - need_to_close = self.is_connected is False - + logger.debug(f"Executing query: {query[:100]}...") + need_to_close = not self.is_connected connection = self.connect() + try: + # Primary: Try SQL API first (standard approach) response = connection.sql.query(body={"query": query}) records = response["rows"] columns = response["columns"] - new_records = True - while new_records: + # Handle pagination for large result sets with safety limit + max_pages = 100 # Prevent infinite pagination + for _ in range(max_pages): + if not response.get("cursor"): + break + response = connection.sql.query(body={"query": query, "cursor": response["cursor"]}) + if not response["rows"]: + break + records.extend(response["rows"]) + + column_names = [col["name"] for col in columns] + if not records: + records = [[None] * len(column_names)] + + return Response(RESPONSE_TYPE.TABLE, data_frame=DataFrame(records, columns=column_names)) + + except (TransportError, RequestError, ApiError) as e: + error_msg = str(e).lower() + + # Intelligent fallback: Check if error is array-related + if any(keyword in error_msg for keyword in ["array", "nested", "object"]): + logger.debug(f"SQL API failed with array-related error, using Search API fallback: {e}") try: - if response["cursor"]: - response = connection.sql.query(body={"query": query, "cursor": response["cursor"]}) + return self._search_api_fallback(query) + except Exception as fallback_error: + logger.error(f"Search API fallback also failed: {fallback_error}") + return Response( + RESPONSE_TYPE.ERROR, error_message=f"Both SQL and Search APIs failed: {fallback_error}" + ) + + # Handle other SQL API errors + logger.error(f"SQL API error: {e}") + return Response(RESPONSE_TYPE.ERROR, error_message=str(e)) + + except Exception as e: + logger.error(f"Unexpected query error: {e}") + return Response(RESPONSE_TYPE.ERROR, error_message=str(e)) + + finally: + if need_to_close: + self.disconnect() + + def _search_api_fallback(self, query: str) -> Response: + """ + Search API fallback for array-containing indexes. - new_records = response["rows"] - records = records + new_records - except KeyError: - new_records = False + This method is automatically invoked when SQL API encounters array fields, + providing seamless query execution with proper array handling. - column_names = [column["name"] for column in columns] - if not records: - null_record = [None] * len(column_names) - records = [null_record] + Args: + query (str): Original SQL query that failed with SQL API + + Returns: + Response: Search results converted to tabular format with arrays as JSON strings + """ + # Simple query parsing (only what's needed for Search API) + index_name = self._extract_table_name(query) + if not index_name: + raise ValueError("Could not determine index name from query") + + # Extract LIMIT from query if present + limit = self._extract_limit(query) + if limit is None: + limit = 10000 # Default maximum documents to fetch to prevent memory issues + + # Execute search with pagination + scroll_id = None + try: + batch_size = min(1000, limit) # Use smaller batch size if limit is small + search_body = { + "size": batch_size, + "query": {"match_all": {}}, + } + + response = self.connection.search(index=index_name, body=search_body, scroll="5m") + + records = [] + all_columns = set() + scroll_id = response.get("_scroll_id") + processed_count = 0 + + # Process results in batches with explicit limit + max_batches = (limit // batch_size) + 1 # Calculate max batches needed + for _ in range(max_batches): + hits = response.get("hits", {}).get("hits", []) + if not hits: + break + + for hit in hits: + if processed_count >= limit: + break + + doc = hit.get("_source", {}) + if doc: + converted_doc = self._convert_arrays_to_strings(doc) + flattened_doc = self._flatten_document(converted_doc) + if flattened_doc: + records.append(flattened_doc) + all_columns.update(flattened_doc.keys()) + processed_count += 1 + + # Get next batch if we haven't reached the limit + if not scroll_id or processed_count >= limit: + break + try: + response = self.connection.scroll(scroll_id=scroll_id, scroll="5m") + except Exception: + break + + # Normalize records + columns = sorted(all_columns) if all_columns else ["no_data"] + normalized_records = [] + + for record in records: + normalized_records.append([record.get(col) for col in columns]) - response = Response( - RESPONSE_TYPE.TABLE, - data_frame=DataFrame(records, columns=column_names), + if not normalized_records: + normalized_records = [[None] * len(columns)] + + return Response(RESPONSE_TYPE.TABLE, data_frame=DataFrame(normalized_records, columns=columns)) + + except Exception as e: + raise Exception(f"Search API execution failed: {e}") + finally: + # Clean up scroll - ensures cleanup even if exceptions occur + if scroll_id: + try: + self.connection.clear_scroll(scroll_id=scroll_id) + except Exception: + pass + + def _extract_table_name(self, query: str) -> Optional[str]: + """ + Extracts the table/index name from a SQL query. + + Args: + query (str): SQL query string + + Returns: + Optional[str]: The extracted table name, or None if not found + """ + import re + + match = re.search(r'FROM\s+([`"]?)([^`"\s]+)\1', query, re.IGNORECASE) + return match.group(2) if match else None + + def _extract_limit(self, query: str) -> Optional[int]: + """ + Extracts the LIMIT value from a SQL query. + + Args: + query (str): SQL query string + + Returns: + Optional[int]: The extracted limit value, or None if not found + """ + import re + + match = re.search(r"LIMIT\s+(\d+)", query, re.IGNORECASE) + if match: + try: + return int(match.group(1)) + except ValueError: + return None + return None + + def _detect_array_fields(self, index_name: str) -> List[str]: + """ + Detects array fields in the specified index with caching. + + Args: + index_name (str): The name of the index to analyze + + Returns: + List[str]: List of field paths that contain arrays + """ + if index_name in self._array_fields_cache: + return self._array_fields_cache[index_name] + + array_fields = [] + try: + response = self.connection.search( + index=index_name, body={"size": 5, "query": {"match_all": {}}}, _source=True ) - except (TransportError, RequestError) as transport_or_request_error: - logger.error(f"Error running query: {query} on Elasticsearch, {transport_or_request_error}!") - response = Response(RESPONSE_TYPE.ERROR, error_message=str(transport_or_request_error)) - except Exception as unknown_error: - logger.error(f"Unknown error running query: {query} on Elasticsearch, {unknown_error}!") - response = Response(RESPONSE_TYPE.ERROR, error_message=str(unknown_error)) + for hit in response.get("hits", {}).get("hits", []): + doc = hit.get("_source", {}) + array_fields.extend(self._find_arrays_in_doc(doc)) - if need_to_close is True: - self.disconnect() + array_fields = list(set(array_fields)) - return response + # Only cache non-empty results to prevent false negatives + if array_fields: + self._array_fields_cache[index_name] = array_fields + + except Exception as e: + logger.error(f"Array field detection failed for {index_name}: {e}") + + return array_fields + + def _find_arrays_in_doc(self, doc: Any, prefix: str = "") -> List[str]: + """ + Recursively finds array fields in a document. + + Args: + doc (Any): The document to analyze + prefix (str): Current field path prefix for nested fields + + Returns: + List[str]: List of field paths containing arrays + """ + arrays = [] + if isinstance(doc, dict): + for key, value in doc.items(): + field_path = f"{prefix}.{key}" if prefix else key + if isinstance(value, list): + arrays.append(field_path) + elif isinstance(value, dict): + arrays.extend(self._find_arrays_in_doc(value, field_path)) + return arrays + + def _convert_arrays_to_strings(self, obj: Any) -> Any: + """ + Converts arrays to JSON strings for SQL compatibility. + + Args: + obj (Any): Object that may contain arrays + + Returns: + Any: Object with arrays converted to JSON strings + """ + if isinstance(obj, list): + try: + return json.dumps(obj, ensure_ascii=False, default=str) + except (TypeError, ValueError): + return str(obj) + elif isinstance(obj, dict): + return {k: self._convert_arrays_to_strings(v) for k, v in obj.items()} + return obj + + def _flatten_document(self, doc: Dict, prefix: str = "", max_depth: int = 10, _depth: int = 0) -> Dict: + """ + Flattens nested documents with depth protection to prevent stack overflow. + + Args: + doc (Dict): Document to flatten + prefix (str): Field path prefix for nested fields + max_depth (int): Maximum recursion depth to prevent stack overflow + _depth (int): Current recursion depth (internal use) + + Returns: + Dict: Flattened document with dot-notation field names + """ + if not isinstance(doc, dict) or _depth >= max_depth: + return {prefix or "value": str(doc)} + + flattened = {} + for key, value in doc.items(): + field_path = f"{prefix}.{key}" if prefix else key + if isinstance(value, dict): + flattened.update(self._flatten_document(value, field_path, max_depth, _depth + 1)) + else: + flattened[field_path] = value + + return flattened def query(self, query: ASTNode) -> Response: """ - Executes a SQL query represented by an ASTNode on the Elasticsearch host and retrieves the data. + Executes a SQL query represented by an ASTNode on the Elasticsearch host. Args: query (ASTNode): An ASTNode representing the SQL query to be executed. @@ -207,12 +491,14 @@ def query(self, query: ASTNode) -> Response: Returns: Response: The response from the `native_query` method, containing the result of the SQL query execution. """ - # TODO: Add support for other query types. - # Use postgresql dialect for SQL rendering - Elasticsearch SQL is ANSI-compatible - renderer = SqlalchemyRender("postgresql") - query_str = renderer.get_string(query, with_failback=True) - logger.debug(f"Executing SQL query: {query_str}") - return self.native_query(query_str) + try: + renderer = SqlalchemyRender(ESDialect) + query_str = renderer.get_string(query, with_failback=True) + logger.debug(f"Executing AST query as SQL: {query_str}") + return self.native_query(query_str) + except Exception as e: + logger.error(f"AST query execution failed: {e}") + return Response(RESPONSE_TYPE.ERROR, error_message=str(e)) def get_tables(self) -> Response: """ @@ -220,19 +506,17 @@ def get_tables(self) -> Response: Returns: Response: A response object containing a list of tables (indexes) in the Elasticsearch host. + System indices (starting with '.') are filtered out. """ - query = """ - SHOW TABLES - """ + query = "SHOW TABLES" result = self.native_query(query) - df = result.data_frame - - # Remove indices that are system indices: These are indices that start with a period. - df = df[~df["name"].str.startswith(".")] - - df = df.drop(["catalog", "kind"], axis=1) - result.data_frame = df.rename(columns={"name": "table_name", "type": "table_type"}) + if result.type == RESPONSE_TYPE.TABLE: + df = result.data_frame + # Filter out system indexes (starting with .) + df = df[~df["name"].str.startswith(".")] + df = df.drop(["catalog", "kind"], axis=1, errors="ignore") + result.data_frame = df.rename(columns={"name": "table_name", "type": "table_type"}) return result @@ -241,24 +525,458 @@ def get_columns(self, table_name: Text) -> Response: Retrieves column (field) details for a specified table (index) in the Elasticsearch host. Args: - table_name (str): The name of the table for which to retrieve column information. + table_name (Text): The name of the table for which to retrieve column information. + + Returns: + Response: A response object containing the column details. Raises: ValueError: If the 'table_name' is not a valid string. + """ + if not table_name or not isinstance(table_name, str): + raise ValueError("Table name must be a non-empty string") + + query = f"DESCRIBE {table_name}" + result = self.native_query(query) + + if result.type == RESPONSE_TYPE.TABLE: + df = result.data_frame + df = df.drop("mapping", axis=1, errors="ignore") + result.data_frame = df.rename(columns={"column": "COLUMN_NAME", "type": "DATA_TYPE"}) + + return result + + def meta_get_column_statistics_for_table( + self, table_name: str, column_names: Optional[List[str]] = None + ) -> Response: + """ + Retrieves statistics for columns in the specified Elasticsearch index. + + This method uses Elasticsearch aggregations to efficiently gather statistics in a single query: + - Numeric fields: min, max (via stats aggregation) + - Keyword fields: distinct count (cardinality) + - Text fields: distinct count (cardinality on .keyword multi-field) + - Date fields: min, max (via stats aggregation, as timestamps) + - All fields: null percentage (missing values / total docs) + - Object/nested fields: excluded from aggregations, null percentage only + - Nested/array fields: treated as text (cardinality on JSON string representation) + + Implementation Details: + - Text fields use the .keyword multi-field suffix for aggregations + - Object and nested types are skipped for cardinality (not aggregatable) + - If aggregations fail (e.g., text field without .keyword), returns schema with NULL values + - All statistics gathered in a single Elasticsearch search query for performance + + Args: + table_name (str): The name of the index to analyze. + column_names (Optional[List[str]]): Specific column names. If None, returns statistics for all columns. Returns: - Response: A response object containing the column details. + Response: DataFrame with columns: + - TABLE_NAME: Index name + - COLUMN_NAME: Field name + - DATA_TYPE: Elasticsearch field type + - NULL_PERCENTAGE: Percentage of documents missing this field (0.0-100.0) + - DISTINCT_VALUES_COUNT: Approximate count of unique values (0 if not aggregatable) + - MINIMUM_VALUE: Minimum value (numeric/date fields, None otherwise) + - MAXIMUM_VALUE: Maximum value (numeric/date fields, None otherwise) + + Raises: + ValueError: If table_name is invalid or column_names not found in index. + + Example: + >>> handler.meta_get_column_statistics_for_table('kibana_sample_data_flights') + >>> handler.meta_get_column_statistics_for_table('products', ['price', 'quantity']) """ if not table_name or not isinstance(table_name, str): - raise ValueError("Invalid table name provided.") + raise ValueError("Table name must be a non-empty string") + + logger.debug(f"Getting column statistics for {table_name}, columns: {column_names}") + need_to_close = not self.is_connected + connection = self.connect() + + try: + # Step 1: Get index mapping to determine field types + mapping_response = connection.indices.get_mapping(index=table_name) + + # Extract field mappings (handle both single and multi-index responses) + if table_name in mapping_response: + properties = mapping_response[table_name].get("mappings", {}).get("properties", {}) + else: + # For wildcard or first index in response + first_index = list(mapping_response.keys())[0] + properties = mapping_response[first_index].get("mappings", {}).get("properties", {}) + + if not properties: + logger.warning(f"No properties found for index {table_name}") + return Response( + RESPONSE_TYPE.TABLE, + data_frame=DataFrame( + columns=[ + "TABLE_NAME", + "COLUMN_NAME", + "DATA_TYPE", + "NULL_PERCENTAGE", + "DISTINCT_VALUES_COUNT", + "MINIMUM_VALUE", + "MAXIMUM_VALUE", + ] + ), + ) + + # Step 2: Flatten nested field mappings and filter by column_names if provided + fields_to_analyze = {} + self._extract_fields_from_mapping(properties, fields_to_analyze, prefix="") + + if column_names: + # Filter to only requested columns + filtered_fields = {} + for col_name in column_names: + if col_name not in fields_to_analyze: + raise ValueError(f"Column '{col_name}' not found in index '{table_name}'") + filtered_fields[col_name] = fields_to_analyze[col_name] + fields_to_analyze = filtered_fields + + # Step 3: Build comprehensive aggregation query + aggs = {} + for field_name, field_info in fields_to_analyze.items(): + field_type = field_info.get("type", "object") + safe_field_name = field_name.replace(".", "_") + + # Skip object/nested types - they don't support aggregations + if field_type in ["object", "nested"]: + continue + + # Determine aggregation field (text fields need .keyword suffix) + agg_field = field_name + if field_type == "text": + # Check if .keyword multi-field exists in mapping + multi_fields = field_info.get("fields", {}) + if "keyword" in multi_fields: + # Text field has .keyword multi-field for aggregations + agg_field = f"{field_name}.keyword" + else: + # Text field without .keyword - skip this field + # (fielddata would need to be enabled, which is not recommended for text fields) + logger.debug(f"Text field '{field_name}' has no .keyword multi-field, skipping") + continue + + # Cardinality aggregation for distinct count + aggs[f"{safe_field_name}_cardinality"] = { + "cardinality": { + "field": agg_field, + "precision_threshold": 3000, # Improves performance on large datasets + } + } + + # Missing aggregation for null count + aggs[f"{safe_field_name}_missing"] = {"missing": {"field": field_name}} + + # Stats aggregation for numeric and date fields + if field_type in [ + "long", + "integer", + "short", + "byte", + "double", + "float", + "half_float", + "scaled_float", + "date", + ]: + aggs[f"{safe_field_name}_stats"] = {"stats": {"field": field_name}} + + # Step 4: Execute single aggregation query for all statistics + search_body = { + "size": 0, # We only need aggregations, not documents + "aggs": aggs, + } + + logger.debug(f"Executing aggregation query with {len(aggs)} aggregations") + + # Execute aggregation query with error handling for field-specific failures + try: + agg_response = connection.search(index=table_name, body=search_body) + except Exception as search_error: + # If aggregation fails (e.g., text field without .keyword), log and retry without problematic aggs + error_msg = str(search_error).lower() + if "fielddata" in error_msg or "keyword" in error_msg or "text" in error_msg: + logger.warning(f"Aggregation failed, possibly due to text field without fielddata: {search_error}") + # Return basic statistics without aggregations + stats_data = [] + for field_name, field_info in fields_to_analyze.items(): + stats_data.append( + { + "TABLE_NAME": table_name, + "COLUMN_NAME": field_name, + "DATA_TYPE": field_info.get("type", "object"), + "NULL_PERCENTAGE": None, + "DISTINCT_VALUES_COUNT": None, + "MINIMUM_VALUE": None, + "MAXIMUM_VALUE": None, + } + ) + return Response(RESPONSE_TYPE.TABLE, data_frame=DataFrame(stats_data)) + else: + raise + + # Step 5: Parse aggregation results into statistics + # Get total document count for NULL_PERCENTAGE calculation + total_docs = agg_response.get("hits", {}).get("total", {}) + if isinstance(total_docs, dict): + total_doc_count = total_docs.get("value", 0) + else: + total_doc_count = total_docs # ES 6.x returns int directly + + stats_data = [] + for field_name, field_info in fields_to_analyze.items(): + field_type = field_info.get("type", "object") + safe_field_name = field_name.replace(".", "_") + + aggregations = agg_response.get("aggregations", {}) + + # Extract cardinality (distinct count) + cardinality_key = f"{safe_field_name}_cardinality" + cardinality_result = aggregations.get(cardinality_key, {}) + distinct_count = int(cardinality_result.get("value", 0)) if cardinality_result else 0 + + # Extract missing count and calculate NULL_PERCENTAGE + missing_key = f"{safe_field_name}_missing" + missing_result = aggregations.get(missing_key, {}) + null_count = missing_result.get("doc_count", 0) if missing_result else 0 + null_percentage = (null_count / total_doc_count * 100.0) if total_doc_count > 0 else 0.0 + + # Extract stats for numeric/date fields + stats_key = f"{safe_field_name}_stats" + stats = aggregations.get(stats_key, {}) + + min_val = stats.get("min") if stats else None + max_val = stats.get("max") if stats else None + + stats_data.append( + { + "TABLE_NAME": table_name, + "COLUMN_NAME": field_name, + "DATA_TYPE": field_type, + "NULL_PERCENTAGE": null_percentage, + "DISTINCT_VALUES_COUNT": distinct_count, + "MINIMUM_VALUE": min_val, + "MAXIMUM_VALUE": max_val, + } + ) + + result_df = DataFrame(stats_data) + logger.debug(f"Retrieved statistics for {len(stats_data)} fields") + + return Response(RESPONSE_TYPE.TABLE, data_frame=result_df) + + except ValueError: + # Re-raise ValueError (e.g., invalid column name) as-is + raise + except Exception as e: + logger.error(f"Failed to get column statistics: {e}") + return Response(RESPONSE_TYPE.ERROR, error_message=str(e)) - query = f""" - DESCRIBE {table_name} + finally: + if need_to_close: + self.disconnect() + + def _extract_fields_from_mapping(self, properties: Dict, fields: Dict, prefix: str = "") -> None: """ + Recursively extracts field definitions from Elasticsearch mapping. + + This helper method flattens nested object and nested type fields into dot-notation paths. + + Args: + properties (Dict): Field properties from mapping + fields (Dict): Output dictionary to populate with field definitions + prefix (str): Current field path prefix for nested fields + """ + for field_name, field_def in properties.items(): + full_field_name = f"{prefix}.{field_name}" if prefix else field_name + field_type = field_def.get("type") + + if field_type: + # Regular field with a type + fields[full_field_name] = field_def + elif "properties" in field_def: + # Nested object - recurse into it + self._extract_fields_from_mapping(field_def["properties"], fields, full_field_name) + else: + # Field without type or properties (treat as object) + fields[full_field_name] = {"type": "object"} + + def meta_get_primary_keys(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves the primary keys for the specified Elasticsearch indices. + + In Elasticsearch, the _id field serves as the implicit primary key for each document. + This method always returns _id as the primary key for each table. + + Args: + table_names (Optional[List[str]]): List of index names. If None, returns primary keys for all tables. + + Returns: + Response: DataFrame with columns: + - TABLE_NAME: Name of the index + - CONSTRAINT_NAME: Name of the primary key constraint + - COLUMN_NAME: The column name (_id) + + Example: + >>> handler.meta_get_primary_keys(['products', 'orders']) + # Returns: TABLE_NAME='products', CONSTRAINT_NAME='PRIMARY', COLUMN_NAME='_id' + # TABLE_NAME='orders', CONSTRAINT_NAME='PRIMARY', COLUMN_NAME='_id' + """ + logger.debug(f"Getting primary keys for tables: {table_names}") + + # If no table names specified, get all tables + if not table_names: + tables_response = self.get_tables() + if tables_response.type == RESPONSE_TYPE.ERROR: + return tables_response + table_names = tables_response.data_frame["TABLE_NAME"].tolist() + + # Elasticsearch always uses _id as the document identifier (primary key) + pk_data = [] + for table_name in table_names: + pk_data.append({"TABLE_NAME": table_name, "CONSTRAINT_NAME": "PRIMARY", "COLUMN_NAME": "_id"}) + + return Response(RESPONSE_TYPE.TABLE, data_frame=DataFrame(pk_data)) + + def meta_get_foreign_keys(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves foreign keys for the specified Elasticsearch indices. + + Elasticsearch is a NoSQL document store and does not support foreign key constraints. + This method always returns an empty DataFrame with the proper structure. + + Args: + table_names (Optional[List[str]]): List of index names. If None, applies to all tables. + + Returns: + Response: Empty DataFrame with columns: + - CHILD_TABLE_NAME: The table containing the foreign key + - CHILD_COLUMN_NAME: The column name + - PARENT_TABLE_NAME: The referenced table name + - PARENT_COLUMN_NAME: The referenced column name + - CONSTRAINT_NAME: Foreign key constraint name + + Example: + >>> handler.meta_get_foreign_keys(['products']) + # Returns: Empty DataFrame (NoSQL has no foreign keys) + """ + logger.debug(f"Getting foreign keys for tables: {table_names} (NoSQL - will return empty)") + + # Elasticsearch is NoSQL and doesn't have foreign key constraints + return Response( + RESPONSE_TYPE.TABLE, + data_frame=DataFrame( + columns=[ + "CHILD_TABLE_NAME", + "CHILD_COLUMN_NAME", + "PARENT_TABLE_NAME", + "PARENT_COLUMN_NAME", + "CONSTRAINT_NAME", + ] + ), + ) + + def meta_get_tables(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves metadata for tables (indices) in the Elasticsearch host. + + Args: + table_names (Optional[List[str]]): List of specific table names to retrieve. + If None, returns all non-system tables. + + Returns: + Response: DataFrame with columns: + - TABLE_NAME: Name of the index + - TABLE_TYPE: Type of table (always 'BASE TABLE' for Elasticsearch) + + Example: + >>> handler.meta_get_tables(['products', 'orders']) + >>> handler.meta_get_tables() # Returns all tables + """ + logger.debug(f"Getting table metadata for: {table_names}") + + # Get all tables using SHOW TABLES + query = "SHOW TABLES" result = self.native_query(query) + if result.type != RESPONSE_TYPE.TABLE: + return result + df = result.data_frame - df = df.drop("mapping", axis=1) - result.data_frame = df.rename(columns={"column": "column_name", "type": "data_type"}) + # Filter out system indexes (starting with .) + df = df[~df["name"].str.startswith(".")] + + # Filter by requested table names if provided + if table_names: + df = df[df["name"].isin(table_names)] + + # Drop unnecessary columns and rename to match spec + df = df.drop(["catalog", "kind"], axis=1, errors="ignore") + df = df.rename(columns={"name": "TABLE_NAME", "type": "TABLE_TYPE"}) + + result.data_frame = df return result + + def meta_get_columns(self, table_names: Optional[List[str]] = None) -> Response: + """ + Retrieves column metadata for tables (indices) in the Elasticsearch host. + + Args: + table_names (Optional[List[str]]): List of specific table names to retrieve columns for. + If None, returns columns for all tables. + + Returns: + Response: DataFrame with columns: + - TABLE_NAME: Name of the index + - COLUMN_NAME: Name of the field/column + - DATA_TYPE: Elasticsearch data type + + Example: + >>> handler.meta_get_columns(['products']) + >>> handler.meta_get_columns() # Returns columns for all tables + """ + logger.debug(f"Getting column metadata for tables: {table_names}") + + # If no table names specified, get all tables first + if not table_names: + tables_response = self.meta_get_tables() + if tables_response.type == RESPONSE_TYPE.ERROR: + return tables_response + table_names = tables_response.data_frame["TABLE_NAME"].tolist() + + # Collect columns for each table + all_columns_data = [] + for table_name in table_names: + try: + query = f"DESCRIBE {table_name}" + result = self.native_query(query) + + if result.type == RESPONSE_TYPE.TABLE: + df = result.data_frame + df = df.drop("mapping", axis=1, errors="ignore") + df = df.rename(columns={"column": "COLUMN_NAME", "type": "DATA_TYPE"}) + # Add TABLE_NAME column + df["TABLE_NAME"] = table_name + all_columns_data.append(df) + except Exception as e: + logger.warning(f"Failed to get columns for table {table_name}: {e}") + continue + + # Combine all results + if all_columns_data: + combined_df = DataFrame() + for df in all_columns_data: + combined_df = combined_df._append(df, ignore_index=True) if not combined_df.empty else df + # Reorder columns to match spec + combined_df = combined_df[["TABLE_NAME", "COLUMN_NAME", "DATA_TYPE"]] + return Response(RESPONSE_TYPE.TABLE, data_frame=combined_df) + else: + return Response( + RESPONSE_TYPE.TABLE, data_frame=DataFrame(columns=["TABLE_NAME", "COLUMN_NAME", "DATA_TYPE"]) + ) diff --git a/mindsdb/integrations/handlers/elasticsearch_handler/requirements.txt b/mindsdb/integrations/handlers/elasticsearch_handler/requirements.txt index 5b0adfd5730..35b2d6333ae 100644 --- a/mindsdb/integrations/handlers/elasticsearch_handler/requirements.txt +++ b/mindsdb/integrations/handlers/elasticsearch_handler/requirements.txt @@ -1,2 +1,2 @@ -elasticsearch>=8.0.0,<9.0.0 -urllib3>=2.6.0 # not directly required, pinned by Snyk to avoid a vulnerability +elasticsearch>=7.13.4,<9.0.0 +elasticsearch-dbapi>=0.2.9 \ No newline at end of file From 22b1bd9831ba15a7ca5729cd7e1eca4aa6593e4b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:35:52 +0300 Subject: [PATCH 067/125] Bump pyjwt from 2.10.1 to 2.12.0 in /requirements (#12293) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 96af4155038..ca331360bb1 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -35,7 +35,7 @@ a2wsgi ~= 1.10.10 # WSGI wrapper for flask+starlette starlette>=0.49.1 sse-starlette==2.3.3 pydantic_core>=2.33.2 -pyjwt==2.10.1 +pyjwt==2.12.0 # files reading pymupdf==1.25.2 filetype From b10539d88026235388aab59ca92b796d66d22422 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:37:57 +0300 Subject: [PATCH 068/125] Bump nltk from 3.9.1 to 3.9.3 in /mindsdb/integrations/handlers/huggingface_handler (#12276) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../integrations/handlers/huggingface_handler/requirements.txt | 2 +- .../handlers/huggingface_handler/requirements_cpu.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mindsdb/integrations/handlers/huggingface_handler/requirements.txt b/mindsdb/integrations/handlers/huggingface_handler/requirements.txt index b70a302214c..f4291850dcf 100644 --- a/mindsdb/integrations/handlers/huggingface_handler/requirements.txt +++ b/mindsdb/integrations/handlers/huggingface_handler/requirements.txt @@ -1,7 +1,7 @@ # NOTE: Any changes made here need to be made to requirements_cpu.txt as well datasets==2.16.1 evaluate==0.4.3 -nltk==3.9.1 +nltk==3.9.3 huggingface-hub==0.29.3 torch==2.8.0 transformers >= 4.42.4 diff --git a/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt b/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt index 7a4e0de6084..b60dc5172ae 100644 --- a/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt +++ b/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt @@ -1,7 +1,7 @@ # Needs to be installed with `pip install --extra-index-url https://download.pytorch.org/whl/ .[huggingface_cpu]` datasets==2.16.1 evaluate==0.4.3 -nltk==3.9.1 +nltk==3.9.3 huggingface-hub==0.29.3 torch==2.8.0+cpu transformers >= 4.42.4 \ No newline at end of file From 60d5a77499f3f50d4439308141e15b9addafa755 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:39:04 +0300 Subject: [PATCH 069/125] Bump basic-ftp from 5.1.0 to 5.2.0 in /docs (#12247) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> From f1f39c24f341579d921a692e823ca08cfe0a5fa2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:42:13 +0300 Subject: [PATCH 070/125] Bump protobuf from 4.25.8 to 5.29.6 in /mindsdb/integrations/handlers/lindorm_handler (#12201) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- mindsdb/integrations/handlers/lindorm_handler/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/lindorm_handler/requirements.txt b/mindsdb/integrations/handlers/lindorm_handler/requirements.txt index 526500be75b..0c562cda190 100644 --- a/mindsdb/integrations/handlers/lindorm_handler/requirements.txt +++ b/mindsdb/integrations/handlers/lindorm_handler/requirements.txt @@ -1,3 +1,3 @@ pyphoenix phoenixdb -protobuf==4.25.8 \ No newline at end of file +protobuf==5.29.6 \ No newline at end of file From 9452142548b6b874b59db5df773300c6b0964b63 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 08:11:08 -0500 Subject: [PATCH 071/125] Bump ajv from 8.17.1 to 8.18.0 in /docs (#12240) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> From 8bac3072cf59daeefb8cf3c82af4bf3a72b2a659 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Fri, 20 Mar 2026 16:13:17 +0300 Subject: [PATCH 072/125] Fix BigQuery test (#12313) --- tests/unit/handlers/test_bigquery.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/handlers/test_bigquery.py b/tests/unit/handlers/test_bigquery.py index 0f67b430ea0..37eb80cb75e 100644 --- a/tests/unit/handlers/test_bigquery.py +++ b/tests/unit/handlers/test_bigquery.py @@ -78,7 +78,7 @@ def test_native_query(self): self.handler.connect = MagicMock(return_value=mock_conn) mock_query = MagicMock() - mock_query.to_dataframe.return_value = None + mock_query.to_dataframe.return_value = pd.DataFrame({"col": [1, 2, 3]}) mock_conn.query.return_value = mock_query query_str = "SELECT * FROM table" @@ -89,8 +89,7 @@ def test_native_query(self): mock_query_job_config_instance = mock_query_job_config.return_value data = self.handler.native_query(query_str) mock_conn.query.assert_called_once_with(query_str, job_config=mock_query_job_config_instance) - assert isinstance(data, DataHandlerResponse) - self.assertFalse(data.error_code) + assert isinstance(data, TableResponse) def test_native_query_empty_select_returns_table(self): mock_conn = MagicMock() From 113707593beb672183fa92bc9962063ebcb164e7 Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Fri, 20 Mar 2026 13:15:20 +0000 Subject: [PATCH 073/125] Fix Postgres column stats UnboundLocalError on errors (#12253) --- .../postgres_handler/postgres_handler.py | 52 ++++++++++--------- tests/unit/handlers/test_postgres.py | 9 ++++ 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py b/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py index 249c31a842f..a3456a8e95a 100644 --- a/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py +++ b/mindsdb/integrations/handlers/postgres_handler/postgres_handler.py @@ -758,31 +758,33 @@ def meta_get_column_statistics(self, table_names: Optional[list] = None) -> Resp result = self.native_query(query) - if result.type == RESPONSE_TYPE.TABLE and result.data_frame is not None: - df = result.data_frame - - # Extract min/max from histogram bounds - def extract_min_max(histogram_str): - if histogram_str and str(histogram_str) != "nan": - clean = str(histogram_str).strip("{}") - if clean: - values = clean.split(",") - min_val = values[0].strip(" \"'") if values else None - max_val = values[-1].strip(" \"'") if values else None - return min_val, max_val - return None, None - - min_max_values = df["histogram_bounds"].apply(extract_min_max) - df["MINIMUM_VALUE"] = min_max_values.apply(lambda x: x[0]) - df["MAXIMUM_VALUE"] = min_max_values.apply(lambda x: x[1]) - - # Convert most_common_values and most_common_freqs to arrays. - df["MOST_COMMON_VALUES"] = df["most_common_values"].apply( - lambda x: x.strip("{}").split(",") if isinstance(x, str) else [] - ) - df["MOST_COMMON_FREQUENCIES"] = df["most_common_frequencies"].apply( - lambda x: x.strip("{}").split(",") if isinstance(x, str) else [] - ) + if result.type != RESPONSE_TYPE.TABLE or result.data_frame is None: + return result + + df = result.data_frame + + # Extract min/max from histogram bounds + def extract_min_max(histogram_str): + if histogram_str and str(histogram_str) != "nan": + clean = str(histogram_str).strip("{}") + if clean: + values = clean.split(",") + min_val = values[0].strip(" \"'") if values else None + max_val = values[-1].strip(" \"'") if values else None + return min_val, max_val + return None, None + + min_max_values = df["histogram_bounds"].apply(extract_min_max) + df["MINIMUM_VALUE"] = min_max_values.apply(lambda x: x[0]) + df["MAXIMUM_VALUE"] = min_max_values.apply(lambda x: x[1]) + + # Convert most_common_values and most_common_freqs to arrays. + df["MOST_COMMON_VALUES"] = df["most_common_values"].apply( + lambda x: x.strip("{}").split(",") if isinstance(x, str) else [] + ) + df["MOST_COMMON_FREQUENCIES"] = df["most_common_frequencies"].apply( + lambda x: x.strip("{}").split(",") if isinstance(x, str) else [] + ) result.data_frame = df.drop(columns=["histogram_bounds", "most_common_values", "most_common_frequencies"]) diff --git a/tests/unit/handlers/test_postgres.py b/tests/unit/handlers/test_postgres.py index dc6d8c64569..f8eafb5849a 100644 --- a/tests/unit/handlers/test_postgres.py +++ b/tests/unit/handlers/test_postgres.py @@ -363,6 +363,15 @@ def test_insert_respects_existing_column_case(self): self.assertIn('"Id"', executed_copy) self.assertIn('"Amount"', executed_copy) + def test_meta_get_column_statistics_returns_non_table_response(self): + error_response = Response(RESPONSE_TYPE.ERROR, error_message="boom") + self.handler.native_query = MagicMock(return_value=error_response) + + result = self.handler.meta_get_column_statistics() + + self.assertIs(result, error_response) + self.handler.native_query.assert_called_once() + def test_cast_dtypes(self): """ Tests the _cast_dtypes method to ensure it correctly converts PostgreSQL types to pandas types From d882ad7210ae00564b79f401b17a52f7244f994e Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Fri, 20 Mar 2026 16:33:18 +0300 Subject: [PATCH 074/125] Fix DataBricks handler test (#12315) --- tests/unit/handlers/test_databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/handlers/test_databricks.py b/tests/unit/handlers/test_databricks.py index 74f62a731e7..659a389a1ed 100644 --- a/tests/unit/handlers/test_databricks.py +++ b/tests/unit/handlers/test_databricks.py @@ -215,7 +215,7 @@ def test_get_tables(self): self.handler.native_query.assert_called_once_with(expected_query) def test_get_tables_returns_non_table_response_without_transform(self): - expected = Response(RESPONSE_TYPE.ERROR, error_message="boom") + expected = ErrorResponse(error_message="boom") self.handler.native_query = MagicMock(return_value=expected) result = self.handler.get_tables() From 49be8fd776169f92f70d14d7f704c824ee223b75 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 20 Mar 2026 17:05:00 +0300 Subject: [PATCH 075/125] fix --- tests/unit/handlers/test_postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/handlers/test_postgres.py b/tests/unit/handlers/test_postgres.py index f8eafb5849a..a0e3adc1335 100644 --- a/tests/unit/handlers/test_postgres.py +++ b/tests/unit/handlers/test_postgres.py @@ -364,7 +364,7 @@ def test_insert_respects_existing_column_case(self): self.assertIn('"Amount"', executed_copy) def test_meta_get_column_statistics_returns_non_table_response(self): - error_response = Response(RESPONSE_TYPE.ERROR, error_message="boom") + error_response = ErrorResponse(error_message="boom") self.handler.native_query = MagicMock(return_value=error_response) result = self.handler.meta_get_column_statistics() From 47cc3109bdbe31fd00cfae04fe15c82841a0308f Mon Sep 17 00:00:00 2001 From: RITWICK RAJ MAKHAL Date: Fri, 20 Mar 2026 20:03:51 +0530 Subject: [PATCH 076/125] Update strapi handler to support v5 (#11862) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michael Olayemi Olawepo <154475559+sejubar@users.noreply.github.com> Co-authored-by: andrew Co-authored-by: April I. Murphy <36110273+aimurphy@users.noreply.github.com> Co-authored-by: Minura Punchihewa <49385643+MinuraPunchihewa@users.noreply.github.com> Co-authored-by: Konstantin Sivakov Co-authored-by: martyna-mindsdb <109554435+martyna-mindsdb@users.noreply.github.com> Co-authored-by: Sebastián Tobón Hernández --- docs/integrations/app-integrations/strapi.mdx | 8 +- .../handlers/strapi_handler/README.md | 10 +- .../handlers/strapi_handler/__about__.py | 2 +- .../handlers/strapi_handler/strapi_handler.py | 149 ++++--- .../handlers/strapi_handler/strapi_tables.py | 421 +++++++++++++++--- .../tests/test_strapi_handler.py | 192 +++++++- 6 files changed, 637 insertions(+), 145 deletions(-) diff --git a/docs/integrations/app-integrations/strapi.mdx b/docs/integrations/app-integrations/strapi.mdx index cdb66e063e4..e92d560d632 100644 --- a/docs/integrations/app-integrations/strapi.mdx +++ b/docs/integrations/app-integrations/strapi.mdx @@ -14,7 +14,7 @@ To use the Strapi Handler, initialize it with the following parameters: - `host`: Strapi server host. - `port`: Strapi server port (typically 1337). - `api_token`: Strapi server API token for authentication. -- `plural_api_ids`: List of plural API IDs for the collections. +- `endpoints`: List of collection endpoints. To get started, create a Strapi engine database with the following SQL command: @@ -25,7 +25,7 @@ PARAMETERS = { "host" : "", --- Host (can be an IP address or URL). "port" : "", --- Common port is 1337. "api_token": "", --- API token of the Strapi server. - "plural_api_ids" : [""] --- Plural API IDs of the collections. + "endpoints" : [""] --- Collection endpoints. }; ``` @@ -43,7 +43,7 @@ Filter data based on specific criteria: ```sql SELECT * FROM myshop. -WHERE id = +WHERE documentId = ''; ``` Insert new data into a collection: @@ -64,7 +64,7 @@ Modify existing data in a collection: ```sql UPDATE myshop. SET = , = , ... -WHERE id = ; +WHERE documentId = ''; ``` diff --git a/mindsdb/integrations/handlers/strapi_handler/README.md b/mindsdb/integrations/handlers/strapi_handler/README.md index 5595ce71c99..5a9341bd308 100644 --- a/mindsdb/integrations/handlers/strapi_handler/README.md +++ b/mindsdb/integrations/handlers/strapi_handler/README.md @@ -13,7 +13,7 @@ The Strapi handler is initialized with the following parameters: - `host` - the host of the Strapi server - `port` - the port of the Strapi server - `api_token` - the api token of the Strapi server -- `plural_api_ids` - the list of plural api ids of the collections +- `endpoints` - the list of collection endpoints ## Implemented Features @@ -36,7 +36,7 @@ PARAMETERS = { "host" : "", --- host, it can be an ip or an url. "port" : "", --- common port is 1337. "api_token": "", --- api token of the strapi server. - "plural_api_ids" : [""] --- plural api ids of the collections. + "endpoints" : [""] --- collection endpoints. }; ``` @@ -49,7 +49,7 @@ PARAMETERS = { "host" : "localhost", "port" : "1337", "api_token": "c56c000d867e95848c", - "plural_api_ids" : ["products", "sellers"] + "endpoints" : ["products", "sellers"] }; ``` @@ -84,7 +84,7 @@ Example: ```sql SELECT description, price FROM myshop.products -WHERE id = 1; +WHERE documentId = 'mvaprjyy72ayx7z4v592sdnr'; ``` --- @@ -140,7 +140,7 @@ Example UPDATE myshop.products SET price = 299, avaiablity = false -WHERE id = 1; +WHERE documentId = 'mvaprjyy72ayx7z4v592sdnr'; ``` Note: You only able to update data into the collection which has `update` permission. diff --git a/mindsdb/integrations/handlers/strapi_handler/__about__.py b/mindsdb/integrations/handlers/strapi_handler/__about__.py index 199f17ec162..d86a20889b2 100644 --- a/mindsdb/integrations/handlers/strapi_handler/__about__.py +++ b/mindsdb/integrations/handlers/strapi_handler/__about__.py @@ -1,6 +1,6 @@ __title__ = "MindsDB Strapi handler" __package_name__ = "mindsdb_strapi_handler" -__version__ = "0.0.1" +__version__ = "0.0.2" __description__ = "MindsDB handler for Strapi" __author__ = "Ritwick Raj Makhal" __github__ = "https://github.com/mindsdb/mindsdb" diff --git a/mindsdb/integrations/handlers/strapi_handler/strapi_handler.py b/mindsdb/integrations/handlers/strapi_handler/strapi_handler.py index 2338edc0f0d..555f3ce96e0 100644 --- a/mindsdb/integrations/handlers/strapi_handler/strapi_handler.py +++ b/mindsdb/integrations/handlers/strapi_handler/strapi_handler.py @@ -1,6 +1,6 @@ from mindsdb.integrations.handlers.strapi_handler.strapi_tables import StrapiTable from mindsdb.integrations.libs.api_handler import APIHandler -from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse +from mindsdb.integrations.libs.response import HandlerResponse, RESPONSE_TYPE, HandlerStatusResponse as StatusResponse from mindsdb_sql_parser import parse_sql from mindsdb.utilities import log import requests @@ -20,18 +20,40 @@ def __init__(self, name: str, **kwargs) -> None: """ super().__init__(name) - self.connection = None - self.is_connected = False - args = kwargs.get('connection_data', {}) - if 'host' in args and 'port' in args: - self._base_url = f"http://{args['host']}:{args['port']}" - if 'api_token' in args: - self._api_token = args['api_token'] - if 'plural_api_ids' in args: - self._plural_api_ids = args['plural_api_ids'] - # Registers tables for each collections in strapi - for pluralApiId in self._plural_api_ids: - self._register_table(table_name=pluralApiId, table_class=StrapiTable(handler=self, name=pluralApiId)) + self._connection_cache = {} + self._table_schemas = {} + + args = kwargs.get("connection_data", {}) + # Handle both complete URLs and host+port combinations + if "url" in args and args.get("url"): + # Complete URL provided (e.g., https://my-strapi.herokuapp.com) + self._base_url = args.get("url").rstrip("/") + elif "host" in args and args.get("host"): + # Traditional host + port setup + host = args.get("host", "") + port = args.get("port", "") + + # Determine protocol + protocol = "https" if args.get("ssl", False) else "http" + + if port: + self._base_url = f"{protocol}://{host}:{port}" + else: + self._base_url = f"{protocol}://{host}" + else: + self._base_url = None + self._api_token = args.get("api_token") + self._endpoints = args.get("endpoints", []) + + self._connection_key = f"{self._base_url}_{self._api_token}" + + # Use cached connection status + self.is_connected = self._connection_cache.get(self._connection_key, False) + + # Register tables but defer schema fetching + for endpoint in self._endpoints: + table_instance = StrapiTable(handler=self, name=endpoint, defer_schema_fetch=True) + self._register_table(table_name=endpoint, table_class=table_instance) def check_connection(self) -> StatusResponse: """checking the connection @@ -39,36 +61,50 @@ def check_connection(self) -> StatusResponse: Returns: StatusResponse: whether the connection is still up """ - response = StatusResponse(False) - try: - self.connect() - response.success = True - except Exception as e: - logger.error(f'Error connecting to Strapi API: {e}!') - response.error_message = e - - self.is_connected = response.success - return response + if self._connection_cache.get(self._connection_key, False): + self.is_connected = True + return StatusResponse(True) + return self.connect() def connect(self) -> StatusResponse: - """making the connectino object - """ - if self.is_connected and self.connection: - return self.connection + """making the connectino object""" + if self._connection_cache.get(self._connection_key, False): + self.is_connected = True + return StatusResponse(True) try: headers = {"Authorization": f"Bearer {self._api_token}"} response = requests.get(f"{self._base_url}", headers=headers) if response.status_code == 200: - self.connection = response self.is_connected = True + self._connection_cache[self._connection_key] = True return StatusResponse(True) else: raise Exception(f"Error connecting to Strapi API: {response.status_code} - {response.text}") except Exception as e: - logger.error(f'Error connecting to Strapi API: {e}!') + logger.error(f"Error connecting to Strapi API: {e}!") + self._connection_cache[self._connection_key] = False return StatusResponse(False, error_message=e) + def get_tables(self) -> HandlerResponse: + """ + Return list of available Strapi collections + Returns: + RESPONSE_TYPE.TABLE + """ + result = self._endpoints + + df = pd.DataFrame(result, columns=["table_name"]) + df["table_type"] = "BASE TABLE" + + return HandlerResponse(RESPONSE_TYPE.TABLE, df) + + def get_table(self, table_name: str): + """Create table instance on demand""" + if table_name in self._endpoints: + return StrapiTable(handler=self, name=table_name) + raise ValueError(f"Table {table_name} not found in your Strapi collections.") + def native_query(self, query: str) -> StatusResponse: """Receive and process a raw query. @@ -86,32 +122,39 @@ def native_query(self, query: str) -> StatusResponse: return self.query(ast) def call_strapi_api(self, method: str, endpoint: str, params: dict = {}, json_data: dict = {}) -> pd.DataFrame: - headers = {"Authorization": f"Bearer {self._api_token}"} - url = f"{self._base_url}{endpoint}" + headers = {"Content-Type": "application/json"} + # Add Authorization header only if API token is provided + if self._api_token: + headers["Authorization"] = f"Bearer {self._api_token}" - if method.upper() in ('GET', 'POST', 'PUT', 'DELETE'): - headers['Content-Type'] = 'application/json' + url = f"{self._base_url}{endpoint}" - if method.upper() in ('POST', 'PUT', 'DELETE'): + if method.upper() in ("GET", "POST", "PUT", "DELETE"): + if method.upper() in ("POST", "PUT", "DELETE"): response = requests.request(method, url, headers=headers, params=params, data=json_data) else: response = requests.get(url, headers=headers, params=params) - if response.status_code == 200: - data = response.json() - # Create an empty DataFrame - df = pd.DataFrame() - if isinstance(data.get('data', None), list): - for item in data['data']: - # Add 'id' and 'attributes' to the DataFrame - row_data = {'id': item['id'], **item['attributes']} - df = df._append(row_data, ignore_index=True) - return df - elif isinstance(data.get('data', None), dict): - # Add 'id' and 'attributes' to the DataFrame - row_data = {'id': data['data']['id'], **data['data']['attributes']} - df = df._append(row_data, ignore_index=True) - return df + if response.status_code == 200 or response.status_code == 201: + response_data = response.json() + + # Check if response has 'data' key + if "data" not in response_data: + raise Exception(f"Malformed API response: missing 'data' key in response from {endpoint}") + + data = response_data["data"] + + # Check if data is of expected type (list or dict) + if isinstance(data, list): + df = pd.DataFrame(data) + elif isinstance(data, dict): + df = pd.DataFrame([data]) + else: + raise Exception( + f"Malformed API response: 'data' key contains unexpected type {type(data).__name__}, expected list or dict from {endpoint}" + ) + + return df else: raise Exception(f"Error connecting to Strapi API: {response.status_code} - {response.text}") @@ -137,11 +180,11 @@ def call_strapi_api(self, method: str, endpoint: str, params: dict = {}, json_da "required": True, "label": "Port", }, - plural_api_ids={ + endpoints={ "type": list, - "description": "Plural API id to use for querying.", + "description": "Collection endpoints to use for querying.", "required": True, - "label": "Plural API id", + "label": "Endpoints", }, ) @@ -149,5 +192,5 @@ def call_strapi_api(self, method: str, endpoint: str, params: dict = {}, json_da host="localhost", port=1337, api_token="c56c000d867e95848c", - plural_api_ids=["posts", "portfolios"], + endpoints=["posts", "portfolios"], ) diff --git a/mindsdb/integrations/handlers/strapi_handler/strapi_tables.py b/mindsdb/integrations/handlers/strapi_handler/strapi_tables.py index 48b125ad976..df7a3e59abf 100644 --- a/mindsdb/integrations/handlers/strapi_handler/strapi_tables.py +++ b/mindsdb/integrations/handlers/strapi_handler/strapi_tables.py @@ -1,19 +1,284 @@ -from typing import List +from typing import List, Dict, Any import pandas as pd from mindsdb.integrations.libs.api_handler import APIHandler, APITable from mindsdb_sql_parser import ast from mindsdb.integrations.utilities.sql_utils import extract_comparison_conditions +from mindsdb_sql_parser.ast.select.operation import BetweenOperation from mindsdb_sql_parser.ast.select.constant import Constant +from mindsdb_sql_parser.ast.base import ASTNode import json -class StrapiTable(APITable): +def extract_or_conditions(node: ASTNode) -> list: + """Extract WHERE conditions as DNF (OR of AND groups). + + Args: + node: The AST node representing the WHERE clause + + Returns: + List of conjunction groups where each inner list is ANDed and + outer list is ORed. + + Examples: + - a = 1 AND b = 2 -> [[(a=1), (b=2)]] + - a = 1 OR b = 2 -> [[(a=1)], [(b=2)]] + - a = 1 OR (b = 2 AND c = 4) -> [[(a=1)], [(b=2), (c=4)]] + """ + + def extract_single_condition(node: ASTNode) -> tuple: + if isinstance(node, ast.BinaryOperation): + op = node.op.lower() + arg1, arg2 = node.args + if not isinstance(arg1, ast.Identifier): + raise NotImplementedError(f"Not implemented arg1: {arg1}") + if isinstance(arg2, ast.Constant): + value = arg2.value + return (op, arg1.parts[-1], value) + # Add this new condition for BETWEEN + elif isinstance(node, BetweenOperation): + field = node.args[0] # The field being tested + min_val = node.args[1] # Lower bound + max_val = node.args[2] # Upper bound + + if ( + isinstance(field, ast.Identifier) + and isinstance(min_val, ast.Constant) + and isinstance(max_val, ast.Constant) + ): + return ("between", field.parts[-1], [min_val.value, max_val.value]) + else: + raise NotImplementedError("BETWEEN with non-constant values not supported") + + raise NotImplementedError(f"Unsupported condition type: {type(node)}") + + def extract_conditions_recursive(node: ASTNode) -> list: + if isinstance(node, ast.BinaryOperation): + if node.op.lower() == "or": + left_conditions = extract_conditions_recursive(node.args[0]) + right_conditions = extract_conditions_recursive(node.args[1]) + return left_conditions + right_conditions + + elif node.op.lower() == "and": + left_conditions = extract_conditions_recursive(node.args[0]) + right_conditions = extract_conditions_recursive(node.args[1]) + + combined = [] + for left_group in left_conditions: + for right_group in right_conditions: + combined.append(left_group + right_group) + return combined + + else: + condition = extract_single_condition(node) + return [[condition]] # Single condition in its own group - def __init__(self, handler: APIHandler, name: str): + elif isinstance(node, BetweenOperation): + condition = extract_single_condition(node) + return [[condition]] # Single condition in its own group + + raise NotImplementedError(f"Unsupported node type: {type(node)}") + + try: + conditions = extract_conditions_recursive(node) + return conditions + except Exception: + return [[]] + + +# Mapping SQL operators to Strapi filter operators +OPERATOR_MAP = { + "=": "$eq", + "!=": "$ne", + ">": "$gt", + ">=": "$gte", + "<": "$lt", + "<=": "$lte", + "IN": "$in", + "NOT IN": "$notIn", +} + + +class StrapiTable(APITable): + def __init__(self, handler: APIHandler, name: str, defer_schema_fetch: bool = False): super().__init__(handler) self.name = name - # get all the fields of a collection as columns - self.columns = self.handler.call_strapi_api(method='GET', endpoint=f'/api/{name}').columns + self._schema_fetched = False + + if not defer_schema_fetch: + self._fetch_schema() + else: + # Set basic Strapi columns as placeholder + self.columns = ["id", "documentId", "createdAt", "updatedAt"] + + def _fetch_schema(self): + """Fetch schema from Strapi API""" + if self._schema_fetched: + return + + # Use cached schema if available + schema_key = f"{self.handler._connection_key}_{self.name}" + if schema_key in self.handler._table_schemas: + self.columns = self.handler._table_schemas[schema_key] + self._schema_fetched = True + return + + # Only fetch schema once and cache it + try: + df = self.handler.call_strapi_api( + method="GET", endpoint=f"/api/{self.name}", params={"pagination[limit]": 1} + ) + if len(df.columns) > 0: + self.columns = df.columns.tolist() + self.handler._table_schemas[schema_key] = self.columns + else: + # If no data, set basic Strapi columns + self.columns = ["id", "documentId", "createdAt", "updatedAt"] + self.handler._table_schemas[schema_key] = self.columns + except Exception: + # Set basic Strapi columns as fallback + self.columns = ["id", "documentId", "createdAt", "updatedAt"] + self.handler._table_schemas[schema_key] = self.columns + + self._schema_fetched = True + + def _build_filters(self, conditions: List[List[tuple]]) -> Dict[str, Any]: + """Build Strapi filters from DNF condition groups. + + Args: + conditions: DNF groups where each inner list is ANDed and + groups are ORed. + + Returns: + Dict of Strapi filter parameters + """ + if not conditions: + return {} + + # Keep the fast-path for direct documentId lookup. + if len(conditions) == 1 and len(conditions[0]) == 1: + op, field, value = conditions[0][0] + if field == "documentId" and op == "=": + return {"documentId": value} + + def to_filter_node(condition: tuple) -> Dict[str, Dict[str, Any]]: + op, field, value = condition + return self._build_single_condition(op, field, value) + + # Build nested Strapi filter tree preserving boolean precedence. + if len(conditions) == 1: + and_group = conditions[0] + if len(and_group) == 1: + filter_tree = to_filter_node(and_group[0]) + else: + filter_tree = {"$and": [to_filter_node(condition) for condition in and_group]} + else: + or_nodes = [] + for and_group in conditions: + if len(and_group) == 1: + or_nodes.append(to_filter_node(and_group[0])) + else: + or_nodes.append({"$and": [to_filter_node(condition) for condition in and_group]}) + filter_tree = {"$or": or_nodes} + + filters = {} + + def flatten(node: Any, path: List[str]) -> None: + if isinstance(node, dict): + for key, value in node.items(): + flatten(value, path + [key]) + elif isinstance(node, list): + for index, value in enumerate(node): + flatten(value, path + [str(index)]) + else: + key = "filters" + "".join(f"[{part}]" for part in path) + filters[key] = node + + flatten(filter_tree, []) + return filters + + def _build_single_condition(self, op: str, field: str, value: Any) -> Dict[str, Dict[str, Any]]: + """Build a single condition dictionary for Strapi filters + + Args: + op: SQL operator + field: Field name + value: Field value + + Returns: + Dictionary with field and its filter conditions + """ + condition = {} + + if op.upper() == "BETWEEN": + if isinstance(value, (list, tuple)) and len(value) == 2: + # BETWEEN translates to field >= min AND field <= max + condition[field] = {"$gte": value[0], "$lte": value[1]} + else: + raise ValueError("BETWEEN operator requires exactly 2 values") + + elif op.upper() == "LIKE": + if not isinstance(value, str): + raise ValueError("LIKE operator requires a string value") + + # Remove quotes if present + if (value.startswith("'") and value.endswith("'")) or (value.startswith('"') and value.endswith('"')): + value = value[1:-1] + + # Handle LIKE patterns + if value.startswith("%") and value.endswith("%"): + value = value[1:-1] # Remove % from both ends + condition[field] = {"$contains": value} + elif value.startswith("%"): + value = value[1:] # Remove leading % + condition[field] = {"$endsWith": value} + elif value.endswith("%"): + value = value[:-1] # Remove trailing % + condition[field] = {"$startsWith": value} + else: + condition[field] = {"$eq": value} + + elif op.upper() == "IS": + if value is None: + condition[field] = {"$null": True} + else: + raise ValueError(f"IS operator with non-null value not supported: {value}") + + elif op.upper() == "IS NOT": + if value is None: + condition[field] = {"$notNull": True} + else: + raise ValueError(f"IS NOT operator with non-null value not supported: {value}") + + elif op.upper() in ("IN", "NOT IN"): + if isinstance(value, (list, tuple)): + strapi_op = "$in" if op.upper() == "IN" else "$notIn" + condition[field] = {strapi_op: list(value)} + else: + raise ValueError(f"{op} operator requires a list or tuple value") + + elif op.upper() in OPERATOR_MAP: + condition[field] = {OPERATOR_MAP[op.upper()]: value} + + else: + raise ValueError(f"Unsupported operator {op} in WHERE clause") + + return condition + + def _fetch_by_id(self, document_id: str, selected_columns: list) -> pd.DataFrame: + """Helper method to fetch a record by documentId + + Args: + document_id (str): The documentId to fetch + selected_columns (list): Columns to include in result + + Returns: + pd.DataFrame: The resulting DataFrame + """ + df = self.handler.call_strapi_api(method="GET", endpoint=f"/api/{self.name}/{document_id}") + + if len(df) > 0: + return df[selected_columns] + return pd.DataFrame(columns=selected_columns) def select(self, query: ast.Select) -> pd.DataFrame: """Triggered at the SELECT query @@ -24,19 +289,8 @@ def select(self, query: ast.Select) -> pd.DataFrame: Returns: pd.DataFrame: The queried information """ - # Initialize _id and selected_columns - _id = None - selected_columns = [] - - # Get id from where clause, if available - conditions = extract_comparison_conditions(query.where) - for op, arg1, arg2 in conditions: - if arg1 == 'id' and op == '=': - _id = arg2 - else: - raise ValueError("Unsupported condition in WHERE clause") - # Get selected columns from query + selected_columns = [] for target in query.targets: if isinstance(target, ast.Star): selected_columns = self.get_columns() @@ -46,43 +300,77 @@ def select(self, query: ast.Select) -> pd.DataFrame: else: raise ValueError(f"Unknown query target {type(target)}") - # Initialize the result DataFrame - result_df = None + # Default to all columns if no columns are selected + if not selected_columns: + selected_columns = self.get_columns() - if _id is not None: - # Fetch data using the provided endpoint for the specific id - df = self.handler.call_strapi_api(method='GET', endpoint=f'/api/{self.name}/{_id}') + # Build filters from WHERE clause + filters = {} + if query.where: + try: + # Extract OR conditions - now always returns list of lists + conditions = extract_or_conditions(query.where) + filters = self._build_filters(conditions) + except Exception: + # Fallback to empty filters + filters = {} - if len(df) > 0: - result_df = df[selected_columns] - else: - # Fetch data without specifying an id - page_size = 100 # The page size you want to use for API requests - limit = query.limit.value if query.limit else None - result_df = pd.DataFrame(columns=selected_columns) + # If we got a documentId filter, use the specific endpoint + if "documentId" in filters: + return self._fetch_by_id(filters["documentId"], selected_columns) - if limit: - # Calculate the number of pages required - page_count = (limit + page_size - 1) // page_size - else: - page_count = 1 + # Initialize pagination parameters with optimized page size + # Use Strapi's default maximum page size of 100 for REST API + page_size = 100 + limit = query.limit.value if query.limit else None + result_df = pd.DataFrame(columns=selected_columns) - for page in range(1, page_count + 1): - if limit: - # Calculate the page size for this request - current_page_size = min(page_size, limit) - else: - current_page_size = page_size + # If limit is specified and smaller than page_size, use limit as page_size to minimize API calls + if limit and limit < page_size: + page_size = limit + + # Prepare initial parameters including filters + params = { + "pagination[page]": 1, + "pagination[pageSize]": page_size, + **filters, # Add any WHERE clause filters + } + + page = 1 + total_fetched = 0 + + # Fetch data in optimized pagination loop + while True: + params["pagination[page]"] = page + + df = self.handler.call_strapi_api(method="GET", endpoint=f"/api/{self.name}", params=params) - df = self.handler.call_strapi_api(method='GET', endpoint=f'/api/{self.name}', params={'pagination[page]': page, 'pagination[pageSize]': current_page_size}) + # Break if no data returned + if len(df) == 0: + break - if len(df) == 0: + # Apply limit constraint if specified + rows_to_take = len(df) + if limit: + remaining_needed = limit - total_fetched + if remaining_needed <= 0: break + rows_to_take = min(rows_to_take, remaining_needed) + + # Take only the needed rows and add to result + df_slice = df.head(rows_to_take) if rows_to_take < len(df) else df + result_df = pd.concat([result_df, df_slice[selected_columns]], ignore_index=True) - result_df = pd.concat([result_df, df[selected_columns]], ignore_index=True) + total_fetched += rows_to_take - if limit: - limit -= current_page_size + # Break conditions: + # 1. If we got fewer rows than page_size, we've reached the end + # 2. If we have a limit and we've reached it + # 3. If we took fewer rows than available due to limit constraint + if len(df) < page_size or (limit and total_fetched >= limit) or rows_to_take < len(df): + break + + page += 1 return result_df @@ -91,13 +379,23 @@ def insert(self, query: ast.Insert) -> None: Args: query (ast.Insert): user's entered query """ - data = {'data': {}} - for column, value in zip(query.columns, query.values[0]): - if isinstance(value, Constant): - data['data'][column.name] = value.value - else: - data['data'][column.name] = value - self.handler.call_strapi_api(method='POST', endpoint=f'/api/{self.name}', json_data=json.dumps(data)) + # Loop through all rows in the VALUES clause + for row_values in query.values: + data = {"data": {}} + + for column, value in zip(query.columns, row_values): + # Clean column name (remove backticks if present) + column_name = column.name + if column_name.startswith("`") and column_name.endswith("`"): + column_name = column_name[1:-1] + + if isinstance(value, Constant): + data["data"][column_name] = value.value + else: + data["data"][column_name] = value + + # Make individual API call for each row + self.handler.call_strapi_api(method="POST", endpoint=f"/api/{self.name}", json_data=json.dumps(data)) def update(self, query: ast.Update) -> None: """triggered at the UPDATE query @@ -106,17 +404,19 @@ def update(self, query: ast.Update) -> None: query (ast.Update): user's entered query """ conditions = extract_comparison_conditions(query.where) - # Get id from query + # Get documentId from query for op, arg1, arg2 in conditions: - if arg1 == 'id' and op == '=': - _id = arg2 + if arg1 == "documentId" and op == "=": + _documentId = arg2 else: - raise NotImplementedError - data = {'data': {}} + raise ValueError("`documentId` must be used in WHERE clause for UPDATE") + data = {"data": {}} for key, value in query.update_columns.items(): if isinstance(value, Constant): - data['data'][key] = value.value - self.handler.call_strapi_api(method='PUT', endpoint=f'/api/{self.name}/{_id}', json_data=json.dumps(data)) + data["data"][key] = value.value + self.handler.call_strapi_api( + method="PUT", endpoint=f"/api/{self.name}/{_documentId}", json_data=json.dumps(data) + ) def get_columns(self, ignore: List[str] = []) -> List[str]: """columns @@ -127,5 +427,6 @@ def get_columns(self, ignore: List[str] = []) -> List[str]: Returns: List[str]: available columns with `ignore` items removed from the list. """ - + if not self._schema_fetched: + self._fetch_schema() return [item for item in self.columns if item not in ignore] diff --git a/mindsdb/integrations/handlers/strapi_handler/tests/test_strapi_handler.py b/mindsdb/integrations/handlers/strapi_handler/tests/test_strapi_handler.py index a7dd95481dc..a2f0b197574 100644 --- a/mindsdb/integrations/handlers/strapi_handler/tests/test_strapi_handler.py +++ b/mindsdb/integrations/handlers/strapi_handler/tests/test_strapi_handler.py @@ -1,51 +1,199 @@ import unittest +from unittest.mock import patch, Mock from mindsdb.integrations.handlers.strapi_handler.strapi_handler import StrapiHandler +from mindsdb.integrations.handlers.strapi_handler.strapi_tables import extract_or_conditions, StrapiTable from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE +from mindsdb_sql_parser import parse_sql class StrapiHandlerTest(unittest.TestCase): + def setUp(self): + self.connection_data = { + "host": "localhost", + "port": "1337", + "api_token": "test_token_123", + "endpoints": ["products", "sellers"], + } + self.handler = StrapiHandler(name="myshop", connection_data=self.connection_data) - @classmethod - def setUpClass(cls): - connection_data = { - 'host': 'localhost', - 'port': '1337', - 'api_token': 'c56c000d867e95848c', - 'plural_api_ids': ['products', 'sellers']} - cls.handler = StrapiHandler(name='myshop', connection_data=connection_data) + # Mock data for testing (matching real Strapi API response structure) + self.mock_products_data = [ + { + "id": 45, + "documentId": "mvaprjyy72ayx7z4v592sdnr", + "title": "Mens Casual Premium Slim Fit T-Shirts", + "desc": "Slim-fitting style, contrast raglan long sleeve, lightweight & breathable fabric.", + "price": 22.3, + "createdAt": "2025-09-09T08:57:55.574Z", + "updatedAt": "2025-09-09T09:53:41.392Z", + "publishedAt": "2025-09-09T09:53:41.412Z", + }, + { + "id": 46, + "documentId": "abc123def456ghi789", + "title": "Womens Cotton Jacket", + "desc": "Great outerwear for Spring/Autumn/Winter.", + "price": 55.99, + "createdAt": "2025-09-09T08:58:55.574Z", + "updatedAt": "2025-09-09T09:54:41.392Z", + "publishedAt": "2025-09-09T09:54:41.412Z", + }, + ] + + self.mock_sellers_data = [ + { + "id": 1, + "documentId": "seller123", + "name": "Test Seller", + "email": "seller@test.com", + "sellerid": "seller001", + "createdAt": "2025-09-09T08:57:55.574Z", + "updatedAt": "2025-09-09T09:53:41.392Z", + "publishedAt": "2025-09-09T09:53:41.412Z", + } + ] + + @patch("mindsdb.integrations.handlers.strapi_handler.strapi_handler.requests.get") + def test_0_check_connection(self, mock_get): + # Mock successful connection response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": {"name": "test-strapi", "version": "4.0.0"}} + mock_get.return_value = mock_response - def test_0_check_connection(self): # Ensure the connection is successful self.assertTrue(self.handler.check_connection()) def test_1_get_table(self): - assert self.handler.get_tables() is not RESPONSE_TYPE.ERROR + # Mock the endpoints from connection data + result = self.handler.get_tables() + self.assertIsNotNone(result) + assert result is not RESPONSE_TYPE.ERROR + + @patch("mindsdb.integrations.handlers.strapi_handler.strapi_handler.requests.get") + def test_2_get_columns(self, mock_get): + # Mock response for schema fetching (single record with limit=1) + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [self.mock_products_data[0]] # Return first product for schema discovery + } + mock_get.return_value = mock_response + + result = self.handler.get_columns("products") + assert result is not RESPONSE_TYPE.ERROR + + @patch("mindsdb.integrations.handlers.strapi_handler.strapi_handler.requests.get") + def test_3_get_data(self, mock_get): + # Mock responses: first call for schema (limit=1), second call for actual data + schema_response = Mock() + schema_response.status_code = 200 + schema_response.json.return_value = {"data": [self.mock_products_data[0]]} - def test_2_get_columns(self): - assert self.handler.get_columns('products') is not RESPONSE_TYPE.ERROR + data_response = Mock() + data_response.status_code = 200 + data_response.json.return_value = {"data": self.mock_products_data} + + # Return schema response first, then data response + mock_get.side_effect = [schema_response, data_response] - def test_3_get_data(self): # Ensure that you can retrieve data from a table - data = self.handler.native_query('SELECT * FROM products') + data = self.handler.native_query("SELECT * FROM products") assert data.type is not RESPONSE_TYPE.ERROR - def test_4_get_data_with_condition(self): + @patch("mindsdb.integrations.handlers.strapi_handler.strapi_handler.requests.get") + def test_4_get_data_with_condition(self, mock_get): + # Mock responses: first call for schema (limit=1), second call for specific documentId + schema_response = Mock() + schema_response.status_code = 200 + schema_response.json.return_value = {"data": [self.mock_products_data[0]]} + + specific_response = Mock() + specific_response.status_code = 200 + specific_response.json.return_value = { + "data": self.mock_products_data[0] # Return single product (not in array for specific ID) + } + + # Return schema response first, then specific product response + mock_get.side_effect = [schema_response, specific_response] + # Ensure that you can retrieve data with a condition - data = self.handler.native_query('SELECT * FROM products WHERE id = 1') + data = self.handler.native_query("SELECT * FROM products WHERE documentId = 'mvaprjyy72ayx7z4v592sdnr'") assert data.type is not RESPONSE_TYPE.ERROR - def test_5_insert_data(self): + @patch("mindsdb.integrations.handlers.strapi_handler.strapi_handler.requests.request") + def test_5_insert_data(self, mock_request): + # Mock response for successful data insertion + mock_response = Mock() + mock_response.status_code = 201 + mock_response.json.return_value = { + "data": { + "id": 2, + "documentId": "newdocid123", + "name": "Ram", + "email": "ram@gmail.com", + "sellerid": "ramu4", + "createdAt": "2025-09-09T08:57:55.574Z", + "updatedAt": "2025-09-09T09:53:41.392Z", + "publishedAt": "2025-09-09T09:53:41.412Z", + } + } + mock_request.return_value = mock_response + # Ensure that data insertion is successful query = "INSERT INTO myshop.sellers (name, email, sellerid) VALUES ('Ram', 'ram@gmail.com', 'ramu4')" result = self.handler.native_query(query) - self.assertTrue(result) + self.assertIsNotNone(result) + assert result.type is not RESPONSE_TYPE.ERROR + + @patch("mindsdb.integrations.handlers.strapi_handler.strapi_handler.requests.request") + def test_6_update_data(self, mock_request): + # Mock response for successful data update + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "id": 45, + "documentId": "mvaprjyy72ayx7z4v592sdnr", + "title": "Updated Product Title", # Updated title + "desc": "Slim-fitting style, contrast raglan long sleeve, lightweight & breathable fabric.", + "price": 22.3, + "createdAt": "2025-09-09T08:57:55.574Z", + "updatedAt": "2025-09-09T09:53:41.392Z", + "publishedAt": "2025-09-09T09:53:41.412Z", + } + } + mock_request.return_value = mock_response - def test_6_update_data(self): # Ensure that data updating is successful - query = "UPDATE products SET name = 'test2' WHERE id = 1" + query = "UPDATE products SET title = 'Updated Product Title' WHERE documentId = 'mvaprjyy72ayx7z4v592sdnr'" result = self.handler.native_query(query) - self.assertTrue(result) + self.assertIsNotNone(result) + assert result.type is not RESPONSE_TYPE.ERROR + + def test_7_where_precedence_or_and(self): + query = parse_sql("SELECT * FROM products WHERE a = 1 OR (b = 2 AND c = 4)") + table = StrapiTable(handler=self.handler, name="products", defer_schema_fetch=True) + + conditions = extract_or_conditions(query.where) + filters = table._build_filters(conditions) + + self.assertIn("filters[$or][0][a][$eq]", filters) + self.assertIn("filters[$or][1][$and][0][b][$eq]", filters) + self.assertIn("filters[$or][1][$and][1][c][$eq]", filters) + + def test_8_where_precedence_or_or(self): + query = parse_sql("SELECT * FROM products WHERE a = 1 OR (b = 2 OR c = 4)") + table = StrapiTable(handler=self.handler, name="products", defer_schema_fetch=True) + + conditions = extract_or_conditions(query.where) + filters = table._build_filters(conditions) + + self.assertIn("filters[$or][0][a][$eq]", filters) + self.assertIn("filters[$or][1][b][$eq]", filters) + self.assertIn("filters[$or][2][c][$eq]", filters) + self.assertFalse(any("[$and]" in key for key in filters)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From b9802c23c90e492063720d2677bafa2a53450e79 Mon Sep 17 00:00:00 2001 From: RITWICK RAJ MAKHAL Date: Fri, 20 Mar 2026 20:13:12 +0530 Subject: [PATCH 077/125] Rename icons for better naming consistency. (#12316) --- mindsdb/integrations/handlers/bigcommerce_handler/__init__.py | 2 +- .../bigcommerce_handler/{bigcommerce-black.svg => icon.svg} | 0 mindsdb/integrations/handlers/netsuite_handler/__init__.py | 2 +- .../handlers/netsuite_handler/{netsuite.svg => icon.svg} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename mindsdb/integrations/handlers/bigcommerce_handler/{bigcommerce-black.svg => icon.svg} (100%) rename mindsdb/integrations/handlers/netsuite_handler/{netsuite.svg => icon.svg} (100%) diff --git a/mindsdb/integrations/handlers/bigcommerce_handler/__init__.py b/mindsdb/integrations/handlers/bigcommerce_handler/__init__.py index 7e671a123a0..1117cbb3089 100644 --- a/mindsdb/integrations/handlers/bigcommerce_handler/__init__.py +++ b/mindsdb/integrations/handlers/bigcommerce_handler/__init__.py @@ -14,7 +14,7 @@ title = "BigCommerce" name = "bigcommerce" type = HANDLER_TYPE.DATA -icon_path = "bigcommerce-black.svg" +icon_path = "icon.svg" support_level = HANDLER_SUPPORT_LEVEL.COMMUNITY __all__ = [ diff --git a/mindsdb/integrations/handlers/bigcommerce_handler/bigcommerce-black.svg b/mindsdb/integrations/handlers/bigcommerce_handler/icon.svg similarity index 100% rename from mindsdb/integrations/handlers/bigcommerce_handler/bigcommerce-black.svg rename to mindsdb/integrations/handlers/bigcommerce_handler/icon.svg diff --git a/mindsdb/integrations/handlers/netsuite_handler/__init__.py b/mindsdb/integrations/handlers/netsuite_handler/__init__.py index 673d6ed81ed..c2272c38724 100644 --- a/mindsdb/integrations/handlers/netsuite_handler/__init__.py +++ b/mindsdb/integrations/handlers/netsuite_handler/__init__.py @@ -14,7 +14,7 @@ title = "Oracle NetSuite" name = "netsuite" type = HANDLER_TYPE.DATA -icon_path = "netsuite.svg" +icon_path = "icon.svg" __all__ = [ "Handler", diff --git a/mindsdb/integrations/handlers/netsuite_handler/netsuite.svg b/mindsdb/integrations/handlers/netsuite_handler/icon.svg similarity index 100% rename from mindsdb/integrations/handlers/netsuite_handler/netsuite.svg rename to mindsdb/integrations/handlers/netsuite_handler/icon.svg From 612b53c8f8b96c57116dfd85118baf5980f24365 Mon Sep 17 00:00:00 2001 From: "r.e.e.c.h.e.e" Date: Fri, 20 Mar 2026 18:41:54 +0100 Subject: [PATCH 078/125] chore: remove elasticsearch from HANDLERS_TO_INSTALL in unit test workflow --- .github/workflows/tests_unit.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests_unit.yml b/.github/workflows/tests_unit.yml index 4cd366bd980..3a9455a7f76 100644 --- a/.github/workflows/tests_unit.yml +++ b/.github/workflows/tests_unit.yml @@ -34,7 +34,6 @@ env: statsforecast chromadb confluence - elasticsearch # We measure 80% on this handlers, as they are the verified HANDLERS_TO_VERIFY: | mysql From 8202b38032fa013ad839e91f7ff8a082ed858fd6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 18:13:40 -0500 Subject: [PATCH 079/125] Bump werkzeug from 3.0.6 to 3.1.6 in /requirements (#12244) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index ca331360bb1..737d9e8b18c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ packaging flask == 3.0.3 -werkzeug == 3.0.6 +werkzeug == 3.1.6 flask-restx >= 1.3.0, < 2.0.0 pandas == 2.2.3 python-multipart == 0.0.22 From f3dbeec350cf4e61058957a1034532d9a53af41d Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Mon, 23 Mar 2026 10:13:59 +0300 Subject: [PATCH 080/125] Del unused callbacks (#12314) --- .../vector_store_loader.py | 36 +- .../interfaces/agents/callback_handlers.py | 177 ---------- .../agents/event_dispatch_callback_handler.py | 50 --- .../agents/langfuse_callback_handler.py | 308 ------------------ .../interfaces/agents/pydantic_ai_agent.py | 4 - requirements/requirements-agents.txt | 5 - requirements/requirements-kb.txt | 1 - 7 files changed, 1 insertion(+), 580 deletions(-) delete mode 100644 mindsdb/interfaces/agents/callback_handlers.py delete mode 100644 mindsdb/interfaces/agents/event_dispatch_callback_handler.py delete mode 100644 mindsdb/interfaces/agents/langfuse_callback_handler.py diff --git a/mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py b/mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py index a094f5e830e..0981b3b2a44 100644 --- a/mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py +++ b/mindsdb/integrations/utilities/rag/loaders/vector_store_loader/vector_store_loader.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from mindsdb.integrations.utilities.rag.settings import VectorStoreType, VectorStoreConfig +from mindsdb.integrations.utilities.rag.settings import VectorStoreConfig from mindsdb.integrations.utilities.rag.loaders.vector_store_loader.base_vector_store import VectorStore from mindsdb.integrations.utilities.rag.loaders.vector_store_loader.MDBVectorStore import MDBVectorStore from mindsdb.integrations.utilities.rag.loaders.vector_store_loader.pgvector import PGVectorMDB @@ -46,37 +46,3 @@ def load(self) -> VectorStore: vector_size=self.config.vector_size, ) return MDBVectorStore(kb_table=self.config.kb_table) - - -class VectorStoreFactory: - @staticmethod - def create(embedding_model: Any, config: VectorStoreConfig) -> VectorStore: - if config.vector_store_type == VectorStoreType.CHROMA: - return VectorStoreFactory._load_chromadb_store(embedding_model, config) - elif config.vector_store_type == VectorStoreType.PGVECTOR: - return VectorStoreFactory._load_pgvector_store(embedding_model, config) - else: - raise ValueError(f"Invalid vector store type, must be one either {VectorStoreType.__members__.keys()}") - - @staticmethod - def _load_chromadb_store(embedding_model: Any, settings) -> VectorStore: - # Chroma still uses langchain, import only when needed - from langchain_community.vectorstores import Chroma - - return Chroma( - persist_directory=settings.persist_directory, - collection_name=settings.collection_name, - embedding_function=embedding_model, - ) - - @staticmethod - def _load_pgvector_store(embedding_model: Any, settings) -> VectorStore: - from .pgvector import PGVectorMDB - - return PGVectorMDB( - connection_string=settings.connection_string, - collection_name=settings.collection_name, - embedding_function=embedding_model, - is_sparse=settings.is_sparse, - vector_size=settings.vector_size, - ) diff --git a/mindsdb/interfaces/agents/callback_handlers.py b/mindsdb/interfaces/agents/callback_handlers.py deleted file mode 100644 index f4735b737c7..00000000000 --- a/mindsdb/interfaces/agents/callback_handlers.py +++ /dev/null @@ -1,177 +0,0 @@ -import io -import logging -import contextlib -from typing import Any, Dict, List, Union, Callable - -from langchain_core.agents import AgentAction, AgentFinish -from langchain_core.callbacks.base import BaseCallbackHandler -from langchain_core.messages.base import BaseMessage -from langchain_core.outputs import LLMResult -from langchain_core.callbacks import StdOutCallbackHandler - - -class ContextCaptureCallback(BaseCallbackHandler): - def __init__(self): - self.context = None - - def on_retriever_end(self, documents: List[Any], *, run_id: str, parent_run_id: Union[str, None] = None, **kwargs: Any) -> Any: - self.context = [{ - 'page_content': doc.page_content, - 'metadata': doc.metadata - } for doc in documents] - - def get_contexts(self): - return self.context - - -class VerboseLogCallbackHandler(StdOutCallbackHandler): - def __init__(self, logger: logging.Logger, verbose: bool): - self.logger = logger - self.verbose = verbose - super().__init__() - - def __call(self, method: Callable, *args: List[Any], **kwargs: Any) -> Any: - if self.verbose is False: - return - f = io.StringIO() - with contextlib.redirect_stdout(f): - method(*args, **kwargs) - output = f.getvalue() - self.logger.info(output) - - def on_chain_start(self, *args: List[Any], **kwargs: Any) -> None: - self.__call(super().on_chain_start, *args, **kwargs) - - def on_chain_end(self, *args: List[Any], **kwargs: Any) -> None: - self.__call(super().on_chain_end, *args, **kwargs) - - def on_agent_action(self, *args: List[Any], **kwargs: Any) -> None: - self.__call(super().on_agent_action, *args, **kwargs) - - def on_tool_end(self, *args: List[Any], **kwargs: Any) -> None: - self.__call(super().on_tool_end, *args, **kwargs) - - def on_text(self, *args: List[Any], **kwargs: Any) -> None: - self.__call(super().on_text, *args, **kwargs) - - def on_agent_finish(self, *args: List[Any], **kwargs: Any) -> None: - self.__call(super().on_agent_finish, *args, **kwargs) - - -class LogCallbackHandler(BaseCallbackHandler): - '''Langchain callback handler that logs agent and chain executions.''' - - def __init__(self, logger: logging.Logger, verbose: bool = True): - logger.setLevel('DEBUG') - self.logger = logger - self._num_running_chains = 0 - self.generated_sql = None - self.verbose_log_handler = VerboseLogCallbackHandler(logger, verbose) - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> Any: - '''Run when LLM starts running.''' - self.logger.debug('LLM started with prompts:') - for prompt in prompts: - self.logger.debug(prompt[:50]) - self.verbose_log_handler.on_llm_start(serialized, prompts, **kwargs) - - def on_chat_model_start( - self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], **kwargs: Any - ) -> Any: - '''Run when Chat Model starts running.''' - self.logger.debug('Chat model started with messages:') - for message_list in messages: - for message in message_list: - self.logger.debug(message.pretty_repr()) - - def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: - '''Run on new LLM token. Only available when streaming is enabled.''' - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: - '''Run when LLM ends running.''' - self.logger.debug('LLM ended with response:') - self.logger.debug(str(response.llm_output)) - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - '''Run when LLM errors.''' - self.logger.debug(f'LLM encountered an error: {str(error)}') - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: - '''Run when chain starts running.''' - self._num_running_chains += 1 - self.logger.info('Entering new LLM chain ({} total)'.format( - self._num_running_chains)) - self.logger.debug('Inputs: {}'.format(inputs)) - - self.verbose_log_handler.on_chain_start(serialized=serialized, inputs=inputs, **kwargs) - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: - '''Run when chain ends running.''' - self._num_running_chains -= 1 - self.logger.info('Ended LLM chain ({} total)'.format( - self._num_running_chains)) - self.logger.debug('Outputs: {}'.format(outputs)) - - self.verbose_log_handler.on_chain_end(outputs=outputs, **kwargs) - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - '''Run when chain errors.''' - self._num_running_chains -= 1 - self.logger.error( - 'LLM chain encountered an error ({} running): {}'.format( - self._num_running_chains, error)) - - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: - '''Run when tool starts running.''' - pass - - def on_tool_end(self, output: str, **kwargs: Any) -> Any: - '''Run when tool ends running.''' - self.verbose_log_handler.on_tool_end(output=output, **kwargs) - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - '''Run when tool errors.''' - pass - - def on_text(self, text: str, **kwargs: Any) -> Any: - '''Run on arbitrary text.''' - self.verbose_log_handler.on_text(text=text, **kwargs) - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - '''Run on agent action.''' - self.logger.debug(f'Running tool {action.tool} with input:') - self.logger.debug(action.tool_input) - - stop_block = 'Observation: ' - if stop_block in action.tool_input: - action.tool_input = action.tool_input[: action.tool_input.find(stop_block)] - - if action.tool.startswith("sql_db_query"): - # Save the generated SQL query - self.generated_sql = action.tool_input - - # fix for mistral - action.tool = action.tool.replace('\\', '') - - self.verbose_log_handler.on_agent_action(action=action, **kwargs) - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - '''Run on agent end.''' - self.logger.debug('Agent finished with return values:') - self.logger.debug(str(finish.return_values)) - self.verbose_log_handler.on_agent_finish(finish=finish, **kwargs) diff --git a/mindsdb/interfaces/agents/event_dispatch_callback_handler.py b/mindsdb/interfaces/agents/event_dispatch_callback_handler.py deleted file mode 100644 index 7446ba2adaa..00000000000 --- a/mindsdb/interfaces/agents/event_dispatch_callback_handler.py +++ /dev/null @@ -1,50 +0,0 @@ -import queue -from typing import Any, Dict, List, Optional, Sequence -from uuid import UUID - -from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.documents import Document - - -class EventDispatchCallbackHandler(BaseCallbackHandler): - '''Puts dispatched events onto an event queue to be processed as a streaming chunk''' - def __init__(self, queue: queue.Queue): - self.queue = queue - - def on_custom_event( - self, - name: str, - data: Any, - *, - run_id: UUID, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs - ): - self.queue.put({ - 'type': 'event', - 'name': name, - 'data': data - }) - - def on_retriever_end( - self, - documents: Sequence[Document], - *, - run_id: UUID, - parent_run_id: Optional[UUID] = None, - **kwargs: Any, - ) -> Any: - document_objects = [] - for d in documents: - document_objects.append({ - 'content': d.page_content, - 'metadata': d.metadata - }) - self.queue.put({ - 'type': 'event', - 'name': 'retriever_end', - 'data': { - 'documents': document_objects - } - }) diff --git a/mindsdb/interfaces/agents/langfuse_callback_handler.py b/mindsdb/interfaces/agents/langfuse_callback_handler.py deleted file mode 100644 index 948eadda6b8..00000000000 --- a/mindsdb/interfaces/agents/langfuse_callback_handler.py +++ /dev/null @@ -1,308 +0,0 @@ -from typing import Any, Dict, Union, Optional, List -from uuid import uuid4 -import datetime -import json - -from langchain_core.callbacks.base import BaseCallbackHandler - -from mindsdb.utilities import log -from mindsdb.interfaces.storage import db - -logger = log.getLogger(__name__) -logger.setLevel('DEBUG') - - -class LangfuseCallbackHandler(BaseCallbackHandler): - """Langchain callback handler that traces tool & chain executions using Langfuse.""" - - def __init__(self, langfuse, trace_id: Optional[str] = None, observation_id: Optional[str] = None): - self.langfuse = langfuse - self.chain_uuid_to_span = {} - self.action_uuid_to_span = {} - # if these are not available, we generate some UUIDs - self.trace_id = trace_id or uuid4().hex - self.observation_id = observation_id or uuid4().hex - # Track metrics about tools and chains - self.tool_metrics = {} - self.chain_metrics = {} - self.current_chain = None - - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: - """Run when tool starts running.""" - parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex - action_span = self.action_uuid_to_span.get(parent_run_uuid) - if action_span is None: - return - - tool_name = serialized.get("name", "tool") - start_time = datetime.datetime.now() - - # Initialize or update tool metrics - if tool_name not in self.tool_metrics: - self.tool_metrics[tool_name] = { - 'count': 0, - 'total_time': 0, - 'errors': 0, - 'last_error': None, - 'inputs': [] - } - - self.tool_metrics[tool_name]['count'] += 1 - self.tool_metrics[tool_name]['inputs'].append(input_str) - - metadata = { - 'tool_name': tool_name, - 'started': start_time.isoformat(), - 'start_timestamp': start_time.timestamp(), - 'input_length': len(input_str) if input_str else 0 - } - action_span.update(metadata=metadata) - - def on_tool_end(self, output: str, **kwargs: Any) -> Any: - """Run when tool ends running.""" - parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex - action_span = self.action_uuid_to_span.get(parent_run_uuid) - if action_span is None: - return - - end_time = datetime.datetime.now() - tool_name = action_span.metadata.get('tool_name', 'unknown') - start_timestamp = action_span.metadata.get('start_timestamp') - - if start_timestamp: - duration = end_time.timestamp() - start_timestamp - if tool_name in self.tool_metrics: - self.tool_metrics[tool_name]['total_time'] += duration - - metadata = { - 'finished': end_time.isoformat(), - 'duration_seconds': duration if start_timestamp else None, - 'output_length': len(output) if output else 0 - } - - action_span.update( - output=output, # tool output is action output (unless superseded by a global action output) - metadata=metadata - ) - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - """Run when tool errors.""" - parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex - action_span = self.action_uuid_to_span.get(parent_run_uuid) - if action_span is None: - return - - try: - error_str = str(error) - except Exception: - error_str = "Couldn't get error string." - - tool_name = action_span.metadata.get('tool_name', 'unknown') - if tool_name in self.tool_metrics: - self.tool_metrics[tool_name]['errors'] += 1 - self.tool_metrics[tool_name]['last_error'] = error_str - - metadata = { - 'error_description': error_str, - 'error_type': error.__class__.__name__, - 'error_time': datetime.datetime.now().isoformat() - } - action_span.update(metadata=metadata) - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: - """Run when chain starts running.""" - if self.langfuse is None: - return - - run_uuid = kwargs.get('run_id', uuid4()).hex - - if serialized is None: - serialized = {} - - chain_name = serialized.get("name", "chain") - start_time = datetime.datetime.now() - - # Initialize or update chain metrics - if chain_name not in self.chain_metrics: - self.chain_metrics[chain_name] = { - 'count': 0, - 'total_time': 0, - 'errors': 0, - 'last_error': None - } - - self.chain_metrics[chain_name]['count'] += 1 - self.current_chain = chain_name - - try: - chain_span = self.langfuse.span( - name=f'{chain_name}-{run_uuid}', - trace_id=self.trace_id, - parent_observation_id=self.observation_id, - input=json.dumps(inputs, indent=2) - ) - - metadata = { - 'chain_name': chain_name, - 'started': start_time.isoformat(), - 'start_timestamp': start_time.timestamp(), - 'input_keys': list(inputs.keys()) if isinstance(inputs, dict) else None, - 'input_size': len(inputs) if isinstance(inputs, dict) else len(str(inputs)) - } - chain_span.update(metadata=metadata) - self.chain_uuid_to_span[run_uuid] = chain_span - except Exception as e: - logger.warning(f"Error creating Langfuse span: {str(e)}") - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: - """Run when chain ends running.""" - if self.langfuse is None: - return - - chain_uuid = kwargs.get('run_id', uuid4()).hex - if chain_uuid not in self.chain_uuid_to_span: - return - chain_span = self.chain_uuid_to_span.pop(chain_uuid) - if chain_span is None: - return - - try: - end_time = datetime.datetime.now() - chain_name = chain_span.metadata.get('chain_name', 'unknown') - start_timestamp = chain_span.metadata.get('start_timestamp') - - if start_timestamp and chain_name in self.chain_metrics: - duration = end_time.timestamp() - start_timestamp - self.chain_metrics[chain_name]['total_time'] += duration - - metadata = { - 'finished': end_time.isoformat(), - 'duration_seconds': duration if start_timestamp else None, - 'output_keys': list(outputs.keys()) if isinstance(outputs, dict) else None, - 'output_size': len(outputs) if isinstance(outputs, dict) else len(str(outputs)) - } - chain_span.update(output=json.dumps(outputs, indent=2), metadata=metadata) - chain_span.end() - except Exception as e: - logger.warning(f"Error updating Langfuse span: {str(e)}") - - def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: - """Run when chain errors.""" - chain_uuid = kwargs.get('run_id', uuid4()).hex - if chain_uuid not in self.chain_uuid_to_span: - return - chain_span = self.chain_uuid_to_span.get(chain_uuid) - if chain_span is None: - return - - try: - error_str = str(error) - except Exception: - error_str = "Couldn't get error string." - - chain_name = chain_span.metadata.get('chain_name', 'unknown') - if chain_name in self.chain_metrics: - self.chain_metrics[chain_name]['errors'] += 1 - self.chain_metrics[chain_name]['last_error'] = error_str - - metadata = { - 'error_description': error_str, - 'error_type': error.__class__.__name__, - 'error_time': datetime.datetime.now().isoformat() - } - chain_span.update(metadata=metadata) - - def on_agent_action(self, action, **kwargs: Any) -> Any: - """Run on agent action.""" - if self.langfuse is None: - return - - run_uuid = kwargs.get('run_id', uuid4()).hex - try: - action_span = self.langfuse.span( - name=f'{getattr(action, "type", "action")}-{getattr(action, "tool", "")}-{run_uuid}', - trace_id=self.trace_id, - parent_observation_id=self.observation_id, - input=str(action) - ) - self.action_uuid_to_span[run_uuid] = action_span - except Exception as e: - logger.warning(f"Error creating Langfuse span for agent action: {str(e)}") - - def on_agent_finish(self, finish, **kwargs: Any) -> Any: - """Run on agent end.""" - if self.langfuse is None: - return - - run_uuid = kwargs.get('run_id', uuid4()).hex - if run_uuid not in self.action_uuid_to_span: - return - action_span = self.action_uuid_to_span.pop(run_uuid) - if action_span is None: - return - - try: - if finish is not None: - action_span.update(output=finish) # supersedes tool output - action_span.end() - except Exception as e: - logger.warning(f"Error updating Langfuse span: {str(e)}") - - def auth_check(self): - if self.langfuse is not None: - return self.langfuse.auth_check() - return False - - def get_metrics(self) -> Dict[str, Any]: - """Get collected metrics about tools and chains. - - Returns: - Dict containing: - - tool_metrics: Statistics about tool usage, errors, and timing - - chain_metrics: Statistics about chain execution, errors, and timing - For each tool/chain, includes: - - count: Number of times used - - total_time: Total execution time - - errors: Number of errors - - last_error: Most recent error message - - avg_duration: Average execution time - """ - metrics = { - 'tool_metrics': {}, - 'chain_metrics': {} - } - - # Process tool metrics - for tool_name, data in self.tool_metrics.items(): - metrics['tool_metrics'][tool_name] = { - 'count': data['count'], - 'total_time': data['total_time'], - 'avg_duration': data['total_time'] / data['count'] if data['count'] > 0 else 0, - 'errors': data['errors'], - 'last_error': data['last_error'], - 'error_rate': data['errors'] / data['count'] if data['count'] > 0 else 0 - } - - # Process chain metrics - for chain_name, data in self.chain_metrics.items(): - metrics['chain_metrics'][chain_name] = { - 'count': data['count'], - 'total_time': data['total_time'], - 'avg_duration': data['total_time'] / data['count'] if data['count'] > 0 else 0, - 'errors': data['errors'], - 'last_error': data['last_error'], - 'error_rate': data['errors'] / data['count'] if data['count'] > 0 else 0 - } - - return metrics - - -def get_skills(agent: db.Agents) -> List: - """ Retrieve skills from agent `skills` attribute. Specific to agent endpoints. """ - return [rel.skill.type for rel in agent.skills_relationships] diff --git a/mindsdb/interfaces/agents/pydantic_ai_agent.py b/mindsdb/interfaces/agents/pydantic_ai_agent.py index f32fb3d8ed3..2ec7a87a189 100644 --- a/mindsdb/interfaces/agents/pydantic_ai_agent.py +++ b/mindsdb/interfaces/agents/pydantic_ai_agent.py @@ -87,10 +87,6 @@ def __init__( self.llm: Optional[object] = None self.embedding_model: Optional[object] = None - self.log_callback_handler: Optional[object] = None - self.langfuse_callback_handler: Optional[object] = None - self.mdb_langfuse_callback_handler: Optional[object] = None - self.langfuse_client_wrapper = LangfuseClientWrapper() self.agent_mode = self.agent.params.get("mode", "text") diff --git a/requirements/requirements-agents.txt b/requirements/requirements-agents.txt index 83b60d9f496..7676cdd4ff4 100644 --- a/requirements/requirements-agents.txt +++ b/requirements/requirements-agents.txt @@ -1,10 +1,5 @@ openai<3.0.0,>=2.9.0 -langchain-community==0.3.27 -langchain-core==0.3.77 -langchain-experimental==0.3.4 - - # When using agents, some LLMs may require the 'transformers' library (like Ollama): transformers >= 4.42.4 diff --git a/requirements/requirements-kb.txt b/requirements/requirements-kb.txt index eb5adbfaefb..4d9a44112af 100644 --- a/requirements/requirements-kb.txt +++ b/requirements/requirements-kb.txt @@ -1,4 +1,3 @@ lxml==5.3.0 # Is this transitive dependency? pgvector==0.3.6 # Required for knowledge bases -langchain-core==0.3.77 litellm==1.63.14 \ No newline at end of file From b96f77ac25cdbf400b50e8970c92d807c49179f6 Mon Sep 17 00:00:00 2001 From: Andrey Date: Mon, 23 Mar 2026 13:36:56 +0300 Subject: [PATCH 081/125] Fix snowflake error message (#12312) --- .../handlers/snowflake_handler/snowflake_handler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py b/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py index 581eb74b053..04898c3df63 100644 --- a/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py +++ b/mindsdb/integrations/handlers/snowflake_handler/snowflake_handler.py @@ -363,7 +363,7 @@ def query(self, query: ASTNode) -> DataHandlerResponse: return self.lowercase_columns(result, query) def lowercase_columns(self, result, query): - if not isinstance(query, Select) or result.data_frame is None: + if not isinstance(query, Select) or not isinstance(result, TableResponse): return result quoted_columns = [] @@ -376,12 +376,11 @@ def lowercase_columns(self, result, query): if column.is_quoted[-1]: quoted_columns.append(column.parts[-1]) - rename_columns = {} - for col in result.data_frame.columns: - if col.isupper() and col not in quoted_columns: - rename_columns[col] = col.lower() - if rename_columns: - result.data_frame = result.data_frame.rename(columns=rename_columns) + for col in result.columns: + col_name = col.alias or col.name + if col_name.isupper() and col_name not in quoted_columns: + col.alias = col_name.lower() + return result def get_tables(self) -> DataHandlerResponse: From 8d3731ea5bb764c5ecec2663f2d8a38513db3361 Mon Sep 17 00:00:00 2001 From: Andrey Date: Mon, 23 Mar 2026 18:04:03 +0300 Subject: [PATCH 082/125] Fix upload xlsx files (#12306) --- mindsdb/integrations/utilities/files/file_reader.py | 4 ++++ mindsdb/interfaces/file/file_controller.py | 1 + 2 files changed, 5 insertions(+) diff --git a/mindsdb/integrations/utilities/files/file_reader.py b/mindsdb/integrations/utilities/files/file_reader.py index f3d013bf200..ab88dbfc486 100644 --- a/mindsdb/integrations/utilities/files/file_reader.py +++ b/mindsdb/integrations/utilities/files/file_reader.py @@ -121,6 +121,10 @@ def __init__( self.parameters = {} + def close(self): + if self.file_obj is not None: + self.file_obj.close() + def get_format(self) -> str: if self.format is not None: return self.format diff --git a/mindsdb/interfaces/file/file_controller.py b/mindsdb/interfaces/file/file_controller.py index 5dfa7c05360..cb1308a952f 100644 --- a/mindsdb/interfaces/file/file_controller.py +++ b/mindsdb/interfaces/file/file_controller.py @@ -169,6 +169,7 @@ def get_file_pages(self, source_path: str): """ file_reader = FileReader(path=source_path) tables = file_reader.get_contents() + file_reader.close() pages_files = {} pages_index = {} From f452a372836942eb03eac3ea403058601c491b40 Mon Sep 17 00:00:00 2001 From: Andrey Date: Tue, 24 Mar 2026 16:18:35 +0300 Subject: [PATCH 083/125] Faiss in CI and default vector db (#12265) --- .github/workflows/tests_unit.yml | 4 +- mindsdb/api/executor/sql_query/sql_query.py | 20 +- .../interfaces/knowledge_base/controller.py | 32 +- requirements/requirements-kb.txt | 3 +- scripts/run_unit_tests.sh | 4 - tests/scripts/check_requirements.py | 3 +- tests/unit/api/http/knowledge_bases_test.py | 8 +- tests/unit/executor/test_agent.py | 6 +- tests/unit/executor/test_knowledge_base.py | 55 ++- tests/unit/executor/test_lowercase.py | 6 +- .../planner/test_select_from_predictor.py | 385 ++++++++---------- 11 files changed, 252 insertions(+), 274 deletions(-) diff --git a/.github/workflows/tests_unit.yml b/.github/workflows/tests_unit.yml index 3a9455a7f76..08c78f54f5e 100644 --- a/.github/workflows/tests_unit.yml +++ b/.github/workflows/tests_unit.yml @@ -32,7 +32,7 @@ env: github ms_teams statsforecast - chromadb + duckdb_faiss confluence # We measure 80% on this handlers, as they are the verified HANDLERS_TO_VERIFY: | @@ -163,8 +163,6 @@ jobs: uv pip install ".[agents,kb]" \ -r requirements/requirements-test.txt \ "${HANDLER_EXTRAS[@]}" - # Onuxruntime is required for ChromaDB, once we have default pgvector we can remove it - uv pip install --force-reinstall onnxruntime==1.20.1 git clone --branch v$(uv pip show mindsdb_sql_parser | grep Version | cut -d ' ' -f 2) https://github.com/mindsdb/mindsdb_sql_parser.git parser_tests - name: Run unit tests diff --git a/mindsdb/api/executor/sql_query/sql_query.py b/mindsdb/api/executor/sql_query/sql_query.py index 4e3c2a55097..aae2902f713 100644 --- a/mindsdb/api/executor/sql_query/sql_query.py +++ b/mindsdb/api/executor/sql_query/sql_query.py @@ -279,7 +279,14 @@ def execute_query(self): if self.planner.plan.is_async and ctx.task_id is None: # release KB locks before inserting in background - self.release_kb_lock(steps) + db_released, partition_params = self.release_kb_lock(steps) + if db_released: + # faiss db is used as a table to insert + if partition_params.get("threads", 1) > 1: + raise ValueError( + "It is not possible to use threads for FAISS knowledge base, " + f"please remove `threads={partition_params['threads']}` parameter" + ) # add to task self.run_query.add_to_task() @@ -348,11 +355,18 @@ def execute_step(self, step, steps_data=None): def release_kb_lock(self, steps): # find knowledge bases that are used as tables to insert. # then release locks of vector for these knowledge bases + # return partition step params and databases names that were unlocked + db_released, partition_params = [], {} for step in steps: if isinstance(step, InsertToTable): - self.session.kb_controller.release_lock(step.table, project_name=self.database) + db_name = self.session.kb_controller.release_lock(step.table, project_name=self.database) + if db_name: + db_released.append(db_name) if isinstance(step, FetchDataframeStepPartition): - self.release_kb_lock(step.steps) + dbs, _ = self.release_kb_lock(step.steps) + db_released.extend(dbs) + partition_params.update(step.params) + return db_released, partition_params SQLQuery.register_steps() diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index d2eabfe5c41..ff4eca4f75a 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -1283,11 +1283,21 @@ def add( vector_db_name = self._create_persistent_pgvector(vector_db_params) params["default_vector_storage"] = vector_db_name else: - raise ValueError( - "Vector table is not defined. Set it by `storage=vector_db.vector_table`. " - "One of the options is to use pgvector: " - "https://docs.mindsdb.com/integrations/vector-db-integrations/pgvector" - ) + # try faiss + module = self.session.integration_controller.get_handler_module("duckdb_faiss") + if module is None or module.Handler is None: + raise ValueError( + "Vector table is not defined. Set it by `storage=vector_db.vector_table`. " + "One of the options is to use pgvector: " + "https://docs.mindsdb.com/integrations/vector-db-integrations/pgvector" + ) + + # create faiss db with same name + vector_table_name = "data" + vector_db_name = self._create_persistent_faiss(name) + # memorize to remove it later + params["default_vector_storage"] = vector_db_name + elif len(storage.parts) != 2: raise ValueError("Storage param has to be vector db with table") else: @@ -1465,6 +1475,16 @@ def _create_persistent_pgvector(self, params=None): self.session.integration_controller.add(vector_store_name, "pgvector", params or {}) return vector_store_name + def _create_persistent_faiss(self, kb_name: str): + vector_store_name = f"store_{kb_name}" + + # check if exists + if self.session.integration_controller.get(vector_store_name): + return vector_store_name + + self.session.integration_controller.add(vector_store_name, "duckdb_faiss", {}) + return vector_store_name + def _create_persistent_chroma(self, kb_name, engine="chromadb"): """Create default vector database for knowledge base, if not specified""" @@ -1634,6 +1654,7 @@ def release_lock(self, knowledge_base: Identifier, project_name): # works only for FAISS dbs. # if FAISS vector db is used in KB: remove this db from handlers cache. # it will clear internal cache of tables in faiss handler and release locks for faiss files + # return unloaded database name if len(knowledge_base.parts) > 1: project_name, kb_name = knowledge_base.parts[-2:] @@ -1650,3 +1671,4 @@ def release_lock(self, knowledge_base: Identifier, project_name): if database.engine == "duckdb_faiss": self.session.integration_controller.handlers_cache.delete(database.name) + return database.name diff --git a/requirements/requirements-kb.txt b/requirements/requirements-kb.txt index 4d9a44112af..576ff256c22 100644 --- a/requirements/requirements-kb.txt +++ b/requirements/requirements-kb.txt @@ -1,3 +1,4 @@ lxml==5.3.0 # Is this transitive dependency? pgvector==0.3.6 # Required for knowledge bases -litellm==1.63.14 \ No newline at end of file +litellm==1.63.14 +faiss-cpu==1.13.2 # default vector storage diff --git a/scripts/run_unit_tests.sh b/scripts/run_unit_tests.sh index 28d3dc1c33c..275fc07f517 100755 --- a/scripts/run_unit_tests.sh +++ b/scripts/run_unit_tests.sh @@ -217,10 +217,6 @@ for handler in "${HANDLERS_TO_INSTALL[@]}"; do -r requirements/requirements-test.txt \ "${HANDLER_EXTRAS[@]}" - # Install onnxruntime for ChromaDB - echo "Installing onnxruntime..." - uv pip install --force-reinstall onnxruntime==1.20.1 - # Clone parser tests PARSER_VERSION=$(uv pip show mindsdb_sql_parser | grep Version | cut -d ' ' -f 2) if [[ ! -d "parser_tests" ]]; then diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index be5c2172baf..08766c9b2cd 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -109,6 +109,7 @@ def get_requirements_with_DEP002(path): "litellm", "numba", # required in a few files for the hierarchicalforecast. Otherwise, uv may install an old version. "urllib3", # pinned by Snyk to avoid a vulnerability + "faiss-cpu", ], } @@ -136,7 +137,7 @@ def get_requirements_with_DEP002(path): HUGGINGFACE_DEP002_IGNORE_HANDLER_DEPS = ["torch"] -RAG_DEP002_IGNORE_HANDLER_DEPS = ["sentence-transformers", "faiss-cpu"] +RAG_DEP002_IGNORE_HANDLER_DEPS = ["sentence-transformers"] SOLR_DEP002_IGNORE_HANDLER_DEPS = ["sqlalchemy-solr"] diff --git a/tests/unit/api/http/knowledge_bases_test.py b/tests/unit/api/http/knowledge_bases_test.py index b4bd4f3488d..92b50b6bf95 100644 --- a/tests/unit/api/http/knowledge_bases_test.py +++ b/tests/unit/api/http/knowledge_bases_test.py @@ -3,17 +3,17 @@ from unittest.mock import patch -@patch("mindsdb.integrations.handlers.chromadb_handler.chromadb_handler.ChromaDBHandler") +@patch("mindsdb.integrations.handlers.duckdb_faiss_handler.duckdb_faiss_handler.DuckDBFaissHandler") @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") -def test_update_kb_embeddings(mock_embedding, chroma, client): +def test_update_kb_embeddings(mock_embedding, handler, client): # for test of embeddings mock_embedding().data = [{"embedding": [0.1, 0.2]}] integration_data = { "database": { "name": "kb_vector_db", - "engine": "chromadb", - "parameters": {"persist_directory": "kb_vector_db"}, + "engine": "duckdb_faiss", + "parameters": {}, } } response = client.post("/api/databases", json=integration_data, follow_redirects=True) diff --git a/tests/unit/executor/test_agent.py b/tests/unit/executor/test_agent.py index 88306a748f9..33a68720ff6 100644 --- a/tests/unit/executor/test_agent.py +++ b/tests/unit/executor/test_agent.py @@ -340,11 +340,7 @@ def test_agent_stream(self, mock_openai): def _create_kb_storage(self, kb_name): self.run_sql(f""" create database db_{kb_name} - with - engine='chromadb', - PARAMETERS = {{ - 'persist_directory': '{kb_name}' - }} + with engine='duckdb_faiss' """) return f"db_{kb_name}.default_collection" diff --git a/tests/unit/executor/test_knowledge_base.py b/tests/unit/executor/test_knowledge_base.py index 51fa41efe74..621c78f1d65 100644 --- a/tests/unit/executor/test_knowledge_base.py +++ b/tests/unit/executor/test_knowledge_base.py @@ -8,7 +8,6 @@ import pandas as pd import pytest -import sys from tests.unit.executor_test_base import BaseExecutorDummyML from mindsdb.integrations.utilities.rag.rerankers.base_reranker import ( @@ -133,7 +132,6 @@ def _create_kb( ) def _get_storage_table(self, kb_name): - # default chromadb db_name = f"db_{kb_name}" self._drop_storage_db(db_name) @@ -141,10 +139,7 @@ def _get_storage_table(self, kb_name): self.run_sql(f""" create database {db_name} with - engine='chromadb', - PARAMETERS = {{ - 'persist_directory': '{kb_name}' - }} + engine='duckdb_faiss' """) self.storages.append(db_name) @@ -191,7 +186,7 @@ def test_kb(self, mock_litellm_embedding): ret = self.run_sql("select * from kb_review") assert len(ret) == 1 - # show tables in default chromadb + # show tables in default vectordb ret = self.run_sql("show knowledge bases") db_name = ret.STORAGE[0].split(".")[0] @@ -480,8 +475,6 @@ def test_join_kb_table(self, mock_litellm_embedding): assert set(ret["id"]) == {"9016", "9023"} @pytest.mark.slow - @pytest.mark.skipif(sys.platform == "win32", reason="Causes hard crash on windows.") - @pytest.mark.skipif(sys.platform == "darwin", reason="Causes hard crash on mac.") @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") @patch("mindsdb.integrations.handlers.postgres_handler.Handler") def test_kb_partitions(self, mock_handler, mock_litellm_embedding): @@ -581,13 +574,14 @@ def native_query_with_generator(*args, **kwargs): """ ) - # test threads - check_partition( - """ - insert into kb_part SELECT id, english FROM pg.ral - using batch_size=20, track_column=id, threads = 3 - """ - ) + # switched off for faiss + # # test threads + # check_partition( + # """ + # insert into kb_part SELECT id, english FROM pg.ral + # using batch_size=20, track_column=id, threads = 3 + # """ + # ) # without track column check_partition( @@ -616,13 +610,14 @@ def native_query_without_generator(*args, **kwargs): """ ) - # test threads - check_partition( - """ - insert into kb_part SELECT id, english FROM pg.ral - using batch_size=20, track_column=id, threads = 3 - """ - ) + # switched off for faiss + # # test threads + # check_partition( + # """ + # insert into kb_part SELECT id, english FROM pg.ral + # using batch_size=20, track_column=id, threads = 3 + # """ + # ) @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") def test_kb_algebra(self, mock_litellm_embedding): @@ -1137,16 +1132,16 @@ def test_dimension_mismatch(self, mock_litellm_embedding): temp_dir = tempfile.mkdtemp() self.run_sql(f""" - create database my_chroma - with - engine='chromadb', + create database my_faiss + with + engine='duckdb_faiss', PARAMETERS = {{ 'persist_directory': '{temp_dir}' }} """) set_litellm_embedding(mock_litellm_embedding, dimension=1000) - self._create_kb("kb1", storage="my_chroma.table1") + self._create_kb("kb1", storage="my_faiss.table1") self.run_sql("insert into kb1 (content) values ('review')") @@ -1154,11 +1149,11 @@ def test_dimension_mismatch(self, mock_litellm_embedding): set_litellm_embedding(mock_litellm_embedding, dimension=1500) with pytest.raises(ValueError): - self._create_kb("kb2", storage="my_chroma.table1") + self._create_kb("kb2", storage="my_faiss.table1") self.run_sql("drop knowledge base kb1") - self.run_sql("drop table my_chroma.table1") - self.run_sql("drop database my_chroma") + self.run_sql("drop table my_faiss.table1") + self.run_sql("drop database my_faiss") @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") def test_duplicated_ids(self, mock_litellm_embedding): diff --git a/tests/unit/executor/test_lowercase.py b/tests/unit/executor/test_lowercase.py index 8f8e0a74870..e5ea99bee69 100644 --- a/tests/unit/executor/test_lowercase.py +++ b/tests/unit/executor/test_lowercase.py @@ -211,11 +211,7 @@ def test_knowledgebase_name_lowercase(self, mock_openai, mock_litellm_embedding) self.run_sql(""" create database my_kb_storage - with - engine='chromadb', - PARAMETERS = { - 'persist_directory': 'my_kb_storage' - } + with engine='duckdb_faiss' """) kb_params = """ diff --git a/tests/unit/planner/test_select_from_predictor.py b/tests/unit/planner/test_select_from_predictor.py index 38a1e65f1ff..85ccf4af365 100644 --- a/tests/unit/planner/test_select_from_predictor.py +++ b/tests/unit/planner/test_select_from_predictor.py @@ -1,14 +1,17 @@ import pytest from mindsdb_sql_parser import parse_sql -from mindsdb_sql_parser.ast import (Identifier, Select, Constant, Star, Parameter, BinaryOperation) +from mindsdb_sql_parser.ast import Identifier, Select, Constant, Star, Parameter, BinaryOperation from mindsdb.api.executor.planner.exceptions import PlanningException from mindsdb.api.executor.planner import plan_query from mindsdb.api.executor.planner.query_plan import QueryPlan from mindsdb.api.executor.planner.step_result import Result from mindsdb.api.executor.planner.steps import ( - ProjectStep, ApplyPredictorRowStep, GetPredictorColumns, FetchDataframeStep + ProjectStep, + ApplyPredictorRowStep, + GetPredictorColumns, + FetchDataframeStep, ) @@ -16,347 +19,334 @@ class TestPlanSelectFromPredictor: def test_select_from_predictor_plan(self): query = Select( targets=[Star()], - from_table=Identifier('mindsdb.pred'), + from_table=Identifier("mindsdb.pred"), where=BinaryOperation( - op='and', - args=[BinaryOperation(op='=', args=[Identifier('x1'), Constant(1)]), - BinaryOperation(op='=', args=[Identifier('x2'), Constant('2')])], - ) + op="and", + args=[ + BinaryOperation(op="=", args=[Identifier("x1"), Constant(1)]), + BinaryOperation(op="=", args=[Identifier("x2"), Constant("2")]), + ], + ), ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ - ApplyPredictorRowStep( - namespace='mindsdb', predictor=Identifier('pred'), - row_dict={'x1': 1, 'x2': '2'} - ), + ApplyPredictorRowStep(namespace="mindsdb", predictor=Identifier("pred"), row_dict={"x1": 1, "x2": "2"}), ], - ) - plan = plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan = plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) assert plan.steps == expected_plan.steps def test_select_from_predictor_negative_constant(self): query = parse_sql( - ''' + """ select * from mindsdb.pred where x1 = -1 - ''' + """ ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ - ApplyPredictorRowStep(namespace='mindsdb', predictor=Identifier('pred'), row_dict={'x1': -1, }), + ApplyPredictorRowStep( + namespace="mindsdb", + predictor=Identifier("pred"), + row_dict={ + "x1": -1, + }, + ), ], ) - plan = plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan = plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) assert plan.steps == expected_plan.steps def test_select_from_predictor_plan_other_ml(self): query = parse_sql( - ''' + """ select * from mlflow.pred where x1 = 1 and x2 = '2' - ''' + """ ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ - ApplyPredictorRowStep( - namespace='mlflow', predictor=Identifier('pred'), - row_dict={'x1': 1, 'x2': '2'} - ), + ApplyPredictorRowStep(namespace="mlflow", predictor=Identifier("pred"), row_dict={"x1": 1, "x2": "2"}), ], - ) - plan = plan_query(query, predictor_metadata=[{'name': 'pred', 'integration_name': 'mlflow'}]) + plan = plan_query(query, predictor_metadata=[{"name": "pred", "integration_name": "mlflow"}]) assert plan.steps == expected_plan.steps def test_select_from_predictor_aliases_in_project(self): query = Select( - targets=[Identifier('tb.x1', alias=Identifier('col1')), - Identifier('tb.x2', alias=Identifier('col2')), - Identifier('tb.y', alias=Identifier('predicted'))], - from_table=Identifier('mindsdb.pred', alias=Identifier('tb')), + targets=[ + Identifier("tb.x1", alias=Identifier("col1")), + Identifier("tb.x2", alias=Identifier("col2")), + Identifier("tb.y", alias=Identifier("predicted")), + ], + from_table=Identifier("mindsdb.pred", alias=Identifier("tb")), where=BinaryOperation( - op='and', + op="and", args=[ - BinaryOperation(op='=', args=[Identifier('tb.x1'), Constant(1)]), - BinaryOperation(op='=', args=[Identifier('tb.x2'), Constant('2')]), + BinaryOperation(op="=", args=[Identifier("tb.x1"), Constant(1)]), + BinaryOperation(op="=", args=[Identifier("tb.x2"), Constant("2")]), ], - ) + ), ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ ApplyPredictorRowStep( - namespace='mindsdb', - predictor=Identifier('pred', alias=Identifier('tb')), - row_dict={'x1': 1, 'x2': '2'} + namespace="mindsdb", + predictor=Identifier("pred", alias=Identifier("tb")), + row_dict={"x1": 1, "x2": "2"}, ), ProjectStep( dataframe=Result(0), - columns=[Identifier('tb.x1', alias=Identifier('col1')), - Identifier('tb.x2', alias=Identifier('col2')), - Identifier('tb.y', alias=Identifier('predicted'))] + columns=[ + Identifier("tb.x1", alias=Identifier("col1")), + Identifier("tb.x2", alias=Identifier("col2")), + Identifier("tb.y", alias=Identifier("predicted")), + ], ), ], - ) - plan = plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan = plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) assert plan.steps == expected_plan.steps def test_select_from_predictor_plan_predictor_alias(self): query = Select( targets=[Star()], - from_table=Identifier('mindsdb.pred', alias=Identifier('pred_alias')), + from_table=Identifier("mindsdb.pred", alias=Identifier("pred_alias")), where=BinaryOperation( - op='and', + op="and", args=[ - BinaryOperation(op='=', args=[Identifier('pred_alias.x1'), Constant(1)]), - BinaryOperation( - op='=', - args=[Identifier('pred_alias.x2'), Constant('2')] - ) + BinaryOperation(op="=", args=[Identifier("pred_alias.x1"), Constant(1)]), + BinaryOperation(op="=", args=[Identifier("pred_alias.x2"), Constant("2")]), ], - ) + ), ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ ApplyPredictorRowStep( - namespace='mindsdb', predictor=Identifier('pred', alias=Identifier('pred_alias')), - row_dict={'x1': 1, 'x2': '2'} + namespace="mindsdb", + predictor=Identifier("pred", alias=Identifier("pred_alias")), + row_dict={"x1": 1, "x2": "2"}, ), ], ) - plan = plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan = plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) assert plan.steps == expected_plan.steps def test_select_from_predictor_plan_verbose_col_names(self): query = Select( targets=[Star()], - from_table=Identifier('mindsdb.pred'), + from_table=Identifier("mindsdb.pred"), where=BinaryOperation( - op='and', - args=[BinaryOperation(op='=', args=[Identifier('pred.x1'), Constant(1)]), - BinaryOperation(op='=', args=[Identifier('pred.x2'), Constant('2')])], - ) + op="and", + args=[ + BinaryOperation(op="=", args=[Identifier("pred.x1"), Constant(1)]), + BinaryOperation(op="=", args=[Identifier("pred.x2"), Constant("2")]), + ], + ), ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ - ApplyPredictorRowStep( - namespace='mindsdb', predictor=Identifier('pred'), - row_dict={'x1': 1, 'x2': '2'} - ), + ApplyPredictorRowStep(namespace="mindsdb", predictor=Identifier("pred"), row_dict={"x1": 1, "x2": "2"}), ProjectStep(dataframe=Result(0), columns=[Star()]), ], ) - plan = plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan = plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) for i in range(len(plan.steps)): assert plan.steps[i] == expected_plan.steps[i] def test_select_from_predictor_plan_group_by_error(self): query = Select( - targets=[Identifier('x1'), Identifier('x2'), Identifier('pred.y')], - from_table=Identifier('mindsdb.pred'), - group_by=[Identifier('x1')] + targets=[Identifier("x1"), Identifier("x2"), Identifier("pred.y")], + from_table=Identifier("mindsdb.pred"), + group_by=[Identifier("x1")], ) with pytest.raises(PlanningException): - plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) def test_select_from_predictor_wrong_where_op_error(self): query = Select( targets=[Star()], - from_table=Identifier('mindsdb.pred'), + from_table=Identifier("mindsdb.pred"), where=BinaryOperation( - op='and', - args=[BinaryOperation(op='>', args=[Identifier('x1'), Constant(1)]), - BinaryOperation(op='=', args=[Identifier('x2'), Constant('2')])], - ) + op="and", + args=[ + BinaryOperation(op=">", args=[Identifier("x1"), Constant(1)]), + BinaryOperation(op="=", args=[Identifier("x2"), Constant("2")]), + ], + ), ) with pytest.raises(PlanningException): - plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) def test_select_from_predictor_multiple_values_error(self): query = Select( targets=[Star()], - from_table=Identifier('mindsdb.pred'), + from_table=Identifier("mindsdb.pred"), where=BinaryOperation( - op='and', - args=[BinaryOperation(op='=', args=[Identifier('x1'), Constant(1)]), - BinaryOperation(op='=', args=[Identifier('x1'), Constant('2')])], - ) + op="and", + args=[ + BinaryOperation(op="=", args=[Identifier("x1"), Constant(1)]), + BinaryOperation(op="=", args=[Identifier("x1"), Constant("2")]), + ], + ), ) with pytest.raises(PlanningException): - plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) def test_select_from_predictor_no_where_error(self): - query = Select( - targets=[Star()], - from_table=Identifier('mindsdb.pred') - ) + query = Select(targets=[Star()], from_table=Identifier("mindsdb.pred")) with pytest.raises(PlanningException): - plan_query(query, predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + plan_query(query, predictor_namespace="mindsdb", predictor_metadata={"pred": {}}) def test_select_from_predictor_default_namespace(self): query = Select( targets=[Star()], - from_table=Identifier('pred'), + from_table=Identifier("pred"), where=BinaryOperation( - op='and', - args=[BinaryOperation(op='=', args=[Identifier('x1'), Constant(1)]), - BinaryOperation(op='=', args=[Identifier('x2'), Constant('2')])], - ) + op="and", + args=[ + BinaryOperation(op="=", args=[Identifier("x1"), Constant(1)]), + BinaryOperation(op="=", args=[Identifier("x2"), Constant("2")]), + ], + ), ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', - default_namespace='mindsdb', + predictor_namespace="mindsdb", + default_namespace="mindsdb", steps=[ - ApplyPredictorRowStep( - namespace='mindsdb', predictor=Identifier('pred'), - row_dict={'x1': 1, 'x2': '2'} - ), + ApplyPredictorRowStep(namespace="mindsdb", predictor=Identifier("pred"), row_dict={"x1": 1, "x2": "2"}), ], ) plan = plan_query( - query, predictor_namespace='mindsdb', default_namespace='mindsdb', predictor_metadata={'pred': {}} + query, predictor_namespace="mindsdb", default_namespace="mindsdb", predictor_metadata={"pred": {}} ) assert plan.steps == expected_plan.steps def test_select_from_predictor_get_columns(self): - sql = 'SELECT GDP_per_capita_USD FROM hdi_predictor_external WHERE 1 = 0' + sql = "SELECT GDP_per_capita_USD FROM hdi_predictor_external WHERE 1 = 0" query = parse_sql(sql) expected_query = Select( - targets=[Identifier('GDP_per_capita_USD')], - from_table=Identifier('hdi_predictor_external'), - where=BinaryOperation( - op="=", - args=[Constant(1), Constant(0)] - ) + targets=[Identifier("GDP_per_capita_USD")], + from_table=Identifier("hdi_predictor_external"), + where=BinaryOperation(op="=", args=[Constant(1), Constant(0)]), ) assert query.to_tree() == expected_query.to_tree() expected_plan = QueryPlan( - predictor_namespace='mindsdb', - default_namespace='mindsdb', + predictor_namespace="mindsdb", + default_namespace="mindsdb", steps=[ - GetPredictorColumns( - namespace='mindsdb', - predictor=Identifier('hdi_predictor_external') - ), - ProjectStep(dataframe=Result(0), columns=[Identifier('GDP_per_capita_USD')]), + GetPredictorColumns(namespace="mindsdb", predictor=Identifier("hdi_predictor_external")), + ProjectStep(dataframe=Result(0), columns=[Identifier("GDP_per_capita_USD")]), ], ) plan = plan_query( - query, predictor_namespace='mindsdb', default_namespace='mindsdb', - predictor_metadata={'hdi_predictor_external': {}} + query, + predictor_namespace="mindsdb", + default_namespace="mindsdb", + predictor_metadata={"hdi_predictor_external": {}}, ) assert plan.steps == expected_plan.steps def test_using_predictor_version(self): query = parse_sql( - ''' + """ select * from mindsdb.pred.21 where x1 = 1 - ''' + """ ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ ApplyPredictorRowStep( - namespace='mindsdb', predictor=Identifier(parts=['pred', '21']), - row_dict={'x1': 1} + namespace="mindsdb", predictor=Identifier(parts=["pred", "21"]), row_dict={"x1": 1} ) ], ) - plan = plan_query(query, predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}]) + plan = plan_query(query, predictor_metadata=[{"name": "pred", "integration_name": "mindsdb"}]) assert plan.steps == expected_plan.steps def test_select_from_predictor_subselect(self): query = parse_sql( - ''' + """ select * from mindsdb.pred.21 where x1 = (select id from int1.t1) - ''' + """ ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ FetchDataframeStep( - integration='int1', - query=parse_sql('select id as id from t1'), + integration="int1", + query=parse_sql("select id as id from t1"), ), ApplyPredictorRowStep( - namespace='mindsdb', - predictor=Identifier(parts=['pred', '21']), - row_dict={'x1': Parameter(Result(0))} - ) + namespace="mindsdb", + predictor=Identifier(parts=["pred", "21"]), + row_dict={"x1": Parameter(Result(0))}, + ), ], ) plan = plan_query( - query, - integrations=['int1'], - predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}] + query, integrations=["int1"], predictor_metadata=[{"name": "pred", "integration_name": "mindsdb"}] ) assert plan.steps == expected_plan.steps def test_select_from_view_subselect(self): query = parse_sql( - ''' + """ select * from v1 where x1 in (select id from int1.tab1) - ''' + """ ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ FetchDataframeStep( - integration='int1', - query=parse_sql('select id as id from tab1'), + integration="int1", + query=parse_sql("select id as id from tab1"), ), FetchDataframeStep( - integration='mindsdb', + integration="mindsdb", query=Select( targets=[Star()], - from_table=Identifier('v1'), - where=BinaryOperation( - op='in', - args=[ - Identifier(parts=['x1']), - Parameter(Result(0)) - ] - ) + from_table=Identifier("v1"), + where=BinaryOperation(op="in", args=[Identifier(parts=["x1"]), Parameter(Result(0))]), ), ), ], @@ -364,81 +354,66 @@ def test_select_from_view_subselect(self): plan = plan_query( query, - integrations=['int1'], - default_namespace='mindsdb', - predictor_metadata=[{'name': 'pred', 'integration_name': 'mindsdb'}] + integrations=["int1"], + default_namespace="mindsdb", + predictor_metadata=[{"name": "pred", "integration_name": "mindsdb"}], ) assert plan.steps == expected_plan.steps def test_select_from_view_subselect_view(self): query = parse_sql( - ''' + """ select * from v1 where x1 in (select v2.id from v2) - ''' + """ ) expected_plan = QueryPlan( - predictor_namespace='mindsdb', + predictor_namespace="mindsdb", steps=[ FetchDataframeStep( - integration='mindsdb', - query=parse_sql('select v2.id as id from v2'), + integration="mindsdb", + query=parse_sql("select v2.id as id from v2"), ), FetchDataframeStep( - integration='mindsdb', + integration="mindsdb", query=Select( targets=[Star()], - from_table=Identifier('v1'), - where=BinaryOperation( - op='in', - args=[ - Identifier(parts=['x1']), - Parameter(Result(0)) - ] - ) + from_table=Identifier("v1"), + where=BinaryOperation(op="in", args=[Identifier(parts=["x1"]), Parameter(Result(0))]), ), ), ], ) - plan = plan_query( - query, - integrations=[], - default_namespace='mindsdb', - predictor_metadata=[] - ) + plan = plan_query(query, integrations=[], default_namespace="mindsdb", predictor_metadata=[]) assert plan.steps == expected_plan.steps class TestMLSelect: - def test_select_from_predictor_plan_other_ml(self): # sends to integrations - query = parse_sql(''' select * from mlflow.predictors ''') + query = parse_sql(""" select * from mlflow.predictors """) expected_plan = QueryPlan( - steps=[ - FetchDataframeStep(step_num=0, integration='mlflow', query=parse_sql('SELECT * FROM predictors')) - ], + steps=[FetchDataframeStep(step_num=0, integration="mlflow", query=parse_sql("SELECT * FROM predictors"))], ) - plan = plan_query(query, predictor_metadata=[], integrations=['mlflow']) + plan = plan_query(query, predictor_metadata=[], integrations=["mlflow"]) assert plan.steps == expected_plan.steps class TestNestedSelect: - def test_using_predictor_in_subselect(self): """ Use predictor in subselect when selecting from integration """ sql = """ SELECT * - FROM chromadb.test_tabl + FROM vectordb.test_tabl WHERE search_vector = ( SELECT emebddings @@ -450,37 +425,25 @@ def test_using_predictor_in_subselect(self): ast_tree = parse_sql(sql) plan = plan_query( ast_tree, - integrations=['chromadb'], - predictor_metadata=[ - {'name': 'embedding_model', 'integration_name': 'mindsdb'} - ] + integrations=["vectordb"], + predictor_metadata=[{"name": "embedding_model", "integration_name": "mindsdb"}], ) expected_plan = [ ApplyPredictorRowStep( step_num=0, - namespace='mindsdb', - predictor=Identifier(parts=['embedding_model']), - row_dict={'content': 'some text'} - ), - ProjectStep( - step_num=1, - dataframe=Result(0), - columns=[Identifier(parts=['emebddings'])] + namespace="mindsdb", + predictor=Identifier(parts=["embedding_model"]), + row_dict={"content": "some text"}, ), + ProjectStep(step_num=1, dataframe=Result(0), columns=[Identifier(parts=["emebddings"])]), FetchDataframeStep( step_num=2, - integration='chromadb', + integration="vectordb", query=Select( targets=[Star()], - from_table=Identifier(parts=['test_tabl']), - where=BinaryOperation( - op='=', - args=[ - Identifier(parts=['search_vector']), - Parameter(Result(1)) - ] - ) + from_table=Identifier(parts=["test_tabl"]), + where=BinaryOperation(op="=", args=[Identifier(parts=["search_vector"]), Parameter(Result(1))]), ), ), ] @@ -498,31 +461,27 @@ def test_using_integration_in_subselect(self): WHERE content = ( SELECT content - FROM chromadb.test_tabl + FROM vectordb.test_tabl LIMIT 1 ) """ ast_tree = parse_sql(sql) plan = plan_query( ast_tree, - integrations=['chromadb'], - predictor_metadata=[ - {'name': 'embedding_model', 'integration_name': 'mindsdb'} - ] + integrations=["vectordb"], + predictor_metadata=[{"name": "embedding_model", "integration_name": "mindsdb"}], ) expected_plan = [ FetchDataframeStep( - step_num=0, - integration='chromadb', - query=parse_sql('SELECT content AS content FROM test_tabl LIMIT 1') + step_num=0, integration="vectordb", query=parse_sql("SELECT content AS content FROM test_tabl LIMIT 1") ), ApplyPredictorRowStep( step_num=1, - namespace='mindsdb', - predictor=Identifier(parts=['embedding_model']), - row_dict={'content': Parameter(Result(0))} - ) + namespace="mindsdb", + predictor=Identifier(parts=["embedding_model"]), + row_dict={"content": Parameter(Result(0))}, + ), ] assert plan.steps == expected_plan From 562e8bbccec2b91446bb82b647261a7a645379e5 Mon Sep 17 00:00:00 2001 From: jnMetaCode <1394485448@qq.com> Date: Tue, 24 Mar 2026 21:28:36 +0800 Subject: [PATCH 084/125] fix: correct isinstance check for redirect_url in database endpoints (#12279) Signed-off-by: JiangNan <1394485448@qq.com> --- mindsdb/api/http/namespaces/databases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindsdb/api/http/namespaces/databases.py b/mindsdb/api/http/namespaces/databases.py index 895bc59656f..f5a75f6bc73 100644 --- a/mindsdb/api/http/namespaces/databases.py +++ b/mindsdb/api/http/namespaces/databases.py @@ -69,7 +69,7 @@ def post(self): status = HandlerStatusResponse(success=False, error_message=str(import_error)) if status.success is not True: - if hasattr(status, "redirect_url") and isinstance(status, str): + if hasattr(status, "redirect_url") and isinstance(status.redirect_url, str): return { "status": "redirect_required", "redirect_url": status.redirect_url, @@ -136,7 +136,7 @@ def post(self): shutil.rmtree(temp_dir) if not status.success: - if hasattr(status, "redirect_url") and isinstance(status, str): + if hasattr(status, "redirect_url") and isinstance(status.redirect_url, str): return { "status": "redirect_required", "redirect_url": status.redirect_url, From 729f59c9dd6442fdb2654dd2327f4e390b454445 Mon Sep 17 00:00:00 2001 From: Andrey Date: Wed, 25 Mar 2026 13:53:39 +0300 Subject: [PATCH 085/125] Remove litellm from CI (#12325) --- .github/workflows/tests_unit.yml | 1 + .../test_faiss_handler.py | 8 +- .../interfaces/knowledge_base/controller.py | 15 --- requirements/requirements-agents.txt | 1 - requirements/requirements-kb.txt | 1 - tests/unit/api/http/knowledge_bases_test.py | 6 +- tests/unit/executor/test_agent.py | 26 ++--- tests/unit/executor/test_knowledge_base.py | 108 +++++++++--------- tests/unit/executor/test_lowercase.py | 8 +- tests/unit/handlers/test_bigquery.py | 1 - 10 files changed, 78 insertions(+), 97 deletions(-) diff --git a/.github/workflows/tests_unit.yml b/.github/workflows/tests_unit.yml index 08c78f54f5e..eb8905a520e 100644 --- a/.github/workflows/tests_unit.yml +++ b/.github/workflows/tests_unit.yml @@ -34,6 +34,7 @@ env: statsforecast duckdb_faiss confluence + openai # We measure 80% on this handlers, as they are the verified HANDLERS_TO_VERIFY: | mysql diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py index a0950730067..6a2711cfbcb 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/test_faiss_handler.py @@ -3,7 +3,7 @@ import pandas as pd -from tests.unit.executor.test_knowledge_base import TestKB, set_litellm_embedding +from tests.unit.executor.test_knowledge_base import TestKB, set_embedding class TestFAISS(TestKB): @@ -32,15 +32,15 @@ def _get_storage_table(self, kb_name): return f"faiss_{kb_name}.kb_faiss" @pytest.mark.parametrize("index_type", ["ivf", "ivf_file"]) - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_ivf_index(self, mock_litellm_embedding, index_type): + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_ivf_index(self, mock_embedding, index_type): """ Run test two times: - make ivf index and then reindex to ivf_file - make ivf_file index and then reindex to ivf """ - set_litellm_embedding(mock_litellm_embedding) + set_embedding(mock_embedding) df = self._get_ral_table() diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index ff4eca4f75a..a85571edbd3 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -1023,21 +1023,6 @@ def _content_to_embeddings(self, content: str) -> List[float]: res = self._df_to_embeddings(df) return res[TableField.EMBEDDINGS.value][0] - @staticmethod - def call_litellm_embedding(session, model_params, messages): - args = copy.deepcopy(model_params) - - if "model_name" not in args: - raise ValueError("'model_name' must be provided for embedding model") - - llm_model = args.pop("model_name") - engine = args.pop("provider") - - module = session.integration_controller.get_handler_module("litellm") - if module is None or module.Handler is None: - raise ValueError(f'Unable to use "{engine}" provider. Litellm handler is not installed') - return module.Handler.embeddings(engine, llm_model, messages, args) - def build_rag_pipeline(self, retrieval_config: dict): """ Builds a RAG pipeline with returned sources diff --git a/requirements/requirements-agents.txt b/requirements/requirements-agents.txt index 7676cdd4ff4..233cdc35198 100644 --- a/requirements/requirements-agents.txt +++ b/requirements/requirements-agents.txt @@ -5,7 +5,6 @@ transformers >= 4.42.4 # Required for KB mindsdb-evaluator == 0.0.21 -litellm==1.63.14 mcp~=1.10.1 # Required for MCP server # A2A requirements diff --git a/requirements/requirements-kb.txt b/requirements/requirements-kb.txt index 576ff256c22..2479eb00040 100644 --- a/requirements/requirements-kb.txt +++ b/requirements/requirements-kb.txt @@ -1,4 +1,3 @@ lxml==5.3.0 # Is this transitive dependency? pgvector==0.3.6 # Required for knowledge bases -litellm==1.63.14 faiss-cpu==1.13.2 # default vector storage diff --git a/tests/unit/api/http/knowledge_bases_test.py b/tests/unit/api/http/knowledge_bases_test.py index 92b50b6bf95..4ccfccfe7a7 100644 --- a/tests/unit/api/http/knowledge_bases_test.py +++ b/tests/unit/api/http/knowledge_bases_test.py @@ -4,10 +4,10 @@ @patch("mindsdb.integrations.handlers.duckdb_faiss_handler.duckdb_faiss_handler.DuckDBFaissHandler") -@patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") +@patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") def test_update_kb_embeddings(mock_embedding, handler, client): # for test of embeddings - mock_embedding().data = [{"embedding": [0.1, 0.2]}] + mock_embedding().embeddings.return_value = [{"embedding": [0.1, 0.2]}] integration_data = { "database": { @@ -54,5 +54,5 @@ def test_update_kb_embeddings(mock_embedding, handler, client): ) assert update_response.status_code == HTTPStatus.OK - kwargs = mock_embedding.call_args_list[0][1] + kwargs = mock_embedding.call_args_list[0][0][0] assert kwargs["api_key"] == "embed-key-2" diff --git a/tests/unit/executor/test_agent.py b/tests/unit/executor/test_agent.py index 33a68720ff6..c1551b1d3fe 100644 --- a/tests/unit/executor/test_agent.py +++ b/tests/unit/executor/test_agent.py @@ -9,7 +9,7 @@ import sys from openai.types.chat import ChatCompletion from tests.unit.executor_test_base import BaseExecutorDummyML -from tests.unit.executor.test_knowledge_base import set_litellm_embedding +from tests.unit.executor.test_knowledge_base import set_embedding def action_response(type="final_query", sql="", text=""): @@ -351,10 +351,10 @@ def _drop_kb_storage(self, vector_table_name): self.run_sql(f"drop database {db_name}") - @patch("litellm.embedding") + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") @patch("pydantic_ai.providers.openai.AsyncOpenAI") - def test_agent_retrieval(self, mock_openai, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + def test_agent_retrieval(self, mock_openai, mock_embedding): + set_embedding(mock_embedding) vector_table_name = self._create_kb_storage("kb_review") self.run_sql(f""" @@ -464,9 +464,9 @@ def test_agent_default_prompt_template(self, mock_openai): assert agent_response in ret.answer[0] @patch("pydantic_ai.providers.openai.AsyncOpenAI") - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_agent_permissions(self, mock_litellm_embedding, mock_openai): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_agent_permissions(self, mock_embedding, mock_openai): + set_embedding(mock_embedding) vector_table_name = self._create_kb_storage("kb_show") @@ -581,9 +581,9 @@ def test_agent_permissions(self, mock_litellm_embedding, mock_openai): self._drop_kb_storage(vector_table_name) @patch("pydantic_ai.providers.openai.AsyncOpenAI") - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_agent_new_syntax(self, mock_litellm_embedding, mock_openai): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_agent_new_syntax(self, mock_embedding, mock_openai): + set_embedding(mock_embedding) vector_table_name = self._create_kb_storage("kb") df = get_dataset_planets() # create 2 files and KBs @@ -709,9 +709,9 @@ def test_agent_new_syntax(self, mock_litellm_embedding, mock_openai): self._drop_kb_storage(vector_table_name) @patch("pydantic_ai.providers.openai.AsyncOpenAI") - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_agent_accept_wrong_quoting(self, mock_litellm_embedding, mock_openai): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_agent_accept_wrong_quoting(self, mock_embedding, mock_openai): + set_embedding(mock_embedding) vector_table_name = self._create_kb_storage("kb1") self.run_sql(f""" create knowledge base kb1 diff --git a/tests/unit/executor/test_knowledge_base.py b/tests/unit/executor/test_knowledge_base.py index 621c78f1d65..1155bd1d1e0 100644 --- a/tests/unit/executor/test_knowledge_base.py +++ b/tests/unit/executor/test_knowledge_base.py @@ -59,13 +59,11 @@ def dummy_embeddings(string, dimension=None): return embeds -def set_litellm_embedding(mock_litellm_embedding, dimension=None): +def set_embedding(mock_embedding, dimension=None): def resp_f(input, *args, **kwargs): - mock_response = MagicMock() - mock_response.data = [{"embedding": dummy_embeddings(s, dimension)} for s in input] - return mock_response + return [dummy_embeddings(s, dimension) for s in input] - mock_litellm_embedding.side_effect = resp_f + mock_embedding().embeddings.side_effect = resp_f class BaseTestKB(BaseExecutorDummyML): @@ -93,7 +91,7 @@ def _create_kb( if embedding_model is None: embedding_model = { - "provider": "bedrock", + "provider": "openai", "model_name": "dummy_model", "api_key": "dummy_key", } @@ -174,9 +172,9 @@ def setup_method(self): config["knowledge_bases"]["disable_autobatch"] = True - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_kb(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_kb(self, mock_embedding): + set_embedding(mock_embedding) self._create_kb("kb_review") @@ -194,9 +192,9 @@ def test_kb(self, mock_litellm_embedding): # only one default collection there assert len(ret) == 1 - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_kb_metadata(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_kb_metadata(self, mock_embedding): + set_embedding(mock_embedding) record = { "review": "all is good, haven't used yet", @@ -407,9 +405,9 @@ async def _fake_call_llm(messages): # Fallback pattern should be descending assert scores[0] > scores[1] > scores[2] - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_join_kb_table(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_join_kb_table(self, mock_embedding): + set_embedding(mock_embedding) df = self._get_ral_table() self.save_file("ral", df) @@ -475,10 +473,10 @@ def test_join_kb_table(self, mock_litellm_embedding): assert set(ret["id"]) == {"9016", "9023"} @pytest.mark.slow - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") @patch("mindsdb.integrations.handlers.postgres_handler.Handler") - def test_kb_partitions(self, mock_handler, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + def test_kb_partitions(self, mock_handler, mock_embedding): + set_embedding(mock_embedding) df = self._get_ral_table() self.save_file("ral", df) @@ -619,9 +617,9 @@ def native_query_without_generator(*args, **kwargs): # """ # ) - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_kb_algebra(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_kb_algebra(self, mock_embedding): + set_embedding(mock_embedding) lines, i = [], 0 for color in ("white", "red", "green"): @@ -738,9 +736,9 @@ def test_kb_algebra(self, mock_litellm_embedding): else: assert "small" in content - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_select_allowed_columns(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_select_allowed_columns(self, mock_embedding): + set_embedding(mock_embedding) # -- no metadata are specified, generated from inserts -- self._create_kb("kb1") @@ -783,9 +781,9 @@ def test_select_allowed_columns(self, mock_litellm_embedding): @patch("mindsdb.interfaces.knowledge_base.llm_client.OpenAI") @patch("mindsdb.integrations.utilities.rag.rerankers.base_reranker.BaseLLMReranker.get_scores") - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_evaluate(self, mock_litellm_embedding, mock_get_scores, mock_openai): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_evaluate(self, mock_embedding, mock_get_scores, mock_openai): + set_embedding(mock_embedding) question, answer = "2+2", "4" agent_response = f""" @@ -903,13 +901,13 @@ def test_evaluate(self, mock_litellm_embedding, mock_get_scores, mock_openai): assert len(df) > 0 @patch("mindsdb.utilities.config.Config.get") - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") @patch("mindsdb.integrations.utilities.rag.rerankers.base_reranker.BaseLLMReranker.get_scores") - def test_save_default_params(self, mock_get_scores, mock_litellm_embedding, mock_config_get): + def test_save_default_params(self, mock_get_scores, mock_embedding, mock_config_get): # reranking result mock_get_scores.side_effect = lambda query, docs: [0.8 for _ in docs] - set_litellm_embedding(mock_litellm_embedding) + set_embedding(mock_embedding) def config_get_side_effect(key, default=None): if key == "default_embedding_model": @@ -943,10 +941,10 @@ def config_get_side_effect(key, default=None): assert "openai_model" not in ret["RERANKING_MODEL"][0] - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_relevance_filtering_gt_operator(self, mock_litellm_embedding): + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_relevance_filtering_gt_operator(self, mock_embedding): """Test relevance filtering with GREATER_THAN operator""" - set_litellm_embedding(mock_litellm_embedding) + set_embedding(mock_embedding) test_data = [ {"id": "1", "content": "This is about machine learning and AI"}, @@ -977,9 +975,9 @@ def test_relevance_filtering_gt_operator(self, mock_litellm_embedding): assert isinstance(ret, pd.DataFrame) @patch("mindsdb.integrations.utilities.rag.rerankers.base_reranker.BaseLLMReranker.get_scores") - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_alter_kb(self, mock_litellm_embedding, mock_get_scores): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_alter_kb(self, mock_embedding, mock_get_scores): + set_embedding(mock_embedding) self._create_kb( "kb1", @@ -1053,9 +1051,9 @@ def test_ollama(self, mock_openai, mock_get_scores): assert "api_key" not in ret["EMBEDDING_MODEL"][0] assert "api_key" not in ret["RERANKING_MODEL"][0] - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_kb_uppercase_source_columns(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_kb_uppercase_source_columns(self, mock_embedding): + set_embedding(mock_embedding) df = pd.DataFrame( [ @@ -1127,8 +1125,8 @@ def test_kb_uppercase_source_columns(self, mock_litellm_embedding): assert len(ret) == 2 assert ret["category"][0] == "Home" - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_dimension_mismatch(self, mock_litellm_embedding): + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_dimension_mismatch(self, mock_embedding): temp_dir = tempfile.mkdtemp() self.run_sql(f""" @@ -1140,13 +1138,13 @@ def test_dimension_mismatch(self, mock_litellm_embedding): }} """) - set_litellm_embedding(mock_litellm_embedding, dimension=1000) + set_embedding(mock_embedding, dimension=1000) self._create_kb("kb1", storage="my_faiss.table1") self.run_sql("insert into kb1 (content) values ('review')") # change dimension - set_litellm_embedding(mock_litellm_embedding, dimension=1500) + set_embedding(mock_embedding, dimension=1500) with pytest.raises(ValueError): self._create_kb("kb2", storage="my_faiss.table1") @@ -1155,9 +1153,9 @@ def test_dimension_mismatch(self, mock_litellm_embedding): self.run_sql("drop table my_faiss.table1") self.run_sql("drop database my_faiss") - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_duplicated_ids(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_duplicated_ids(self, mock_embedding): + set_embedding(mock_embedding) self._create_kb("kb1") @@ -1187,9 +1185,9 @@ def test_duplicated_ids(self, mock_litellm_embedding): ret = self.run_sql("select * from kb1 where id = 2") assert len(ret) == 1 - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_update(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_update(self, mock_embedding): + set_embedding(mock_embedding) self._create_kb("kb1") @@ -1208,9 +1206,9 @@ def test_update(self, mock_litellm_embedding): class TestKBAutoBatch(BaseTestKB): - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_no_autobatch(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_no_autobatch(self, mock_embedding): + set_embedding(mock_embedding) df = self._get_ral_table() self.save_file("ral", df) @@ -1230,9 +1228,9 @@ def test_no_autobatch(self, mock_litellm_embedding): ret = self.run_sql("select * from kb_ral limit 1") assert len(ret) == 1 - @patch("mindsdb.integrations.handlers.litellm_handler.litellm_handler.embedding") - def test_autobatch(self, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_autobatch(self, mock_embedding): + set_embedding(mock_embedding) df = self._get_ral_table() self.save_file("ral", df) diff --git a/tests/unit/executor/test_lowercase.py b/tests/unit/executor/test_lowercase.py index e5ea99bee69..f5fd71efe30 100644 --- a/tests/unit/executor/test_lowercase.py +++ b/tests/unit/executor/test_lowercase.py @@ -4,7 +4,7 @@ import pandas as pd from tests.unit.executor_test_base import BaseExecutorDummyML -from tests.unit.executor.test_agent import set_litellm_embedding +from tests.unit.executor.test_agent import set_embedding class TestLowercase(BaseExecutorDummyML): @@ -204,10 +204,10 @@ def test_agent_name_lowercase(self): self.run_sql(f"drop agent `{another_agent_name}`") self.run_sql(f"drop agent {another_agent_name}") - @patch("litellm.embedding") + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") @patch("openai.OpenAI") - def test_knowledgebase_name_lowercase(self, mock_openai, mock_litellm_embedding): - set_litellm_embedding(mock_litellm_embedding) + def test_knowledgebase_name_lowercase(self, mock_openai, mock_embedding): + set_embedding(mock_embedding) self.run_sql(""" create database my_kb_storage diff --git a/tests/unit/handlers/test_bigquery.py b/tests/unit/handlers/test_bigquery.py index 37eb80cb75e..448af57d609 100644 --- a/tests/unit/handlers/test_bigquery.py +++ b/tests/unit/handlers/test_bigquery.py @@ -6,7 +6,6 @@ from google.api_core.exceptions import BadRequest from mindsdb.integrations.libs.response import ( - DataHandlerResponse, HandlerStatusResponse as StatusResponse, RESPONSE_TYPE, TableResponse, From 9dd046c5e8e4db3446f877c54b6ad56b91e1ab3a Mon Sep 17 00:00:00 2001 From: Andrey Date: Wed, 25 Mar 2026 17:04:10 +0300 Subject: [PATCH 086/125] Check agent parameters (#12291) Co-authored-by: martyna-mindsdb <109554435+martyna-mindsdb@users.noreply.github.com> --- docs/mindsdb_sql/agents/agent_syntax.mdx | 9 +- docs/sdks/python/agents.mdx | 6 + mindsdb/api/executor/command_executor.py | 6 +- .../datahub/datanodes/mindsdb_tables.py | 21 +- mindsdb/api/http/namespaces/agents.py | 35 +-- .../interfaces/agents/agents_controller.py | 217 ++++++++---------- mindsdb/interfaces/agents/modes/base.py | 4 + .../interfaces/agents/pydantic_ai_agent.py | 8 +- .../agents/utils/pydantic_ai_model_factory.py | 2 +- .../interfaces/agents/utils/sql_toolkit.py | 8 +- .../interfaces/knowledge_base/controller.py | 27 +-- mindsdb/utilities/utils.py | 23 ++ tests/unit/api/http/agents_test.py | 54 +++-- tests/unit/executor/test_agent.py | 212 +++++++---------- tests/unit/executor/test_lowercase.py | 15 +- tests/unit/executor/test_schema.py | 9 +- .../interfaces/agents/test_generic_api_key.py | 97 +------- .../agents/test_api_key_handling.py | 73 +++--- 18 files changed, 350 insertions(+), 476 deletions(-) diff --git a/docs/mindsdb_sql/agents/agent_syntax.mdx b/docs/mindsdb_sql/agents/agent_syntax.mdx index 9c1d3c01563..42597689170 100644 --- a/docs/mindsdb_sql/agents/agent_syntax.mdx +++ b/docs/mindsdb_sql/agents/agent_syntax.mdx @@ -24,7 +24,8 @@ USING "tables": ["datasource_conn_name.table_name", ...] }, prompt_template='describe data', - timeout=10; + timeout=10, + mode='text'; ``` It creates an agent that uses the defined model and has access to the connected data. @@ -315,6 +316,12 @@ This parameter defines the time the agent can take to come back with an answer. For example, when the `timeout` parameter is set to 10, the agent has 10 seconds to return an answer. If the agent takes longer than 10 seconds, it aborts the process and comes back with an answer indicating its failure to return an answer within the defined time interval. +### `mode` + +This parameter defines the agent's response style, allowing users to partially control the output format. Supported values include `text` and `sql`. + +When set, the agent will tailor its responses to match the specified format. Note that the agent may still adapt its output when necessary to ensure clarity or correctness. + ## `SELECT FROM AGENT` Syntax Query an agent to generate responses to questions. diff --git a/docs/sdks/python/agents.mdx b/docs/sdks/python/agents.mdx index f5cf58e11ce..b6c170b90fc 100644 --- a/docs/sdks/python/agents.mdx +++ b/docs/sdks/python/agents.mdx @@ -307,6 +307,12 @@ This parameter defines the time the agent can take to come back with an answer. For example, when the `timeout` parameter is set to 10, the agent has 10 seconds to return an answer. If the agent takes longer than 10 seconds, it aborts the process and comes back with an answer indicating its failure to return an answer within the defined time interval. +### `mode` + +This parameter defines the agent's response style, allowing users to partially control the output format. Supported values include `text` and `sql`. + +When set, the agent will tailor its responses to match the specified format. Note that the agent may still adapt its output when necessary to ensure clarity or correctness. + ## Get Agents You can get an existing agent with the `get()` method. diff --git a/mindsdb/api/executor/command_executor.py b/mindsdb/api/executor/command_executor.py index deacf21c0cf..25d8858f458 100644 --- a/mindsdb/api/executor/command_executor.py +++ b/mindsdb/api/executor/command_executor.py @@ -1484,13 +1484,11 @@ def answer_drop_kb(self, statement: DropKnowledgeBase, database_name: str) -> Ex def answer_create_agent(self, statement, database_name): project_name, name = match_two_part_name(statement.name, default_db_name=database_name) - provider = statement.params.pop("provider", None) try: _ = self.session.agents_controller.add_agent( name=name, project_name=project_name, - model_name=statement.model, - provider=provider, + model=statement.model, params=variables_controller.fill_parameters(statement.params), ) except EntityExistsError as e: @@ -1521,7 +1519,7 @@ def answer_update_agent(self, statement: UpdateAgent, database_name: str): _ = self.session.agents_controller.update_agent( name, project_name=project_name, - model_name=model, + model=model, params=variables_controller.fill_parameters(statement.params), ) except (EntityExistsError, EntityNotExistsError, ValueError) as e: diff --git a/mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py b/mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py index 8c274873465..b7fd38e3b3a 100644 --- a/mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py +++ b/mindsdb/api/executor/datahub/datanodes/mindsdb_tables.py @@ -381,7 +381,7 @@ def get_data(cls, query: ASTNode = None, inf_schema=None, **kwargs): class AgentsTable(MdbTable): name = "AGENTS" - columns = ["NAME", "PROJECT", "MODEL_NAME", "PARAMS"] + columns = ["NAME", "PROJECT", "MODEL", "PARAMS"] @classmethod def get_data(cls, query: ASTNode = None, inf_schema=None, **kwargs): @@ -394,15 +394,18 @@ def get_data(cls, query: ASTNode = None, inf_schema=None, **kwargs): project_names = {i.id: i.name for i in project_controller.get_list()} # NAME, PROJECT, MODEL, PARAMS (skills removed) - data = [ - ( - a.name, - project_names[a.project_id], - a.model_name, - to_json(a.params), + data = [] + for a in all_agents: + params = a.params or {} + model = params.pop("model", {}) + data.append( + [ + a.name, + project_names[a.project_id], + to_json(model), + to_json(params), + ] ) - for a in all_agents - ] return pd.DataFrame(data, columns=cls.columns) diff --git a/mindsdb/api/http/namespaces/agents.py b/mindsdb/api/http/namespaces/agents.py index a57c4c09dc2..3a7d6612499 100644 --- a/mindsdb/api/http/namespaces/agents.py +++ b/mindsdb/api/http/namespaces/agents.py @@ -28,14 +28,16 @@ def create_agent(project_name, name, agent): if name is None: return http_error(HTTPStatus.BAD_REQUEST, "Missing field", 'Missing "name" field for agent') - model_name = agent.get("model_name") - provider = agent.get("provider") - params = agent.get("params", {}) + if agent.get("model"): + model = agent["model"] + elif "model_name" in agent: + model = {"model_name": agent.get("model_name"), "provider": agent.get("provider")} + else: + model = None + if agent.get("data"): params["data"] = agent["data"] - if agent.get("model"): - params["model"] = agent["model"] if agent.get("prompt_template"): params["prompt_template"] = agent["prompt_template"] @@ -54,23 +56,21 @@ def create_agent(project_name, name, agent): ) try: - created_agent = agents_controller.add_agent( - name=name, project_name=project_name, model_name=model_name, provider=provider, params=params - ) + created_agent = agents_controller.add_agent(name=name, project_name=project_name, model=model, params=params) return created_agent.as_dict(), HTTPStatus.CREATED except (ValueError, EntityExistsError): # Model doesn't exist. return http_error( HTTPStatus.NOT_FOUND, "Resource not found", - f'The model "{model_name}" does not exist. Please ensure that the name is correct and try again.', + f'The model "{model}" does not exist. Please ensure that the name is correct and try again.', ) except NotImplementedError: # Free users trying to create agent. return http_error( HTTPStatus.UNAUTHORIZED, "Unavailable to free users", - f'The model "{model_name}" does not exist. Please ensure that the name is correct and try again.', + f'The model "{model}" does not exist. Please ensure that the name is correct and try again.', ) @@ -174,13 +174,17 @@ def put(self, project_name, agent_name): # Update try: - model_name = agent.get("model_name", None) - provider = agent.get("provider") params = agent.get("params", {}) + + if agent.get("model"): + model = agent["model"] + elif "model_name" in agent: + model = {"model_name": agent.get("model_name"), "provider": agent.get("provider")} + else: + model = None + if agent.get("data"): params["data"] = agent["data"] - if agent.get("model"): - params["model"] = agent["model"] if agent.get("prompt_template"): params["prompt_template"] = agent["prompt_template"] @@ -188,8 +192,7 @@ def put(self, project_name, agent_name): agent_name, project_name=project_name, name=name, - model_name=model_name, - provider=provider, + model=model, params=params, ) diff --git a/mindsdb/interfaces/agents/agents_controller.py b/mindsdb/interfaces/agents/agents_controller.py index 90e809ab568..504c2891af6 100644 --- a/mindsdb/interfaces/agents/agents_controller.py +++ b/mindsdb/interfaces/agents/agents_controller.py @@ -1,7 +1,9 @@ import datetime -from typing import Dict, Iterator, List, Union, Tuple, Optional, Any +from typing import Dict, Iterator, List, Union, Tuple, Optional, Any, Text import copy +from enum import Enum +from pydantic import BaseModel from sqlalchemy.orm.attributes import flag_modified from sqlalchemy import null import pandas as pd @@ -13,19 +15,58 @@ from mindsdb.interfaces.model.functions import PredictorRecordNotFound from mindsdb.interfaces.model.model_controller import ModelController from mindsdb.utilities.config import config +from mindsdb.utilities.utils import validate_pydantic_params from mindsdb.utilities import log +from mindsdb.interfaces.agents.utils.sql_toolkit import MindsDBQuery from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError from .utils.constants import ASSISTANT_COLUMN, SUPPORTED_PROVIDERS, PROVIDER_TO_MODELS from .utils.pydantic_ai_model_factory import get_llm_provider - +from .pydantic_ai_agent import check_agent_llm logger = log.getLogger(__name__) default_project = config.get("default_project") +def check_agent_data(data): + tables = data.get("tables", []) + knowledge_bases = data.get("knowledge_bases", []) + if tables or knowledge_bases: + sql_toolkit = MindsDBQuery(tables=tables, knowledge_bases=knowledge_bases) + + if tables and len(sql_toolkit.get_usable_table_names(lazy=False)) == 0: + raise ValueError(f"No tables found: {tables}") + + if knowledge_bases and len(sql_toolkit.get_usable_knowledge_base_names(lazy=False)) == 0: + raise ValueError(f"No knowledge bases found: {knowledge_bases}") + + +class AgentParamsData(BaseModel): + knowledge_bases: List[str] | None = None + tables: List[str] | None = None + + class Config: + extra = "forbid" + + +class AgentMode(Enum): + TEXT = "text" + SQL = "sql" + + +class AgentParams(BaseModel): + prompt_template: str | None = None + model: Dict[Text, Any] | None = None + data: AgentParamsData | None = None + timeout: int | None = None + mode: AgentMode = AgentMode.TEXT + + class Config: + extra = "forbid" + + class AgentsController: """Handles CRUD operations at the database level for Agents""" @@ -149,8 +190,7 @@ def add_agent( self, name: str, project_name: str = None, - model_name: Union[str, dict] = None, - provider: str = None, + model: dict = None, params: Dict[str, Any] = None, ) -> db.Agents: """ @@ -159,25 +199,16 @@ def add_agent( Parameters: name (str): The name of the new agent project_name (str): The containing project - model_name (str | dict): The name of the existing ML model the agent will use - provider (str): The provider of the model + model: Dict, parameters for the model to use + - provider: The provider of the model (e.g., 'openai', 'google') + - Other model-specific parameters like 'api_key', 'model_name', etc. + params (Dict[str, str]): Parameters to use when running the agent data: Dict, data sources for an agent, keys: - knowledge_bases: List of KBs to use - tables: list of tables to use - model: Dict, parameters for the model to use - - provider: The provider of the model (e.g., 'openai', 'google') - - Other model-specific parameters like 'api_key', 'model_name', etc. _api_key: API key for the provider (e.g., openai_api_key) - # Deprecated parameters: - database: The database to use (default is 'mindsdb') - knowledge_base_database: The database to use for knowledge base queries (default is 'mindsdb') - include_tables: List of tables to include - ignore_tables: List of tables to ignore - include_knowledge_bases: List of knowledge bases to include - ignore_knowledge_bases: List of knowledge bases to ignore - Returns: agent (db.Agents): The created agent @@ -195,61 +226,19 @@ def add_agent( # No need to copy params since we're not preserving the original reference params = params or {} + params["model"] = model - if isinstance(model_name, dict): - # move into params - params["model"] = model_name - model_name = None + # check agent params + validate_pydantic_params(params, AgentParams, "agent") - if model_name is not None: - _, provider = self.check_model_provider(model_name, provider) + # check llm works + llm_params = self.get_agent_llm_params(model) + check_agent_llm(llm_params) - if model_name is None: - logger.warning("'model_name' param is not provided. Using default global llm model at runtime.") - - # If model_name is not provided, we use default global llm model at runtime - # Default parameters will be applied at runtime via get_agent_llm_params - # This allows global default updates to apply to all agents immediately - - # Extract API key if provided in the format _api_key - if provider is not None: - provider_api_key_param = f"{provider.lower()}_api_key" - if provider_api_key_param in params: - # Keep the API key in params for the agent to use - # It will be picked up by get_api_key() in handler_utils.py - pass - - # Handle generic api_key parameter if provided - if "api_key" in params: - # Keep the generic API key in params for the agent to use - # It will be picked up by get_api_key() in handler_utils.py - pass - - depreciated_params = [ - "database", - "knowledge_base_database", - "include_tables", - "ignore_tables", - "include_knowledge_bases", - "ignore_knowledge_bases", - ] - if any(param in params for param in depreciated_params): - raise ValueError( - f"Parameters {', '.join(depreciated_params)} are deprecated. " - "Use 'data' parameter with 'tables' and 'knowledge_bases' keys instead." - ) - - include_tables = None - include_knowledge_bases = None - if "data" in params: - include_knowledge_bases = params["data"].get("knowledge_bases") - include_tables = params["data"].get("tables") - - # Convert string parameters to lists if needed - if isinstance(include_tables, str): - include_tables = [t.strip() for t in include_tables.split(",")] - if isinstance(include_knowledge_bases, str): - include_knowledge_bases = [kb.strip() for kb in include_knowledge_bases.split(",")] + # check data + data = params.get("data", {}) + if data: + check_agent_data(data) agent = db.Agents( name=name, @@ -257,8 +246,6 @@ def add_agent( company_id=ctx.company_id, user_id=ctx.user_id, user_class=ctx.user_class, - model_name=model_name, - provider=provider, params=params, ) @@ -272,9 +259,8 @@ def update_agent( agent_name: str, project_name: str = default_project, name: str = None, - model_name: Union[str, dict] = None, - provider: str = None, - params: Dict[str, str] = None, + model: dict = None, + params: Dict[str, Any] = None, ): """ Updates an agent in the database. @@ -283,8 +269,7 @@ def update_agent( agent_name (str): The name of the new agent, or existing agent to update project_name (str): The containing project name (str): The updated name of the agent - model_name (str | dict): The name of the existing ML model the agent will use - provider (str): The provider of the model + model dict: model parameters params: (Dict[str, str]): Parameters to use when running the agent Returns: @@ -301,12 +286,7 @@ def update_agent( existing_params = existing_agent.params or {} is_demo = (existing_agent.params or {}).get("is_demo", False) - if is_demo and ( - (name is not None and name != agent_name) - or (model_name is not None and existing_agent.model_name != model_name) - or (provider is not None and existing_agent.provider != provider) - or (isinstance(params, dict) and len(params) > 0 and "prompt_template" not in params) - ): + if is_demo: raise ValueError("It is forbidden to change properties of the demo object") if name is not None and name != agent_name: @@ -316,27 +296,34 @@ def update_agent( raise EntityExistsError(f"Agent with updated name already exists: {name}") existing_agent.name = name - if model_name or provider: - if isinstance(model_name, dict): - # move into params - existing_params["model"] = model_name - model_name = None - - # check model and provider - _, provider = self.check_model_provider(model_name, provider) - # Update model and provider - existing_agent.model_name = model_name - existing_agent.provider = provider - - if params is not None: - # Merge params on update - existing_params.update(params) - # Remove None values entirely. - params = {k: v for k, v in existing_params.items() if v is not None} - existing_agent.params = params - # Some versions of SQL Alchemy won't handle JSON updates correctly without this. - # See: https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.attributes.flag_modified - flag_modified(existing_agent, "params") + params = params or {} + + if model: + params["model"] = model + + if params: + validate_pydantic_params(params, AgentParams, "agent") + else: + # do nothing + return existing_agent + + if model: + # check llm works + llm_params = self.get_agent_llm_params(model) + check_agent_llm(llm_params) + + data = params.get("data", {}) + if data: + check_agent_data(data) + + # Merge params on update + existing_params.update(params) + # Remove None values entirely. + params = {k: v for k, v in existing_params.items() if v is not None} + existing_agent.params = params + # Some versions of SQL Alchemy won't handle JSON updates correctly without this. + # See: https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.attributes.flag_modified + flag_modified(existing_agent, "params") db.session.commit() return existing_agent @@ -362,32 +349,12 @@ def delete_agent(self, agent_name: str, project_name: str = default_project): agent.deleted_at = datetime.datetime.now() db.session.commit() - def get_agent_llm_params(self, agent): + def get_agent_llm_params(self, model_params): """ Get agent LLM parameters by combining default config with user provided parameters. Uses the same pattern as knowledge bases get_model_params function. """ - agent_params = agent.params - - # Get model params from agent params (same structure as knowledge bases) - if "model" in agent_params: - model_params = agent_params.get("model", {}) - if not isinstance(model_params, dict): - raise ValueError("Model parameters must be passed as a JSON object") - else: - # params for LLM can be arbitrary (backward compatibility) - model_params = copy.deepcopy(agent_params) - model_params.pop("mode", None) - model_params.pop("prompt_template", None) - - _, provider = self.check_model_provider(agent.model_name, agent.provider) - - if agent.model_name is not None: - model_params["model_name"] = agent.model_name - if provider is not None: - model_params["provider"] = provider - combined_model_params = copy.deepcopy(config.get("default_llm", {})) if model_params: @@ -433,7 +400,7 @@ def get_completion( from .pydantic_ai_agent import PydanticAIAgent # Get agent parameters and combine with default LLM parameters at runtime - llm_params = self.get_agent_llm_params(agent) + llm_params = self.get_agent_llm_params(agent.params.get("model")) pydantic_agent = PydanticAIAgent(agent, llm_params=llm_params) diff --git a/mindsdb/interfaces/agents/modes/base.py b/mindsdb/interfaces/agents/modes/base.py index 97376b2a2af..1ec13d9242c 100644 --- a/mindsdb/interfaces/agents/modes/base.py +++ b/mindsdb/interfaces/agents/modes/base.py @@ -8,6 +8,10 @@ class PlanResponse(BaseModel): estimated_steps: int = Field(..., description="Estimated number of steps needed to solve the question") +class TestResponse(BaseModel): + text: str = Field(..., description="Text response to the user") + + class ResponseType: FINAL_QUERY = "final_query" # this is the final query EXPLORATORY = "exploratory_query" # this is a query to explore and collect info to solve the challenge (e.g., distinct values of a categorical column, schema inference, etc.) diff --git a/mindsdb/interfaces/agents/pydantic_ai_agent.py b/mindsdb/interfaces/agents/pydantic_ai_agent.py index 2ec7a87a189..7a38c57b165 100644 --- a/mindsdb/interfaces/agents/pydantic_ai_agent.py +++ b/mindsdb/interfaces/agents/pydantic_ai_agent.py @@ -25,7 +25,7 @@ from mindsdb.utilities.context import context as ctx from mindsdb.utilities.langfuse import LangfuseClientWrapper from mindsdb.interfaces.agents.modes import sql as sql_mode, text_sql as text_sql_mode -from mindsdb.interfaces.agents.modes.base import ResponseType, PlanResponse +from mindsdb.interfaces.agents.modes.base import ResponseType, PlanResponse, TestResponse logger = log.getLogger(__name__) DEBUG_LOGGER = logger.debug @@ -65,6 +65,12 @@ def wrapper(self, messages, *args, **kwargs): return decorator +def check_agent_llm(llm_params): + model = get_model_instance_from_kwargs(llm_params) + agent = Agent(model, output_type=TestResponse) + agent.run_sync("Say 'hi'") + + class PydanticAIAgent: """Pydantic AI-based agent to replace LangchainAgent""" diff --git a/mindsdb/interfaces/agents/utils/pydantic_ai_model_factory.py b/mindsdb/interfaces/agents/utils/pydantic_ai_model_factory.py index 8189542312e..aa72d5e2f99 100644 --- a/mindsdb/interfaces/agents/utils/pydantic_ai_model_factory.py +++ b/mindsdb/interfaces/agents/utils/pydantic_ai_model_factory.py @@ -52,7 +52,7 @@ def get_llm_provider(args: Dict) -> str: return "writer" # For vLLM, require explicit provider specification - raise ValueError("Invalid model name. Please define a supported llm provider") + raise ValueError(f"Invalid model name: {model_name}. Please define a supported llm provider") def get_embedding_model_provider(args: Dict) -> str: diff --git a/mindsdb/interfaces/agents/utils/sql_toolkit.py b/mindsdb/interfaces/agents/utils/sql_toolkit.py index 1468bc2fb75..502a08f5be8 100644 --- a/mindsdb/interfaces/agents/utils/sql_toolkit.py +++ b/mindsdb/interfaces/agents/utils/sql_toolkit.py @@ -281,11 +281,11 @@ def _check_f(node, is_table=None, **kwargs): query_traversal(ast_query, _check_f) - def get_usable_table_names(self): + def get_usable_table_names(self, lazy=True): if not self.tables: # no tables allowed return [] - if not self.tables.has_wildcard: + if not self.tables.has_wildcard and lazy: return self.tables.items result_tables = [] @@ -330,11 +330,11 @@ def get_usable_table_names(self): return result_tables - def get_usable_knowledge_base_names(self): + def get_usable_knowledge_base_names(self, lazy=True): if not self.knowledge_bases: # no tables allowed return [] - if not self.knowledge_bases.has_wildcard: + if not self.knowledge_bases.has_wildcard and lazy: return self.knowledge_bases.items try: diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index a85571edbd3..aa32a46419e 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -6,7 +6,7 @@ import pandas as pd import numpy as np -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from sqlalchemy.orm.attributes import flag_modified from mindsdb_sql_parser.ast import BinaryOperation, Constant, Identifier, Select, Update, Delete, Star @@ -37,6 +37,7 @@ from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs from mindsdb.utilities.config import config from mindsdb.utilities.context import context as ctx +from mindsdb.utilities.utils import validate_pydantic_params from mindsdb.interfaces.agents.utils.pydantic_ai_model_factory import get_llm_provider from mindsdb.interfaces.knowledge_base.llm_wrapper import create_chat_model @@ -1164,26 +1165,6 @@ class KnowledgeBaseController: def __init__(self, session) -> None: self.session = session - def _check_kb_input_params(self, params): - # check names and types KB params - try: - KnowledgeBaseInputParams.model_validate(params) - except ValidationError as e: - problems = [] - for error in e.errors(): - parameter = ".".join([str(i) for i in error["loc"]]) - param_type = error["type"] - if param_type == "extra_forbidden": - msg = f"Parameter '{parameter}' is not allowed" - else: - msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}" - problems.append(msg) - - msg = "\n".join(problems) - if len(problems) > 1: - msg = "\n" + msg - raise ValueError(f"Problem with knowledge base parameters: {msg}") from e - def add( self, name: str, @@ -1208,7 +1189,7 @@ def add( params = params or {} params["preprocessing"] = preprocessing_config - self._check_kb_input_params(params) + validate_pydantic_params(params, KnowledgeBaseInputParams, "knowledge base") # Check if vector_size is provided when using sparse vectors is_sparse = params.get("is_sparse") @@ -1371,7 +1352,7 @@ def update( params = params or {} params["preprocessing"] = preprocessing_config - self._check_kb_input_params(params) + validate_pydantic_params(params, KnowledgeBaseInputParams, "knowledge base") # get project id project = self.session.database_controller.get_project(project_name) diff --git a/mindsdb/utilities/utils.py b/mindsdb/utilities/utils.py index 3c9bd09162c..160b03fe79c 100644 --- a/mindsdb/utilities/utils.py +++ b/mindsdb/utilities/utils.py @@ -2,6 +2,8 @@ import re import typing +from pydantic import BaseModel, ValidationError + def parse_csv_attributes(csv_attributes: typing.Optional[str] = "") -> typing.Dict[str, str]: """ @@ -32,3 +34,24 @@ def parse_csv_attributes(csv_attributes: typing.Optional[str] = "") -> typing.Di raise ValueError(f"Failed to parse csv_attributes='{csv_attributes}': {e}") from e return attributes + + +def validate_pydantic_params(params: dict, schema: type[BaseModel], subject: str): + # check names and types + try: + schema.model_validate(params) + except ValidationError as e: + problems = [] + for error in e.errors(): + parameter = ".".join([str(i) for i in error["loc"]]) + param_type = error["type"] + if param_type == "extra_forbidden": + msg = f"Parameter '{parameter}' is not allowed" + else: + msg = f"Error in '{parameter}' (type: {param_type}): {error['msg']}. Input: {repr(error['input'])}" + problems.append(msg) + + msg = "\n".join(problems) + if len(problems) > 1: + msg = "\n" + msg + raise ValueError(f"Problem with {subject} parameters: {msg}") from e diff --git a/tests/unit/api/http/agents_test.py b/tests/unit/api/http/agents_test.py index a6253132384..bd2532bcd78 100644 --- a/tests/unit/api/http/agents_test.py +++ b/tests/unit/api/http/agents_test.py @@ -27,14 +27,13 @@ def test_prepare(client): @pytest.mark.deprecated( "MindsDB models are no longer used with agents. However, Minds still uses models, so this test is kept for now" ) -def test_post_agent_depreciated(client): +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") +def test_post_agent_depreciated(check_agent_llm, client): create_request = { "agent": { "name": "test_post_agent_depreciated", - "model_name": "test_model", - "params": {"k1": "v1"}, - "provider": "mindsdb", - "skills": ["test_skill"], + "model": {"provider": "openai", "model_name": "test_model"}, + "params": {"timeout": 10}, } } @@ -45,9 +44,8 @@ def test_post_agent_depreciated(client): expected_agent = { "name": "test_post_agent_depreciated", - "model_name": "test_model", - "provider": "mindsdb", - "params": {"k1": "v1"}, + "model": {"provider": "openai", "model_name": "test_model"}, + "params": {"timeout": 10}, "id": created_agent["id"], "project_id": created_agent["project_id"], "created_at": created_agent["created_at"], @@ -57,7 +55,9 @@ def test_post_agent_depreciated(client): assert created_agent == expected_agent -def test_post_agent(client): +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_data") +def test_post_agent(check_agent_data, check_agent_llm, client): create_request = { "agent": { "name": "TEST_post_agent", @@ -161,7 +161,9 @@ def test_get_agents_project_not_found(client): assert get_response.status_code == HTTPStatus.NOT_FOUND -def test_get_agent(client): +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_data") +def test_get_agent(check_agent_data, check_agent_llm, client): create_request = { "agent": { "name": "test_get_agent", @@ -236,13 +238,13 @@ def test_get_agent_project_not_found(client): @pytest.mark.deprecated( "MindsDB models are no longer used with agents. However, Minds still uses models, so this test is kept for now" ) -def test_put_agent_update_depreciated(client): +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") +def test_put_agent_update_depreciated(check_agent_llm, client): create_request = { "agent": { "name": "test_put_agent_update_depreciated", - "model_name": "test_model", - "params": {"k1": "v1", "k2": "v2"}, - "provider": "mindsdb", + "model": {"provider": "openai", "model_name": "test_model"}, + "params": {"timeout": 10}, } } @@ -251,7 +253,7 @@ def test_put_agent_update_depreciated(client): update_request = { "agent": { - "params": {"k1": "v1.1", "k2": None, "k3": "v3"}, + "params": {"timeout": 20}, } } @@ -262,9 +264,8 @@ def test_put_agent_update_depreciated(client): expected_agent = { "name": "test_put_agent_update_depreciated", - "model_name": "test_model", - "params": {"k1": "v1.1", "k3": "v3"}, - "provider": "mindsdb", + "model": {"provider": "openai", "model_name": "test_model"}, + "params": {"timeout": 20}, "id": updated_agent["id"], "project_id": updated_agent["project_id"], "created_at": updated_agent["created_at"], @@ -277,7 +278,9 @@ def test_put_agent_update_depreciated(client): @pytest.mark.deprecated( "MindsDB models are no longer used with agents. However, Minds still uses models, so this test is kept for now" ) -def test_put_agent_update(client): +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_data") +def test_put_agent_update(check_agent_data, check_agent_llm, client): create_request = { "agent": { "name": "test_put_agent_update", @@ -292,7 +295,7 @@ def test_put_agent_update(client): update_request = { "agent": { - "params": {"k1": "v1.1", "k2": None, "k3": "v3"}, + "params": {"timeout": 5}, "data": { "tables": ["example_db.customers", "example_db.orders"], "knowledge_bases": ["example_kb"], @@ -307,7 +310,7 @@ def test_put_agent_update(client): expected_agent = { "name": "test_put_agent_update", - "params": {"k1": "v1.1", "k3": "v3"}, + "params": {"timeout": 5}, "id": updated_agent["id"], "project_id": updated_agent["project_id"], "created_at": updated_agent["created_at"], @@ -356,7 +359,9 @@ def test_put_agent_no_agent(client): # assert '404' in response.status -def test_delete_agent(client): +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_data") +def test_delete_agent(check_agent_data, check_agent_llm, client): create_request = { "agent": { "name": "test_delete_agent", @@ -385,13 +390,14 @@ def test_delete_agent_not_found(client): assert delete_response.status_code == HTTPStatus.NOT_FOUND -def test_agent_completions(client): +@patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") +def test_agent_completions(check_agent_llm, client): create_request = { "agent": { "name": "test_agent", "model_name": "test_model", "provider": "mindsdb", - "params": {"prompt_template": "Test message!", "user_column": "content"}, + "params": {"prompt_template": "Test message!"}, } } diff --git a/tests/unit/executor/test_agent.py b/tests/unit/executor/test_agent.py index c1551b1d3fe..a41d36f0a6c 100644 --- a/tests/unit/executor/test_agent.py +++ b/tests/unit/executor/test_agent.py @@ -1,9 +1,9 @@ -import time import os import json from unittest.mock import patch, AsyncMock +from sqlalchemy.orm.attributes import flag_modified import pandas as pd import pytest import sys @@ -18,16 +18,19 @@ def action_response(type="final_query", sql="", text=""): return json.dumps({"sql_query": sql, "type": type, "text": text, "short_description": "a tool"}) -def set_openai_completion(mock_openai, llm_response): +def set_openai_completion(mock_openai, llm_response, add_planning=True): if isinstance(llm_response, str): llm_responses = [ action_response(sql=f"select '{llm_response}' as answer"), ] + elif not isinstance(llm_response, list): + llm_responses = [llm_response] else: llm_responses = llm_response - # always add plan response - llm_responses.insert(0, '{"plan":"my plan is ...", "estimated_steps":3}') + if add_planning: + # add plan response + llm_responses.insert(0, '{"plan":"my plan is ...", "estimated_steps":3}') mock_openai.agent_calls = [] calls = [] @@ -104,84 +107,10 @@ def setup_method(self): config["knowledge_bases"]["disable_autobatch"] = True - @pytest.mark.slow - def unused_test_mindsdb_provider(self): - # pydantic agent doesn't support using mindsdb model - from mindsdb.api.executor.exceptions import ExecutorException - - agent_response = "how can I help you" - # model - self.run_sql( - f""" - CREATE model base_model - PREDICT output - using - column='question', - output='{agent_response}', - engine='dummy_ml', - join_learn_process=true - """ - ) - - self.run_sql("CREATE ML_ENGINE langchain FROM langchain") - - agent_params = """ - USING - provider='mindsdb', - model = "base_model", -- < - prompt_template="Answer the user input in a helpful way" - """ - self.run_sql(f""" - CREATE AGENT my_agent {agent_params} - """) - with pytest.raises(ExecutorException): - self.run_sql(f""" - CREATE AGENT my_agent {agent_params} - """) - self.run_sql(f""" - CREATE AGENT IF NOT EXISTS my_agent {agent_params} - """) - - ret = self.run_sql("select * from my_agent where question = 'hi'") - - assert agent_response in ret.answer[0] - - @pytest.mark.skipif( - sys.platform in ["darwin", "win32"], reason="Mocking doesn't work on Windows or macOS for some reason" - ) - @patch("openai.OpenAI") - def unused_test_openai_provider_with_model(self, mock_openai): - # pydantic agent doesn't support using mindsdb model - - agent_response = "how can I assist you today?" - set_openai_completion(mock_openai, agent_response) - - self.run_sql("CREATE ML_ENGINE langchain FROM langchain") - - self.run_sql(""" - CREATE MODEL lang_model - PREDICT answer USING - engine = "langchain", - model = "gpt-3.5-turbo", - openai_api_key='--', - prompt_template="Answer the user input in a helpful way"; - """) - - time.sleep(5) - - self.run_sql(""" - CREATE AGENT my_agent - USING - model='lang_model' - """) - ret = self.run_sql("select * from my_agent where question = 'hi'") - - assert agent_response in ret.answer[0] - @patch("pydantic_ai.providers.openai.AsyncOpenAI") def test_openai_provider(self, mock_openai): - agent_response = "how can I assist you today?" - set_openai_completion(mock_openai, agent_response) + # test response + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT my_agent @@ -193,6 +122,10 @@ def test_openai_provider(self, mock_openai): }, prompt_template="Answer the user input in a helpful way" """) + + agent_response = "how can I assist you today?" + set_openai_completion(mock_openai, agent_response) + ret = self.run_sql("select * from my_agent where question = 'hi'") # check model params @@ -252,10 +185,8 @@ def config_get_side_effect(key, default=None): mock_config_get.side_effect = config_get_side_effect - agent_response = "how can I assist you today?" - set_openai_completion(mock_openai, agent_response) - # Create an agent with only provider specified - should use default LLM params + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT default_params_agent USING @@ -266,6 +197,8 @@ def config_get_side_effect(key, default=None): }, prompt_template="Answer the user input in a helpful way" """) + agent_response = "how can I assist you today?" + set_openai_completion(mock_openai, agent_response) # Check that the agent was created with the default parameters agent_info = self.run_sql("SELECT * FROM information_schema.agents WHERE name = 'default_params_agent'") @@ -273,7 +206,7 @@ def config_get_side_effect(key, default=None): # Verify the agent has the user-specified parameters but not default parameters agent_params = json.loads(agent_info["PARAMS"].iloc[0]) assert agent_params.get("prompt_template") == "Answer the user input in a helpful way" - assert agent_params["model"]["model_name"] == "gpt-3" + assert "gpt-3" in agent_info["MODEL"][0] # Default parameters should NOT be stored in the database # They will be applied at runtime via get_agent_llm_params @@ -291,19 +224,18 @@ def config_get_side_effect(key, default=None): # --- Test that agent creation works with minimal syntax using default_llm config --- - mock_openai.reset_mock() - agent_response = "how can I assist you today?" - set_openai_completion(mock_openai, agent_response) - # Create an agent with minimal syntax - should use all default LLM params + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT minimal_syntax_agent USING - data = { - "tables": ['test.table1', 'test.table2'] - } + data = { } """) + mock_openai.reset_mock() + agent_response = "how can I assist you today?" + set_openai_completion(mock_openai, agent_response) + ret = self.run_sql("select * from minimal_syntax_agent where question = 'hi'") assert agent_response in ret.answer[0] @@ -314,18 +246,21 @@ def config_get_side_effect(key, default=None): @pytest.mark.skipif(sys.platform == "darwin", reason="Fails on macOS") @patch("pydantic_ai.providers.openai.AsyncOpenAI") def test_agent_stream(self, mock_openai): - agent_response = "how can I assist you today?" - set_openai_completion(mock_openai, agent_response) - + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT my_agent USING - provider='openai', - model = "gpt-3.5-turbo", - openai_api_key='--', + model={ + "model_name": "gpt-3.5-turbo", + "provider": "openai", + "api_key": "--" + }, prompt_template="Answer the user input in a helpful way" """) + agent_response = "how can I assist you today?" + set_openai_completion(mock_openai, agent_response) + agents_controller = self.command_executor.session.agents_controller agent = agents_controller.get_agent("my_agent") @@ -370,16 +305,18 @@ def test_agent_retrieval(self, mock_openai, mock_embedding): os.environ["OPENAI_API_KEY"] = "--" + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" create agent retrieve_agent using - model='gpt-3.5-turbo', - provider='openai', + model={ + "model_name": "gpt-3.5-turbo", + "provider": "openai" + }, prompt_template='Answer the user input in a helpful way using tools', data = { "knowledge_bases": ["kb_review"] - }, - mode='retrieval' + } """) agent_response = "the answer is yes" @@ -413,10 +350,12 @@ def test_agent_retrieval(self, mock_openai, mock_embedding): self._drop_kb_storage(vector_table_name) # should not be possible to drop demo agent - def test_drop_demo_agent(self): + @patch("pydantic_ai.providers.openai.AsyncOpenAI") + def test_drop_demo_agent(self, mock_openai): """should not be possible to drop demo agent""" from mindsdb.api.executor.exceptions import ExecutorException + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT my_demo_agent USING @@ -425,37 +364,49 @@ def test_drop_demo_agent(self): 'model_name': "gpt-3.5-turbo", 'api_key': '-key-' }, - prompt_template="--", - is_demo=true; + prompt_template="--" """) + + # mark as demo in db + agent = self.db.Agents.query.filter_by(name="my_demo_agent").first() + agent.params["is_demo"] = True + flag_modified(agent, "params") + self.db.session.commit() with pytest.raises(ExecutorException): - self.run_sql("drop agent my_agent") + self.run_sql("drop agent my_demo_agent") @patch("pydantic_ai.providers.openai.AsyncOpenAI") def test_agent_default_prompt_template(self, mock_openai): """Test that agents work correctly with default prompt templates in different modes""" - agent_response = "default prompt template response" - set_openai_completion(mock_openai, agent_response) # Test non-retrieval mode with no prompt_template (should use default) + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT default_prompt_agent USING - provider='openai', - model = "gpt-3.5-turbo", - openai_api_key='--' + model={ + "model_name": "gpt-3.5-turbo", + "provider": "openai", + "api_key": "--" + } """) + + agent_response = "default prompt template response" + set_openai_completion(mock_openai, agent_response) + ret = self.run_sql("select * from default_prompt_agent where question = 'test question'") assert agent_response in ret.answer[0] # Test retrieval mode with no prompt_template (should use default retrieval template) + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT default_retrieval_agent USING - provider='openai', - model = "gpt-3.5-turbo", - openai_api_key='--', - mode='retrieval' + model={ + "model_name": "gpt-3.5-turbo", + "provider": "openai", + "api_key": "--" + } """) mock_openai.reset_mock() @@ -491,11 +442,14 @@ def test_agent_permissions(self, mock_embedding, mock_openai): select id, planet_name content from files.show1 """) + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT my_agent USING - model = "gpt-3.5-turbo", - openai_api_key='--', + model={ + "model_name": "gpt-3.5-turbo", + "api_key": '--' + }, data = { "knowledge_bases": ["kb_show*"], "tables": ["files.show*"] @@ -601,6 +555,7 @@ def test_agent_new_syntax(self, mock_embedding, mock_openai): select id, planet_name content from files.file{i} where id != 1000 """) + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT my_agent USING @@ -653,6 +608,7 @@ def test_agent_new_syntax(self, mock_embedding, mock_openai): assert "important user instruction №42" in mock_openai.agent_calls[0] # --- ALTER AGENT --- + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" ALTER AGENT my_agent USING @@ -723,11 +679,14 @@ def test_agent_accept_wrong_quoting(self, mock_embedding, mock_openai): self.save_file("file1", df) + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT my_agent USING - model = "gpt-3.5-turbo", - openai_api_key='--', + model={ + "model_name": "gpt-3.5-turbo", + "api_key": '--' + }, data = { "knowledge_bases": ["kb1"], "tables": ["files.file1", "files.file2.*"] @@ -761,11 +720,14 @@ def test_3_part_table(self, mock_pg, mock_openai): df = get_dataset_planets() self.set_handler(mock_pg, name="pg", tables={"planets": df}, schema="public") + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql(""" CREATE AGENT my_agent USING - model = "gpt-3.5-turbo", - openai_api_key='--', + model={ + "model_name": "gpt-3.5-turbo", + "api_key": '--' + }, data = { "tables": ["pg.public.*"] } @@ -788,19 +750,23 @@ def test_3_part_table(self, mock_pg, mock_openai): assert "Moon" in mock_openai.agent_calls[3] assert "Moon" in mock_openai.agent_calls[4] + @patch("pydantic_ai.providers.openai.AsyncOpenAI") @patch("mindsdb.interfaces.agents.pydantic_ai_agent.PydanticAIAgent._get_completion_stream") - def test_agent_query_param_override(self, mock_get_completion): + def test_agent_query_param_override(self, mock_get_completion, mock_openai): """ Test that agent parameters can be overridden per-query using the USING clause in SELECT. """ mock_get_completion.return_value = [{"type": "data", "content": "-"}] + set_openai_completion(mock_openai, action_response(text="hi"), add_planning=False) self.run_sql( """ CREATE AGENT override_agent USING - model = 'gpt-4o', - openai_api_key = 'sk-override', + model={ + "model_name": "gpt-4o", + "api_key": 'sk-override' + }, prompt_template = 'Answer questions', timeout = 60; """ diff --git a/tests/unit/executor/test_lowercase.py b/tests/unit/executor/test_lowercase.py index f5fd71efe30..d7e9d2a32b0 100644 --- a/tests/unit/executor/test_lowercase.py +++ b/tests/unit/executor/test_lowercase.py @@ -166,13 +166,15 @@ def test_model_name_lowercase(self): self.run_sql(f"DROP MODEL `{another_name}`") self.run_sql(f"DROP MODEL {another_name}") - def test_agent_name_lowercase(self): + @patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") + def test_agent_name_lowercase(self, check_agent_llm): agent_params = """ - model='gpt-3.5-turbo', - provider='openai', + model={ + "model_name": "gpt-3.5-turbo", + "provider": "openai" + }, prompt_template='Answer the user input in a helpful way using tools', - max_iterations=5, - mode='retrieval' + mode='text' """ # mixed case: agent @@ -274,7 +276,8 @@ def test_job_name_lowercase(self): self.run_sql(f"DROP JOB {another_name}") - def test_chatbot_lowercase(self): + @patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") + def test_chatbot_lowercase(self, check_agent_llm): self.run_sql("create agent my_agent using model={'provider': 'openai', 'model_name': 'gpt-3.5'}") self.run_sql("create database my_db using engine='dummy_data'") diff --git a/tests/unit/executor/test_schema.py b/tests/unit/executor/test_schema.py index 8c80177c006..540467a9c78 100644 --- a/tests/unit/executor/test_schema.py +++ b/tests/unit/executor/test_schema.py @@ -12,7 +12,8 @@ def test_show(self): self.run_sql(f"show {item}") @pytest.mark.slow - def test_schema(self): + @patch("mindsdb.interfaces.agents.agents_controller.check_agent_llm") + def test_schema(self, check_agent): # --- create objects + describe --- # todo: create knowledge base (requires chromadb) @@ -91,15 +92,15 @@ def test_schema(self): # agent self.run_sql(""" CREATE AGENT agent1 - USING model = 'pred1' + USING model = {'model_name': "pred1", "provider": "openai"} """) self.run_sql(""" CREATE AGENT proj2.agent2 - USING model = 'pred2' -- it looks up in agent's project + USING model = {'model_name': "pred2", "provider": "openai"} -- it looks up in agent's project """) df = self.run_sql("describe agent agent1") - assert df.NAME[0] == "agent1" and df.MODEL_NAME[0] == "pred1" + assert df.NAME[0] == "agent1" and "pred1" in df.MODEL[0] # chatbot self.run_sql(""" diff --git a/tests/unit/interfaces/agents/test_generic_api_key.py b/tests/unit/interfaces/agents/test_generic_api_key.py index 8198b763a08..3473aa05c70 100644 --- a/tests/unit/interfaces/agents/test_generic_api_key.py +++ b/tests/unit/interfaces/agents/test_generic_api_key.py @@ -1,9 +1,8 @@ import os import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch from mindsdb.integrations.utilities.handler_utils import get_api_key -from mindsdb.interfaces.agents.agents_controller import AgentsController class TestGenericApiKeyHandling(unittest.TestCase): @@ -71,100 +70,6 @@ def test_get_generic_api_key_for_google_provider(self): ) self.assertEqual(api_key, "test-specific-google-api-key") - @patch("mindsdb.interfaces.agents.agents_controller.AgentsController.check_model_provider") - @patch("mindsdb.interfaces.agents.agents_controller.AgentsController.get_agent") - @patch("mindsdb.interfaces.agents.agents_controller.ProjectController") - @patch("mindsdb.interfaces.storage.db.session") - def test_add_agent_with_generic_api_key( - self, mock_session, mock_project_controller, mock_get_agent, mock_check_model_provider - ): - """Test adding an agent with a generic API key in params.""" - # Mock project controller - mock_project = MagicMock() - mock_project_controller.return_value.get.return_value = mock_project - - # Mock get_agent to return None (agent doesn't exist yet) - mock_get_agent.return_value = None - - # Mock check_model_provider to return a provider - mock_check_model_provider.return_value = (None, "openai") - - # Create an instance of AgentsController - agent_controller = AgentsController() - - # Test adding an agent with a generic API key in params - params = {"api_key": "test-generic-agent-api-key", "other_param": "value"} - - # Create a mock agent with proper params - mock_agent = MagicMock() - mock_agent.params = params.copy() # Set params directly - - # Mock db.Agents to return our prepared mock agent - with patch("mindsdb.interfaces.storage.db.Agents", return_value=mock_agent): - # Add the agent - agent = agent_controller.add_agent( - name="test_agent", - project_name="mindsdb", - model_name="gpt-4", - provider="openai", - params=params, - ) - - # Verify that the generic API key was preserved in the params - self.assertEqual(agent.params["api_key"], "test-generic-agent-api-key") - - @patch("mindsdb.interfaces.agents.agents_controller.AgentsController.check_model_provider") - @patch("mindsdb.interfaces.agents.agents_controller.AgentsController.get_agent") - @patch("mindsdb.interfaces.agents.agents_controller.ProjectController") - @patch("mindsdb.interfaces.storage.db.session") - def test_add_agent_with_both_api_keys( - self, mock_session, mock_project_controller, mock_get_agent, mock_check_model_provider - ): - """Test adding an agent with both generic and provider-specific API keys.""" - # Mock project controller - mock_project = MagicMock() - mock_project_controller.return_value.get.return_value = mock_project - - # Mock get_agent to return None (agent doesn't exist yet) - mock_get_agent.return_value = None - - # Mock check_model_provider to return a provider - mock_check_model_provider.return_value = (None, "openai") - - # Create an instance of AgentsController - agent_controller = AgentsController() - - # Test adding an agent with both generic and provider-specific API keys - params = { - "api_key": "test-generic-agent-api-key", - "openai_api_key": "test-specific-agent-api-key", - "other_param": "value", - } - - # Create a mock agent with proper params - mock_agent = MagicMock() - mock_agent.params = params.copy() # Set params directly - - # Mock db.Agents to return our prepared mock agent - with patch("mindsdb.interfaces.storage.db.Agents", return_value=mock_agent): - # Add the agent - agent = agent_controller.add_agent( - name="test_agent", - project_name="mindsdb", - model_name="gpt-4", - provider="openai", - params=params, - ) - - # Verify that both API keys were preserved in the params - self.assertEqual(agent.params["api_key"], "test-generic-agent-api-key") - self.assertEqual(agent.params["openai_api_key"], "test-specific-agent-api-key") - - # Test that get_api_key returns the provider-specific key when both are present - api_key = get_api_key("openai", {"params": params}) - - self.assertEqual(api_key, "test-specific-agent-api-key") - if __name__ == "__main__": unittest.main() diff --git a/tests/unused/unit/interfaces/agents/test_api_key_handling.py b/tests/unused/unit/interfaces/agents/test_api_key_handling.py index 484ba775577..8a45b931d6c 100644 --- a/tests/unused/unit/interfaces/agents/test_api_key_handling.py +++ b/tests/unused/unit/interfaces/agents/test_api_key_handling.py @@ -12,10 +12,9 @@ class TestAgentApiKeyHandling(unittest.TestCase): def setUp(self): """Set up test environment.""" # Mock environment variables - self.env_patcher = patch.dict(os.environ, { - 'OPENAI_API_KEY': 'test-env-api-key', - 'ANTHROPIC_API_KEY': 'test-env-anthropic-key' - }) + self.env_patcher = patch.dict( + os.environ, {"OPENAI_API_KEY": "test-env-api-key", "ANTHROPIC_API_KEY": "test-env-anthropic-key"} + ) self.env_patcher.start() def tearDown(self): @@ -25,43 +24,44 @@ def tearDown(self): def test_get_api_key_from_env(self): """Test retrieving API key from environment variables.""" # Test getting API key from environment variable - api_key = get_api_key('openai', {}) - self.assertEqual(api_key, 'test-env-api-key') + api_key = get_api_key("openai", {}) + self.assertEqual(api_key, "test-env-api-key") def test_get_api_key_from_args(self): """Test retrieving API key from create_args.""" # Test getting API key from create_args - api_key = get_api_key('openai', {'openai_api_key': 'test-args-api-key'}) - self.assertEqual(api_key, 'test-args-api-key') + api_key = get_api_key("openai", {"openai_api_key": "test-args-api-key"}) + self.assertEqual(api_key, "test-args-api-key") def test_get_api_key_from_params(self): """Test retrieving API key from params dictionary.""" # Test getting API key from params dictionary - api_key = get_api_key('openai', {'params': {'openai_api_key': 'test-params-api-key'}}) - self.assertEqual(api_key, 'test-params-api-key') + api_key = get_api_key("openai", {"params": {"openai_api_key": "test-params-api-key"}}) + self.assertEqual(api_key, "test-params-api-key") def test_get_api_key_priority(self): """Test API key retrieval priority.""" # Test that create_args takes priority over environment variables - api_key = get_api_key('openai', {'openai_api_key': 'test-args-api-key'}) - self.assertEqual(api_key, 'test-args-api-key') + api_key = get_api_key("openai", {"openai_api_key": "test-args-api-key"}) + self.assertEqual(api_key, "test-args-api-key") # Test that params takes priority over environment variables - api_key = get_api_key('openai', {'params': {'openai_api_key': 'test-params-api-key'}}) - self.assertEqual(api_key, 'test-params-api-key') + api_key = get_api_key("openai", {"params": {"openai_api_key": "test-params-api-key"}}) + self.assertEqual(api_key, "test-params-api-key") # Test that create_args takes priority over params - api_key = get_api_key('openai', { - 'openai_api_key': 'test-args-api-key', - 'params': {'openai_api_key': 'test-params-api-key'} - }) - self.assertEqual(api_key, 'test-args-api-key') - - @patch('mindsdb.interfaces.agents.agents_controller.AgentsController.check_model_provider') - @patch('mindsdb.interfaces.agents.agents_controller.AgentsController.get_agent') - @patch('mindsdb.interfaces.agents.agents_controller.ProjectController') - @patch('mindsdb.interfaces.storage.db.session') - def test_add_agent_with_api_key(self, mock_session, mock_project_controller, mock_get_agent, mock_check_model_provider): + api_key = get_api_key( + "openai", {"openai_api_key": "test-args-api-key", "params": {"openai_api_key": "test-params-api-key"}} + ) + self.assertEqual(api_key, "test-args-api-key") + + @patch("mindsdb.interfaces.agents.agents_controller.AgentsController.check_model_provider") + @patch("mindsdb.interfaces.agents.agents_controller.AgentsController.get_agent") + @patch("mindsdb.interfaces.agents.agents_controller.ProjectController") + @patch("mindsdb.interfaces.storage.db.session") + def test_add_agent_with_api_key( + self, mock_session, mock_project_controller, mock_get_agent, mock_check_model_provider + ): """Test adding an agent with an API key in params.""" # Mock project controller mock_project = MagicMock() @@ -71,36 +71,31 @@ def test_add_agent_with_api_key(self, mock_session, mock_project_controller, moc mock_get_agent.return_value = None # Mock check_model_provider to return a provider - mock_check_model_provider.return_value = (None, 'openai') + mock_check_model_provider.return_value = (None, "openai") # Create an instance of AgentsController agent_controller = AgentsController() # Test adding an agent with an API key in params - params = { - 'openai_api_key': 'test-agent-api-key', - 'other_param': 'value' - } + params = {"openai_api_key": "test-agent-api-key", "other_param": "value"} # Create a mock agent with proper params mock_agent = MagicMock() mock_agent.params = params.copy() # Set params directly # Mock db.Agents to return our prepared mock agent - with patch('mindsdb.interfaces.storage.db.Agents', return_value=mock_agent): + with patch("mindsdb.interfaces.storage.db.Agents", return_value=mock_agent): # Add the agent agent = agent_controller.add_agent( - name='test_agent', - project_name='mindsdb', - model_name='gpt-4', - skills=[], - provider='openai', - params=params + name="test_agent", + project_name="mindsdb", + model={"model_name": "gpt-4", "provider": "openai"}, + params=params, ) # Verify that the API key was preserved in the params - self.assertEqual(agent.params.get('openai_api_key'), 'test-agent-api-key') + self.assertEqual(agent.params.get("openai_api_key"), "test-agent-api-key") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From 3c4aa12c38bcd442d34b3edd5c50b26de027c68f Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Wed, 25 Mar 2026 17:33:19 +0300 Subject: [PATCH 087/125] MCP server update (#12274) --- docs/model-context-protocol/anthropic.mdx | 2 +- docs/model-context-protocol/openai.mdx | 2 +- docs/model-context-protocol/usage.mdx | 56 ++++-- docs/setup/custom-config.mdx | 32 +++ mindsdb/__main__.py | 15 ++ mindsdb/api/common/middleware.py | 111 +++++++++-- mindsdb/api/http/namespaces/default.py | 5 +- mindsdb/api/http/start.py | 61 ++++-- mindsdb/api/mcp/__init__.py | 183 +----------------- mindsdb/api/mcp/app.py | 94 +++++++++ mindsdb/api/mcp/completions.py | 35 ++++ mindsdb/api/mcp/mcp_instance.py | 36 ++++ mindsdb/api/mcp/oauth.py | 159 +++++++++++++++ mindsdb/api/mcp/prompts/__init__.py | 1 + mindsdb/api/mcp/prompts/sample_table.py | 21 ++ mindsdb/api/mcp/resources/__init__.py | 1 + mindsdb/api/mcp/resources/schema.py | 136 +++++++++++++ mindsdb/api/mcp/tools/__init__.py | 1 + mindsdb/api/mcp/tools/query.py | 60 ++++++ mindsdb/api/mcp/types.py | 25 +++ .../handlers/file_handler/file_handler.py | 21 +- mindsdb/integrations/libs/response.py | 39 ++-- mindsdb/utilities/config.py | 129 +++++++++++- mindsdb/utilities/log.py | 1 + requirements/requirements-agents.txt | 2 +- tests/unit/api/mcp/__init__.py | 0 tests/unit/api/mcp/test_completions.py | 135 +++++++++++++ tests/unit/api/mcp/test_prompts.py | 45 +++++ tests/unit/api/mcp/test_query_tool.py | 129 ++++++++++++ tests/unit/api/mcp/test_resources.py | 177 +++++++++++++++++ tests/unit/handlers/test_file.py | 31 ++- 31 files changed, 1486 insertions(+), 259 deletions(-) create mode 100644 mindsdb/api/mcp/app.py create mode 100644 mindsdb/api/mcp/completions.py create mode 100644 mindsdb/api/mcp/mcp_instance.py create mode 100644 mindsdb/api/mcp/oauth.py create mode 100644 mindsdb/api/mcp/prompts/__init__.py create mode 100644 mindsdb/api/mcp/prompts/sample_table.py create mode 100644 mindsdb/api/mcp/resources/__init__.py create mode 100644 mindsdb/api/mcp/resources/schema.py create mode 100644 mindsdb/api/mcp/tools/__init__.py create mode 100644 mindsdb/api/mcp/tools/query.py create mode 100644 mindsdb/api/mcp/types.py create mode 100644 tests/unit/api/mcp/__init__.py create mode 100644 tests/unit/api/mcp/test_completions.py create mode 100644 tests/unit/api/mcp/test_prompts.py create mode 100644 tests/unit/api/mcp/test_query_tool.py create mode 100644 tests/unit/api/mcp/test_resources.py diff --git a/docs/model-context-protocol/anthropic.mdx b/docs/model-context-protocol/anthropic.mdx index ba8609f5b23..0b594db756a 100644 --- a/docs/model-context-protocol/anthropic.mdx +++ b/docs/model-context-protocol/anthropic.mdx @@ -35,7 +35,7 @@ response = client.beta.messages.create( mcp_servers = [ { "type": "url", - "url": "https://5a52-88-203-84-191.ngrok-free.app/mcp/sse", + "url": "https:///mcp/sse", "name": "mindsdb-mcp", "authorization_token": "" } diff --git a/docs/model-context-protocol/openai.mdx b/docs/model-context-protocol/openai.mdx index c3d8ea2df54..3d9736dde9e 100644 --- a/docs/model-context-protocol/openai.mdx +++ b/docs/model-context-protocol/openai.mdx @@ -32,7 +32,7 @@ response = client.responses.create( { "type": "mcp", "server_label": "mdb", - "server_url": "https://5a52-88-203-84-191.ngrok-free.app/mcp/sse", + "server_url": "https:///mcp/sse", "headers": { "Authorization": "Bearer " }, "require_approval": "never", } diff --git a/docs/model-context-protocol/usage.mdx b/docs/model-context-protocol/usage.mdx index 5f18ac91937..43ed653b8d8 100644 --- a/docs/model-context-protocol/usage.mdx +++ b/docs/model-context-protocol/usage.mdx @@ -29,37 +29,67 @@ Follow the steps below to use MindsDB as an MCP server. ``` -3. Start MindsDB MCP server, either with or without authentication. +3. Start MindsDB MCP server. - * Start MindsDB MCP server without authentication to connect it to [Cursor](/mcp/cursor_usage). + * **Without authentication** (suitable for local tools): ```bash - docker run --name mindsdb_container -p 47334:47334 -p 47335:47335 mindsdb/mindsdb + docker run --name mindsdb_container -p 47334:47334 mindsdb/mindsdb ``` - * Start MindsDB MCP server with authentication to connect it to [OpenAI](/mcp/openai) or [Anthropic](/mcp/anthropic). + * **With PAT authentication** (suitable for remote): ```bash - docker run --name mindsdb_container -p 47334:47334 -p 47335:47335 -e MINDSDB_USERNAME=admin -e MINDSDB_PASSWORD=password123 mindsdb/mindsdb + docker run --name mindsdb_container -p 47334:47334 -e MINDSDB_USERNAME=admin -e MINDSDB_PASSWORD=password123 mindsdb/mindsdb ``` - Then get an auth token from MindsDB: + Get a Bearer token: ```bash curl -X POST -d '{"username":"admin","password":"password123"}' -H "Content-Type: application/json" http://localhost:47334/api/login ``` - This will return a token that you can use in your MCP client. + Use this token as `Authorization: Bearer ` in your MCP client. + + * **With OAuth 2.0** (for enterprise deployments): configure `MINDSDB_MCP_OAUTH_ENABLED=true` along with `MINDSDB_MCP_OAUTH_ISSUER_URL`, `MINDSDB_MCP_OAUTH_CLIENT_ID`, and `MINDSDB_MCP_OAUTH_CLIENT_SECRET`. 4. To confirm the MindsDB MCP server is running use `http://127.0.0.1:47334/mcp/status`. A successful response means your MCP environment is ready. -## MCP Tools +## MCP Capabilities + +### Tools + +**`query`** — Executes SQL queries against MindsDB using MySQL syntax. + +Parameters: +- `query` (required): SQL query string +- `context` (optional): Dict with default database, e.g. `{"db": "my_postgres"}` + +Returns one of: +- `{"type": "table", "column_names": [...], "data": [...]}` — for SELECT results +- `{"type": "ok", "affected_rows": N}` — for INSERT/UPDATE/DELETE +- `{"type": "error", "error_code": N, "error_message": "..."}` — on failure + +### Resources + +MCP resources expose schema information for discovery: + +| Resource URI | Description | +|---|---| +| `schema://databases` | Lists all connected data sources | +| `schema://databases/{db}/tables` | Lists tables in a database | +| `schema://databases/{db}/tables/{table}/columns` | Lists columns with types | +| `schema://knowledge_bases` | Lists knowledge bases | + +### Prompts -MindsDB MCP API exposes a set of tools that enable users to interact with their data and extract valuable insights. +**`sample_table`** — Generates instructions to fetch 5 sample rows and describe a table's structure. -**1. List Databases** +## Transport Modes -The `list_databases` tool lists all data sources connected to MindsDB. +- **HTTP (SSE)**: `http://127.0.0.1:47334/mcp/sse` +- **HTTP (Streamable)**: `http://127.0.0.1:47334/mcp/streamable` +- **Stdio**: run with `--mcp-stdio` flag for local stdio-based transport -**2. Query** +## Configuration -The `query` tool executes queries on the federated data to extract data relevant to answering a given question. +CORS, rate limiting, DNS rebinding protection, and OAuth settings for the MCP server are configured via the `api.mcp` section of `config.json` or the corresponding environment variables. See [Extend the Default MindsDB Configuration](/setup/custom-config#mcp-api) for the full parameter reference. diff --git a/docs/setup/custom-config.mdx b/docs/setup/custom-config.mdx index 89772412af0..0a76a038edc 100644 --- a/docs/setup/custom-config.mdx +++ b/docs/setup/custom-config.mdx @@ -191,6 +191,38 @@ Connection parameters for the MySQL API include: + + +The `mcp` section configures the [MCP server](/model-context-protocol/usage). + +```json +"api": { + "mcp": { + "cors": { + "enabled": true, + "allow_origins": [], + "allow_origin_regex": "https?://(localhost|127\\.0\\.0\\.1)(:\\d+)?", + "allow_headers": ["*"] + }, + "rate_limit": { + "enabled": false, + "requests_per_minute": 60 + }, + "dns_rebinding_protection": false + } +} +``` + +* `cors.enabled`: Enables CORS headers on MCP endpoints. Can also be set via `MINDSDB_MCP_CORS_ENABLED`. +* `cors.allow_origins`: List of allowed origins. Can also be set via `MINDSDB_MCP_ALLOW_ORIGINS` (comma-separated). +* `cors.allow_origin_regex`: Regex pattern for allowed origins. Can also be set via `MINDSDB_MCP_ALLOW_ORIGIN_REGEXP`. +* `cors.allow_headers`: List of allowed request headers. Can also be set via `MINDSDB_MCP_ALLOW_HEADERS` (comma-separated). +* `rate_limit.enabled`: Enables per-IP rate limiting. Can also be set via `MINDSDB_MCP_RATE_LIMIT_ENABLED`. +* `rate_limit.requests_per_minute`: Maximum number of requests per minute per IP. Can also be set via `MINDSDB_MCP_RATE_LIMIT_RPM`. +* `dns_rebinding_protection`: When `true`, the MCP transport validates the `Host` header against a list of known-safe hosts to prevent DNS rebinding attacks. Disabled by default (`false`). Enable it when running MindsDB locally and you want to restrict MCP access to `localhost` only. Can also be set via `MINDSDB_MCP_DNS_REBINDING_PROTECTION`. + + + #### `cache` diff --git a/mindsdb/__main__.py b/mindsdb/__main__.py index f5d2a672e2d..a09de7cc739 100644 --- a/mindsdb/__main__.py +++ b/mindsdb/__main__.py @@ -360,6 +360,15 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: sys.exit(0) + if config.cmd_args.mcp_stdio: + # StreamHandler writes to stderr by default, which MCP treats as notification messages. + # Raise the log level to ERROR to suppress notification spam, and explicitly set the + # stream to stderr in case the user has overridden it in their config. + os.environ["MINDSDB_CONSOLE_LOG_LEVEL"] = "ERROR" + config["logging"]["handlers"]["console"]["level"] = "ERROR" + config["logging"]["handlers"]["console"]["stream"] = "ext://sys.stderr" + log.configure_logging() + config.raise_warnings(logger=logger) os.environ["MINDSDB_RUNTIME"] = "1" @@ -430,6 +439,12 @@ def start_process(trunc_process_data: TrunkProcessData) -> None: clean_process_marks() + if config.cmd_args.mcp_stdio: + from mindsdb.api.mcp.mcp_instance import mcp + + mcp.run() + sys.exit(0) + # Get config values for APIs http_api_config = config.get("api", {}).get("http", {}) mysql_api_config = config.get("api", {}).get("mysql", {}) diff --git a/mindsdb/api/common/middleware.py b/mindsdb/api/common/middleware.py index 7730b178ad4..6fb93380191 100644 --- a/mindsdb/api/common/middleware.py +++ b/mindsdb/api/common/middleware.py @@ -1,13 +1,15 @@ import os +import time import hmac import secrets import hashlib +from collections import deque from http import HTTPStatus from typing import Optional -from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from starlette.requests import Request +from starlette.types import ASGIApp, Receive, Scope, Send from mindsdb.utilities import log from mindsdb.utilities.config import config @@ -24,6 +26,10 @@ def get_pat_fingerprint(token: str) -> str: return hmac.new(SECRET_KEY.encode(), token.encode(), hashlib.sha256).hexdigest() +if config["auth"]["token"]: + TOKENS.append(get_pat_fingerprint(config["auth"]["token"])) + + def generate_pat() -> str: logger.debug("Generating new auth token") token = "pat_" + secrets.token_urlsafe(32) @@ -56,23 +62,106 @@ def revoke_pat(raw_token: str) -> bool: return False -class PATAuthMiddleware(BaseHTTPMiddleware): - def _extract_bearer(self, request: Request) -> Optional[str]: - h = request.headers.get("Authorization") +class PATAuthMiddleware: + """Pure ASGI middleware (compatible with SSE / streaming responses). + The class is not inherited from starlette.middleware.base.BaseHTTPMiddleware + bacause it collect responses to buffer, which is not good for streaming + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + @staticmethod + def _extract_bearer(headers: dict) -> Optional[str]: + h = headers.get("authorization") if not h or not h.startswith("Bearer "): return None return h.split(" ", 1)[1].strip() or None - async def dispatch(self, request: Request, call_next): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + if config.get("auth", {}).get("http_auth_enabled", False) is False: - return await call_next(request) + await self.app(scope, receive, send) + return - token = self._extract_bearer(request) - if not token or not verify_pat(token): - return JSONResponse({"detail": "Unauthorized"}, status_code=HTTPStatus.UNAUTHORIZED) + if scope.get("method") == "OPTIONS": + await self.app(scope, receive, send) + return - request.state.user = config["auth"].get("username") - return await call_next(request) + request = Request(scope) + token = self._extract_bearer(dict(request.headers)) + if not token or not verify_pat(token): + response = JSONResponse({"detail": "Unauthorized"}, status_code=HTTPStatus.UNAUTHORIZED) + await response(scope, receive, send) + return + + scope.setdefault("state", {})["user"] = config["auth"].get("username") + await self.app(scope, receive, send) + + +class RateLimitMiddleware: + """Rate limiting middleware using a sliding window counter. Tracks requests per client IP.""" + + def __init__(self, app: ASGIApp, requests_per_minute: int) -> None: + self.app = app + self.requests_per_minute = requests_per_minute + self._window = 60.0 # seconds + self._counters: dict[str, deque] = {} + + def _get_client_key(self, scope: Scope) -> str: + client = scope.get("client") + if client: + return client[0] + return "unknown" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + if scope.get("method") == "OPTIONS": + await self.app(scope, receive, send) + return + + # Clients usually repeat this request until + # the connection is established, so no rate limit it. + if scope.get("method") == "GET" and scope.get("path", "").endswith("/sse"): + await self.app(scope, receive, send) + return + + client_key = self._get_client_key(scope) + now = time.monotonic() + window_start = now - self._window + + timestamps = self._counters.setdefault(client_key, deque()) + + # Evict timestamps outside the current window + while timestamps and timestamps[0] <= window_start: + timestamps.popleft() + + if len(timestamps) >= self.requests_per_minute: + retry_after = int(self._window - (now - timestamps[0])) + 1 + else: + retry_after = None + timestamps.append(now) + + if retry_after is not None: + response = JSONResponse( + {"detail": f"Too Many Requests, retry after {retry_after} seconds"}, + status_code=HTTPStatus.TOO_MANY_REQUESTS, + headers={"Retry-After": str(retry_after)}, + ) + await response(scope, receive, send) + return + + stale_keys = [k for k, ts in self._counters.items() if not ts or ts[-1] <= window_start] + for k in stale_keys: + del self._counters[k] + + await self.app(scope, receive, send) # Used by mysql protocol diff --git a/mindsdb/api/http/namespaces/default.py b/mindsdb/api/http/namespaces/default.py index cdcf39d387f..4b2e0940ba5 100644 --- a/mindsdb/api/http/namespaces/default.py +++ b/mindsdb/api/http/namespaces/default.py @@ -65,7 +65,10 @@ def post(self): session.permanent = True if config["auth"]["http_auth_type"] in (HTTP_AUTH_TYPE.TOKEN, HTTP_AUTH_TYPE.SESSION_OR_TOKEN): - response["token"] = generate_pat() + if config["auth"]["token"]: + response["token"] = config["auth"]["token"] + else: + response["token"] = generate_pat() return response, 200 diff --git a/mindsdb/api/http/start.py b/mindsdb/api/http/start.py index 9cfb8454c89..f2373ebb114 100644 --- a/mindsdb/api/http/start.py +++ b/mindsdb/api/http/start.py @@ -1,5 +1,6 @@ import gc from importlib import import_module +from contextlib import asynccontextmanager, AsyncExitStack gc.disable() @@ -28,7 +29,7 @@ async def _health_check(request): return JSONResponse({"status": "ok"}) -def _mount_optional_api(name: str, mount_path: str, get_app_fn, routes): +def _mount_optional_api(name: str, mount_path: str, get_app_fn, routes) -> object | None: try: optional_app = get_app_fn() except ImportError as exc: @@ -41,8 +42,11 @@ def _mount_optional_api(name: str, mount_path: str, get_app_fn, routes): ) return - optional_app.add_middleware(PATAuthMiddleware) + if name.upper() != "MCP" or config["api"]["mcp"]["oauth"]["enabled"] is False: + optional_app.add_middleware(PATAuthMiddleware) + routes.append(Mount(mount_path, app=optional_app)) + return optional_app def start(verbose, app: Flask = None, is_restart: bool = False): @@ -58,23 +62,44 @@ def start(verbose, app: Flask = None, is_restart: bool = False): process_cache.init() routes = [] + sub_apps = [] # Health check FIRST - async endpoint that bypasses WSGI worker pool # This ensures health checks respond even when all workers are blocked routes.append(Route("/api/util/ping", _health_check, methods=["GET"])) - _mount_optional_api( - "A2A", - "/a2a", - lambda: import_module("mindsdb.api.a2a").get_a2a_app(), - routes, - ) - _mount_optional_api( - "MCP", - "/mcp", - lambda: import_module("mindsdb.api.mcp").get_mcp_app(), - routes, - ) + for name, path, factory in [ + ("A2A", "/a2a", lambda: import_module("mindsdb.api.a2a").get_a2a_app()), + ("MCP", "/mcp", lambda: import_module("mindsdb.api.mcp").get_mcp_app()), + ]: + mounted = _mount_optional_api(name, path, factory, routes) + if mounted is not None: + sub_apps.append(mounted) + + # RFC 9728: /.well-known/oauth-protected-resource must be at the server root, + # not under the /mcp mount, so we register it here before the Flask fallback. + try: + well_known_routes = import_module("mindsdb.api.mcp").get_mcp_well_known_routes() + routes.extend(well_known_routes) + except ImportError: + pass + except Exception as e: + logger.warning(f"Error during registering of mcp well-known routes: {e}") + + @asynccontextmanager + async def lifespan(_): + """Propagate ASGI lifespan events to mounted sub-apps. + + Starlette's Mount does not forward startup/shutdown lifespan events to + sub-applications automatically. This context manager manually enters the + lifespan context of each collected sub-app so their internal state + (e.g. StreamableHTTPSessionManager task group for MCP) is properly + initialized on startup and torn down on shutdown. + """ + async with AsyncExitStack() as stack: + for sub_app in sub_apps: + await stack.enter_async_context(sub_app.router.lifespan_context(sub_app)) + yield # Root app LAST so it won't shadow the others routes.append( @@ -89,4 +114,10 @@ def start(verbose, app: Flask = None, is_restart: bool = False): ) # Setting logging to None makes uvicorn use the existing logging configuration - uvicorn.run(Starlette(routes=routes, debug=verbose), host=host, port=int(port), log_level=None, log_config=None) + uvicorn.run( + Starlette(routes=routes, lifespan=lifespan, debug=verbose), + host=host, + port=int(port), + log_level=None, + log_config=None, + ) diff --git a/mindsdb/api/mcp/__init__.py b/mindsdb/api/mcp/__init__.py index b5601a16e8b..3473a394e61 100644 --- a/mindsdb/api/mcp/__init__.py +++ b/mindsdb/api/mcp/__init__.py @@ -1,182 +1,3 @@ -import os -from textwrap import dedent -from typing import Any -from contextlib import asynccontextmanager -from collections.abc import AsyncIterator -from dataclasses import dataclass +from mindsdb.api.mcp.app import get_mcp_app, get_mcp_well_known_routes -from mcp.server.fastmcp import FastMCP -from mcp.server.transport_security import TransportSecuritySettings -from starlette.requests import Request -from starlette.responses import JSONResponse - -from mindsdb.api.mysql.mysql_proxy.classes.fake_mysql_proxy import FakeMysqlProxy -from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE as SQL_RESPONSE_TYPE -from mindsdb.interfaces.storage import db -from mindsdb.utilities import log - -logger = log.getLogger(__name__) - - -def _get_transport_security() -> TransportSecuritySettings: - default_hosts = ["localhost:*", "127.0.0.1:*"] - env_hosts = os.environ.get("MINDSDB_MCP_ALLOWED_HOSTS", "") - if env_hosts: - custom_hosts = [h.strip() for h in env_hosts.split(",") if h.strip()] - for host in custom_hosts: - if ":" not in host: - default_hosts.append(f"{host}:*") - default_hosts.append(host) - logger.info(f"MCP transport security allowed hosts: {default_hosts}") - return TransportSecuritySettings(allowed_hosts=default_hosts) - - -@dataclass -class AppContext: - db: Any - - -@asynccontextmanager -async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: - """Manage application lifecycle with type-safe context""" - # Initialize on startup - db.init() - try: - yield AppContext(db=db) - finally: - # TODO: We need better way to handle this in storage/db.py - pass - - -# Configure server with lifespan and transport security -mcp = FastMCP( - "MindsDB", - lifespan=app_lifespan, - dependencies=["mindsdb"], - transport_security=_get_transport_security(), -) - - -# MCP Queries -LISTING_QUERY = "SHOW DATABASES" - - -query_tool_description = dedent("""\ - Executes a SQL query against MindsDB. - - A database must be specified either in the `context` parameter or directly in the query string (e.g., `SELECT * FROM my_database.my_table`). Queries like `SELECT * FROM my_table` will fail without a `context`. - - Args: - query (str): The SQL query to execute. - context (dict, optional): The default database context. For example, `{"db": "my_postgres"}`. - - Returns: - A dictionary describing the result. - - For a successful query with no data to return (e.g., an `UPDATE` statement), the response is `{"type": "ok"}`. - - If the query returns tabular data, the response is a dictionary containing `data` (a list of rows) and `column_names` (a list of column names). For example: `{"type": "table", "data": [[1, "a"], [2, "b"]], "column_names": ["column_a", "column_b"]}`. - - In case of an error, a response is `{"type": "error", "error_message": "the error message"}`. -""") - - -@mcp.tool(name="query", description=query_tool_description) -def query(query: str, context: dict | None = None) -> dict[str, Any]: - """Execute a SQL query against MindsDB - - Args: - query: The SQL query to execute - context: Optional context parameters for the query - - Returns: - Dict containing the query results or error information - """ - - if context is None: - context = {} - - logger.debug(f"Incoming MCP query: {query}") - - mysql_proxy = FakeMysqlProxy() - mysql_proxy.set_context(context) - - try: - result = mysql_proxy.process_query(query) - - if result.type == SQL_RESPONSE_TYPE.OK: - return {"type": SQL_RESPONSE_TYPE.OK} - - if result.type == SQL_RESPONSE_TYPE.TABLE: - return { - "type": SQL_RESPONSE_TYPE.TABLE, - "data": result.result_set.to_lists(json_types=True), - "column_names": [column.alias or column.name for column in result.result_set.columns], - } - else: - return {"type": SQL_RESPONSE_TYPE.ERROR, "error_code": 0, "error_message": "Unknown response type"} - - except Exception as e: - logger.exception("Error processing query:") - return {"type": SQL_RESPONSE_TYPE.ERROR, "error_code": 0, "error_message": str(e)} - - -list_databases_tool_description = ( - "Returns a list of all database connections currently available in MindsDB. " - + "The tool takes no parameters and responds with a list of database names, " - + 'for example: ["my_postgres", "my_mysql", "test_db"].' -) - - -@mcp.tool(name="list_databases", description=list_databases_tool_description) -def list_databases() -> list[str]: - """ - List all databases in MindsDB - - Returns: - list[str]: list of databases - """ - - mysql_proxy = FakeMysqlProxy() - - try: - result = mysql_proxy.process_query(LISTING_QUERY) - if result.type == SQL_RESPONSE_TYPE.ERROR: - return { - "type": "error", - "error_code": result.error_code, - "error_message": result.error_message, - } - - elif result.type == SQL_RESPONSE_TYPE.OK: - return {"type": "ok"} - - elif result.type == SQL_RESPONSE_TYPE.TABLE: - data = result.result_set.to_lists(json_types=True) - data = [val[0] for val in data] - return data - - except Exception as e: - logger.exception("Error while retrieving list of databases") - return { - "type": "error", - "error_code": 0, - "error_message": str(e), - } - - -def _get_status(request: Request) -> JSONResponse: - """ - Status endpoint that returns basic server information. - This endpoint can be used by the frontend to check if the MCP server is running. - """ - - status_info = { - "status": "ok", - "service": "mindsdb-mcp", - } - - return JSONResponse(status_info) - - -def get_mcp_app(): - app = mcp.sse_app() - app.add_route("/status", _get_status, methods=["GET"]) - return app +__all__ = ["get_mcp_app", "get_mcp_well_known_routes"] diff --git a/mindsdb/api/mcp/app.py b/mindsdb/api/mcp/app.py new file mode 100644 index 00000000000..ea810595ac0 --- /dev/null +++ b/mindsdb/api/mcp/app.py @@ -0,0 +1,94 @@ +from contextlib import asynccontextmanager + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware + +from mindsdb.utilities.config import config +from mindsdb.api.common.middleware import RateLimitMiddleware +from mindsdb.api.mcp.mcp_instance import mcp + +# region these imports required for correct initialization +from mindsdb.api.mcp import tools # noqa: F401 +from mindsdb.api.mcp import resources # noqa: F401 +from mindsdb.api.mcp import prompts # noqa: F401 +from mindsdb.api.mcp import completions # noqa: F401 +# endregion + + +def _get_status(request: Request) -> JSONResponse: + return JSONResponse({"status": "ok", "service": "mindsdb-mcp"}) + + +def get_mcp_app(): + sse_starlette = mcp.sse_app() + http_starlette = mcp.streamable_http_app() + + @asynccontextmanager + async def lifespan(_): + """Required for streamable_http to run task group""" + async with http_starlette.router.lifespan_context(http_starlette): + yield + + middleware = [] + + # Preserve AuthenticationMiddleware from http_starlette so that + # RequireAuthMiddleware can read scope["user"] set by BearerAuthBackend. + if mcp._token_verifier is not None: + middleware = [ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(mcp._token_verifier)), + Middleware(AuthContextMiddleware), + ] + + combined_app = Starlette( + routes=list(sse_starlette.routes) + list(http_starlette.routes), + middleware=middleware, + lifespan=lifespan, + ) + + # Rate limit should be added before CORS, so that CORS adds correct headers + if config["api"]["mcp"]["rate_limit"]["enabled"]: + combined_app.add_middleware( + RateLimitMiddleware, + requests_per_minute=config["api"]["mcp"]["rate_limit"]["requests_per_minute"], + ) + + if config["api"]["mcp"]["cors"]["enabled"]: + combined_app.add_middleware( + CORSMiddleware, + allow_origins=config["api"]["mcp"]["cors"]["allow_origins"], + allow_origin_regex=config["api"]["mcp"]["cors"]["allow_origin_regex"], + allow_methods=["GET", "POST", "DELETE", "OPTIONS"], + allow_headers=config["api"]["mcp"]["cors"]["allow_headers"], + expose_headers=["mcp-session-id"], + ) + + combined_app.add_route("/status", _get_status, methods=["GET"]) + + return combined_app + + +def get_mcp_well_known_routes() -> list[Route]: + """Return OAuth protected resource metadata routes for mounting at the server root. + + RFC 9728 requires /.well-known/oauth-protected-resource to be served at the + server root, not under the /mcp sub-path, so start.py registers these separately. + """ + from mcp.server.auth.routes import create_protected_resource_routes + + auth = mcp.settings.auth + if not auth or not auth.resource_server_url: + return [] + + return create_protected_resource_routes( + resource_url=auth.resource_server_url, + authorization_servers=[auth.issuer_url], + scopes_supported=auth.required_scopes, + ) diff --git a/mindsdb/api/mcp/completions.py b/mindsdb/api/mcp/completions.py new file mode 100644 index 00000000000..94bf2abe2cd --- /dev/null +++ b/mindsdb/api/mcp/completions.py @@ -0,0 +1,35 @@ +from mcp.types import Completion, PromptReference, ResourceTemplateReference + +from mindsdb.api.mcp.mcp_instance import mcp +from mindsdb.api.executor.controllers.session_controller import SessionController +from mindsdb.utilities.context import context as ctx +from mindsdb.api.mcp.resources.schema import _get_database_names +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +@mcp.completion() +async def handle_completion(ref, argument, context): + if not isinstance(ref, (ResourceTemplateReference, PromptReference)): + return None + + try: + if argument.name == "database_name": + names = _get_database_names() + return Completion(values=[n for n in names if n.startswith(argument.value)]) + + if argument.name == "table_name": + database_name = (context.arguments or {}).get("database_name") + if not database_name: + return None + ctx.set_default() + session = SessionController() + datanode = session.datahub.get(database_name) + all_tables = datanode.get_tables() + names = [t.TABLE_NAME for t in all_tables] + return Completion(values=[n for n in names if n.startswith(argument.value)]) + except Exception as e: + logger.info(f"Couldn't get completion for parameter {argument.name}: {e}") + + return None diff --git a/mindsdb/api/mcp/mcp_instance.py b/mindsdb/api/mcp/mcp_instance.py new file mode 100644 index 00000000000..fa65ab47711 --- /dev/null +++ b/mindsdb/api/mcp/mcp_instance.py @@ -0,0 +1,36 @@ +from mcp.server.fastmcp import FastMCP +from mcp.server.transport_security import TransportSecuritySettings + +from mindsdb.api.mcp.oauth import build_oauth_components +from mindsdb.utilities.config import config + + +def _create_mcp() -> FastMCP: + token_verifier, auth_settings = build_oauth_components() + + dns_rebinding_protection = config["api"]["mcp"]["dns_rebinding_protection"] + transport_security = TransportSecuritySettings(enable_dns_rebinding_protection=dns_rebinding_protection) + + return FastMCP( + name="MindsDB", + instructions=( + "MindsDB is a data platform that connects to external databases and data sources.\n" + "Use the available resources to discover connected databases and their schema,\n" + "then use the `query` tool to retrieve or manipulate data with SQL.\n" + "\n" + "Workflow:\n" + "1. Read `schema://databases` to list available data sources.\n" + "2. Read `schema://databases/{name}/tables` to explore tables in a source.\n" + "3. Read `schema://databases/{name}/tables/{table}/columns` to inspect columns.\n" + "4. Use the `query` tool to run SQL queries against the data." + ), + dependencies=["mindsdb"], + streamable_http_path="/streamable", + debug=False, + token_verifier=token_verifier, + auth=auth_settings, + transport_security=transport_security, + ) + + +mcp = _create_mcp() diff --git a/mindsdb/api/mcp/oauth.py b/mindsdb/api/mcp/oauth.py new file mode 100644 index 00000000000..0e21ffefe27 --- /dev/null +++ b/mindsdb/api/mcp/oauth.py @@ -0,0 +1,159 @@ +from typing import Any +from urllib.parse import urljoin + +import httpx +from pydantic import AnyHttpUrl +from mcp.server.auth.settings import AuthSettings +from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url + +from mindsdb.utilities.config import config +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +class IntrospectionTokenVerifier(TokenVerifier): + """Token verifier that uses OAuth 2.0 Token Introspection (RFC 7662). + Intended for use when MindsDB acts as a Resource Server and token + issuance is delegated to an external provider (e.g. Keycloak). + + Args: + introspection_endpoint: Full URL of the RFC 7662 introspection endpoint. + server_url: Public URL of this MCP server (e.g. ``http://host:port/mcp/streamable``). + Used to derive the expected ``aud`` (audience) claim value. + client_id: OAuth client ID used to authenticate against the introspection endpoint. + client_secret: OAuth client secret used to authenticate against the introspection endpoint. + """ + + def __init__( + self, + introspection_endpoint: str, + server_url: str, + client_id: str, + client_secret: str, + ): + self.introspection_endpoint = introspection_endpoint + self.server_url = server_url + self.client_id = client_id + self.client_secret = client_secret + self.resource_url = resource_url_from_server_url(server_url) + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify a bearer token via the introspection endpoint. + + Args: + token: Raw bearer token string extracted from the Authorization header. + + Returns: + AccessToken: Populated access token on successful verification. + None: If the token is inactive, the audience is invalid, the endpoint + is unreachable, or any other error occurs. + """ + # to prevent SSRF attacks it must start from https, or be local server + if not self.introspection_endpoint.startswith(("https://", "http://localhost:", "http://127.0.0.1:")): + return None + + timeout = httpx.Timeout(10.0, connect=5.0) + limits = httpx.Limits(max_connections=10, max_keepalive_connections=5) + + async with httpx.AsyncClient( + timeout=timeout, + limits=limits, + verify=True, + follow_redirects=False, + ) as client: + try: + form_data = { + "token": token, + "client_id": self.client_id, + "client_secret": self.client_secret, + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + response = await client.post( + self.introspection_endpoint, + data=form_data, + headers=headers, + ) + + if response.status_code != 200: + return None + + data = response.json() + if not data.get("active", False): + return None + + if not self._validate_resource(data): + return None + + return AccessToken( + token=token, + client_id=data.get("client_id", "unknown"), + scopes=data.get("scope", "").split() if data.get("scope") else [], + expires_at=data.get("exp"), + resource=self.resource_url, + ) + + except Exception as e: + logger.error(f"Error during token verification: {e}") + return None + + def _validate_resource(self, token_data: dict[str, Any]) -> bool: + """Validate that the token was issued for this resource server (RFC 8707). + + Args: + token_data: Parsed JSON response from the introspection endpoint. + + Returns: + bool: True if at least one audience entry matches this server's resource URL, + False if ``aud`` is missing or no entry matches. + """ + if not self.server_url or not self.resource_url: + return False + + aud: list[str] | str | None = token_data.get("aud") + if isinstance(aud, list): + return any(check_resource_allowed(self.resource_url, a) for a in aud) + if isinstance(aud, str): + return check_resource_allowed(self.resource_url, aud) + return False + + +def build_oauth_components() -> tuple[IntrospectionTokenVerifier, AuthSettings] | tuple[None, None]: + """Build token verifier and auth settings from the OAuth config section. + + Returns: + tuple[IntrospectionTokenVerifier, AuthSettings]: Token verifier and auth settings ready + to pass to FastMCP if OAuth is enabled. + tuple[None, None]: If OAuth ``enabled`` is False or not set. + """ + oauth_cfg = config["api"]["mcp"]["oauth"] + if not oauth_cfg.get("enabled", False): + return None, None + + host = config["api"]["http"]["host"] + port = config["api"]["http"]["port"] + mcp_endpoint_url = f"http://{host}:{port}/mcp/streamable" + + issuer_url = oauth_cfg.get("issuer_url", "").rstrip("/") + "/" + client_id = oauth_cfg.get("client_id", "") + client_secret = oauth_cfg.get("client_secret", "") + scope = oauth_cfg.get("scope", "mcp:tools") + + introspection_endpoint = urljoin(issuer_url, "protocol/openid-connect/token/introspect") + + token_verifier = IntrospectionTokenVerifier( + introspection_endpoint=introspection_endpoint, + server_url=mcp_endpoint_url, + client_id=client_id, + client_secret=client_secret, + ) + + auth_settings = AuthSettings( + issuer_url=AnyHttpUrl(issuer_url), + required_scopes=[scope], + resource_server_url=AnyHttpUrl(mcp_endpoint_url), + ) + + return token_verifier, auth_settings diff --git a/mindsdb/api/mcp/prompts/__init__.py b/mindsdb/api/mcp/prompts/__init__.py new file mode 100644 index 00000000000..437673b53d3 --- /dev/null +++ b/mindsdb/api/mcp/prompts/__init__.py @@ -0,0 +1 @@ +from mindsdb.api.mcp.prompts import sample_table # noqa: F401 diff --git a/mindsdb/api/mcp/prompts/sample_table.py b/mindsdb/api/mcp/prompts/sample_table.py new file mode 100644 index 00000000000..2473715aa7d --- /dev/null +++ b/mindsdb/api/mcp/prompts/sample_table.py @@ -0,0 +1,21 @@ +from mcp.types import TextContent + +from mindsdb.api.mcp.mcp_instance import mcp + + +@mcp.prompt(name="sample_table", description="Fetch 5 sample rows from a table and describe its structure.") +def sample_table(database_name: str, table_name: str) -> list[TextContent]: + return [ + TextContent( + type="text", + text=( + f"Use the `query` tool to fetch 5 sample rows from the table `{table_name}` " + f"in database `{database_name}`:\n\n" + f"```sql\n" + f"SELECT * FROM `{database_name}`.`{table_name}` LIMIT 5;\n" + f"```\n\n" + f"After getting the results, briefly describe the table structure " + f"and what kind of data it contains." + ), + ) + ] diff --git a/mindsdb/api/mcp/resources/__init__.py b/mindsdb/api/mcp/resources/__init__.py new file mode 100644 index 00000000000..5cd0b60720d --- /dev/null +++ b/mindsdb/api/mcp/resources/__init__.py @@ -0,0 +1 @@ +from mindsdb.api.mcp.resources import schema # noqa: F401 diff --git a/mindsdb/api/mcp/resources/schema.py b/mindsdb/api/mcp/resources/schema.py new file mode 100644 index 00000000000..6986c7dd420 --- /dev/null +++ b/mindsdb/api/mcp/resources/schema.py @@ -0,0 +1,136 @@ +from pydantic import BaseModel + +from mindsdb.api.mcp.mcp_instance import mcp +from mindsdb.api.executor.controllers.session_controller import SessionController +from mindsdb.utilities.context import context as ctx +from mindsdb.integrations.libs.response import TableResponse, ErrorResponse +from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE + + +class TableInfo(BaseModel): + TABLE_NAME: str + TABLE_TYPE: str + TABLE_SCHEMA: str + + +class ColumnInfo(BaseModel): + COLUMN_NAME: str + MYSQL_DATA_TYPE: str + + +class KnowledgeBaseInfo(BaseModel): + name: str + project: str + metadata_columns: list[str] + content_columns: list[str] + id_column: str + + +def _get_database_names() -> list[str]: + ctx.set_default() + session = SessionController() + databases = session.database_controller.get_list() + return [x["name"] for x in databases if x["type"] == "data"] + + +@mcp.resource( + "schema://databases", + mime_type="application/json", + description=( + "Initial list of connected data source names available for querying. " + "This resource may be cached by the client. " + "To get the current list of databases during a session, use the `query` tool: " + "SHOW DATABASES" + ), +) +def list_databases() -> list[str]: + return _get_database_names() + + +@mcp.resource( + "schema://databases/{database_name}/tables", + mime_type="application/json", + description=( + "Initial list of tables in the specified connected database. " + "This resource may be cached by the client. " + "To get the current list of tables during a session (e.g. after CREATE/DROP TABLE), " + "use the `query` tool: " + "SHOW TABLES FROM {database_name}" + ), +) +def db_tables(database_name: str) -> list[TableInfo]: + ctx.set_default() + session = SessionController() + datanode = session.datahub.get(database_name) + if datanode is None: + raise ValueError(f"Database '{database_name}' is not found.") + all_tables = datanode.get_tables() + all_tables = [ + { + "TABLE_NAME": table.TABLE_NAME, + "TABLE_TYPE": table.TABLE_TYPE, + "TABLE_SCHEMA": table.TABLE_SCHEMA, + } + for table in all_tables + ] + return all_tables + + +@mcp.resource( + "schema://databases/{database_name}/tables/{table_name}/columns", + mime_type="application/json", + description=( + "Initial column names and types for a specific table in a connected database. " + "This resource may be cached by the client. " + "To get the current column list during a session (e.g. after ALTER TABLE), " + "use the `query` tool: " + "SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS " + "WHERE TABLE_SCHEMA = '{database_name}' AND TABLE_NAME = '{table_name}'" + ), +) +def db_table_columns(database_name: str, table_name: str) -> list[ColumnInfo]: + ctx.set_default() + session = SessionController() + handler = session.integration_controller.get_data_handler(database_name) + columns_answer = handler.get_columns(table_name) + + if isinstance(columns_answer, TableResponse): + if columns_answer.type != RESPONSE_TYPE.COLUMNS_TABLE: + raise ValueError( + "Database returned a successful response, but the column list does not match the expected format" + ) + df = columns_answer.fetchall() + response = df[["COLUMN_NAME", "MYSQL_DATA_TYPE"]].to_dict(orient="records") + return response + if isinstance(columns_answer, ErrorResponse): + raise ValueError(columns_answer.error_message) + raise ValueError(f"Unexpected handler response type: {columns_answer}") + + +@mcp.resource( + "schema://knowledge_bases", + description=( + "Initial list of knowledge bases with their project, column configuration, and ID column. " + "This resource may be cached by the client. " + "To get the current list of knowledge bases during a session, use the `query` tool: " + "SHOW KNOWLEDGE BASES" + ), +) +def list_knowledge_bases() -> list[KnowledgeBaseInfo]: + ctx.set_default() + session = SessionController() + project_names = session.datahub.get_projects_names() + result = [] + for project_name in project_names: + kbs = session.kb_controller.list(project_name) + for kb in kbs: + result.append( + { + "name": kb.get("name"), + "project": kb.get("project"), + "metadata_columns": kb.get("metadata_columns"), + "content_columns": kb.get("content_columns"), + "id_column": kb.get("id_column"), + } + ) + return result diff --git a/mindsdb/api/mcp/tools/__init__.py b/mindsdb/api/mcp/tools/__init__.py new file mode 100644 index 00000000000..a07edf06817 --- /dev/null +++ b/mindsdb/api/mcp/tools/__init__.py @@ -0,0 +1 @@ +from mindsdb.api.mcp.tools import query # noqa: F401 diff --git a/mindsdb/api/mcp/tools/query.py b/mindsdb/api/mcp/tools/query.py new file mode 100644 index 00000000000..42026e32b1f --- /dev/null +++ b/mindsdb/api/mcp/tools/query.py @@ -0,0 +1,60 @@ +from textwrap import dedent +from typing import Annotated + +from pydantic import Field + +from mindsdb.api.mcp.mcp_instance import mcp +from mindsdb.api.mcp.types import ErrorResponse, QueryResponseAnswer, response_adapter +from mindsdb.api.mysql.mysql_proxy.mysql_proxy import SQLAnswer +from mindsdb.api.mysql.mysql_proxy.classes.fake_mysql_proxy import FakeMysqlProxy +from mindsdb.utilities.context import context as ctx +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +query_tool_description = dedent("""\ + Execute a SQL query against MindsDB and return the result. + + Queries use MySQL syntax. Use fully qualified names (`database`.`table`) or set `context` to specify + the default database. Use backticks (`) to quote identifiers that are reserved words or contain + special characters. + + Returns one of: + - `{"type": "ok"}` — for statements with no output (INSERT, UPDATE, etc.) + - `{"type": "table", "column_names": [...], "data": [[...], ...]}` — for SELECT results + - `{"type": "error", "error_message": "..."}` — on failure +""") + + +@mcp.tool(name="query", description=query_tool_description) +def query( + query: Annotated[str, Field(description="SQL query to execute against MindsDB.")], + context: Annotated[ + dict | None, + Field( + description=( + 'Default database context, e.g. {"db": "my_postgres"}. ' + "Required if the query does not use fully qualified table names." + ) + ), + ] = None, +) -> QueryResponseAnswer: + ctx.set_default() + + if context is None: + context = {} + + logger.debug(f"Incoming MCP query: {query}") + + mysql_proxy = FakeMysqlProxy() + mysql_proxy.set_context(context) + + try: + result: SQLAnswer = mysql_proxy.process_query(query) + query_response: dict = result.dump_http_response() + except Exception as e: + logger.exception("Error processing query:") + return ErrorResponse(type="error", error_code=0, error_message=str(e)) + + return response_adapter.validate_python(query_response) diff --git a/mindsdb/api/mcp/types.py b/mindsdb/api/mcp/types.py new file mode 100644 index 00000000000..0275742116f --- /dev/null +++ b/mindsdb/api/mcp/types.py @@ -0,0 +1,25 @@ +from typing import Annotated, Literal, Union + +from pydantic import BaseModel, Field, TypeAdapter + + +class OkResponse(BaseModel): + type: Literal["ok"] + affected_rows: int | None = None + + +class ErrorResponse(BaseModel): + type: Literal["error"] + error_code: int + error_message: str + + +class TableResponse(BaseModel): + type: Literal["table"] + column_names: list[str] + data: list[list] + + +QueryResponseAnswer = Annotated[Union[OkResponse, ErrorResponse, TableResponse], Field(discriminator="type")] + +response_adapter = TypeAdapter(QueryResponseAnswer) diff --git a/mindsdb/integrations/handlers/file_handler/file_handler.py b/mindsdb/integrations/handlers/file_handler/file_handler.py index c6c66408caa..6a1fc443ee4 100644 --- a/mindsdb/integrations/handlers/file_handler/file_handler.py +++ b/mindsdb/integrations/handlers/file_handler/file_handler.py @@ -7,11 +7,15 @@ from mindsdb_sql_parser.ast import CreateTable, DropTables, Insert, Select, Identifier from mindsdb_sql_parser.ast.base import ASTNode +from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE from mindsdb.api.executor.utilities.sql import query_dfs from mindsdb.integrations.libs.base import DatabaseHandler -from mindsdb.integrations.libs.response import RESPONSE_TYPE -from mindsdb.integrations.libs.response import HandlerResponse as Response -from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse +from mindsdb.integrations.libs.response import ( + RESPONSE_TYPE, + HandlerResponse as Response, + HandlerStatusResponse as StatusResponse, + INF_SCHEMA_COLUMNS_NAMES_SET, +) from mindsdb.utilities import log @@ -211,16 +215,23 @@ def get_tables(self) -> Response: def get_columns(self, table_name) -> Response: file_meta = self.file_controller.get_file_meta(table_name) + if file_meta is None: + result = Response( + RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame([], columns=list(INF_SCHEMA_COLUMNS_NAMES_SET)) + ) + result.to_columns_table_response(map_type_fn=lambda _: MYSQL_DATA_TYPE.TEXT) + return result result = Response( RESPONSE_TYPE.TABLE, data_frame=pd.DataFrame( [ { - "Field": x["name"].strip() if isinstance(x, dict) else x.strip(), - "Type": "str", + "COLUMN_NAME": x["name"].strip() if isinstance(x, dict) else x.strip(), + "DATA_TYPE": "str", } for x in file_meta["columns"] ] ), ) + result.to_columns_table_response(map_type_fn=lambda _: MYSQL_DATA_TYPE.TEXT) return result diff --git a/mindsdb/integrations/libs/response.py b/mindsdb/integrations/libs/response.py index c559e875869..3af33b444fa 100644 --- a/mindsdb/integrations/libs/response.py +++ b/mindsdb/integrations/libs/response.py @@ -405,16 +405,21 @@ def to_columns_table_response(self, map_type_fn: Callable) -> None: if self._data is None: return self._data.columns = [name.upper() for name in self._data.columns] + + for required_column in (INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME, INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE): + if required_column not in self._data.columns: + raise ValueError( + f"Missed required for INFORMATION_SCHEMA.COLUMNS column {required_column}. " + f"Columns set: {self._data.columns}" + ) + for column_name in INF_SCHEMA_COLUMNS_NAMES_SET: + if column_name not in self._data.columns: + self._data[column_name] = None + self._data[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE] = self._data[INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE].apply( map_type_fn ) - # region validate df - current_columns_set = set(self._data.columns) - if INF_SCHEMA_COLUMNS_NAMES_SET != current_columns_set: - raise ValueError(f"Columns set for INFORMATION_SCHEMA.COLUMNS is wrong: {list(current_columns_set)}") - # endregion - self._data = self._data.astype( { INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME: "string", @@ -475,13 +480,16 @@ def normalize_response(response) -> TableResponse | OkResponse | ErrorResponse: if mysql_types is None: mysql_types = [None] * len(columns) - return TableResponse( + table_response = TableResponse( data=response.data_frame, columns=[ Column(name=column_name, type=mysql_type) for column_name, mysql_type in zip(columns, mysql_types) ], data_generator=iter([]), # empty generator for legacy responses ) + if response.resp_type == RESPONSE_TYPE.COLUMNS_TABLE: + table_response.type = RESPONSE_TYPE.COLUMNS_TABLE + return table_response # Unknown type - return as-is (shouldn't happen normally) return response @@ -538,16 +546,21 @@ def to_columns_table_response(self, map_type_fn: Callable) -> None: raise ValueError(f"Cannot convert {self.resp_type} to {RESPONSE_TYPE.COLUMNS_TABLE}") self.data_frame.columns = [name.upper() for name in self.data_frame.columns] + + for required_column in (INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME, INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE): + if required_column not in self.data_frame.columns: + raise ValueError( + f"Missed required for INFORMATION_SCHEMA.COLUMNS column {required_column}. " + f"Columns set: {self.data_frame.columns}" + ) + for column_name in INF_SCHEMA_COLUMNS_NAMES_SET: + if column_name not in self.data_frame.columns: + self.data_frame[column_name] = None + self.data_frame[INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE] = self.data_frame[ INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE ].apply(map_type_fn) - # region validate df - current_columns_set = set(self.data_frame.columns) - if INF_SCHEMA_COLUMNS_NAMES_SET != current_columns_set: - raise ValueError(f"Columns set for INFORMATION_SCHEMA.COLUMNS is wrong: {list(current_columns_set)}") - # endregion - self.data_frame = self.data_frame.astype( { INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME: "string", diff --git a/mindsdb/utilities/config.py b/mindsdb/utilities/config.py index 7aa74f2b1e7..b7187dff930 100644 --- a/mindsdb/utilities/config.py +++ b/mindsdb/utilities/config.py @@ -13,6 +13,49 @@ # NOTE do not `import from mindsdb` here +def get_bool_env_var(env_name: str) -> bool: + """Read an environment variable and return its value as a boolean. + + Args: + env_name (str): name of the environment variable to read. + + Returns: + bool: True or False, or None if the variable is not set or empty. + + Raises: + ValueError: if the value is set but does not match any known boolean representation. + """ + value = os.environ.get(env_name) + if value is None or value == "": + return None + match value.lower(): + case "1" | "true" | "on" | "yes" | "y": + value = True + case "0" | "false" | "off" | "no" | "n": + value = False + case _: + raise ValueError(f"Expected a boolean value for the environment variable '{env_name}', but got '{value}'") + return value + + +def get_list_env_var(env_name: str) -> list[str]: + """Read an environment variable and return its value as a list of strings. + + The value is expected to be a comma-separated string. Whitespace around + each item is stripped, and empty items are ignored. + + Args: + env_name (str): name of the environment variable to read. + + Returns: + list[str]: list of non-empty strings, or None if the variable is not set or empty. + """ + value = os.environ.get(env_name) + if value is None or value.strip() == "": + return None + return [item.strip() for item in value.split(",") if item.strip()] + + def _merge_key_recursive(target_dict, source_dict, key): if key not in target_dict: target_dict[key] = source_dict[key] @@ -155,6 +198,7 @@ def __new__(cls, *args, **kwargs) -> "Config": "http_permanent_session_lifetime": datetime.timedelta(days=31), "username": "mindsdb", "password": "", + "token": None, # MINDSDB_AUTH_TOKEN }, "logging": { "handlers": { @@ -199,6 +243,26 @@ def __new__(cls, *args, **kwargs) -> "Config": "host": "0.0.0.0", # API server binds to all interfaces by default "port": "8000", }, + "mcp": { + "cors": { + "enabled": True, + "allow_origins": [], + "allow_origin_regex": r"https?://(localhost|127\.0\.0\.1)(:\d+)?", + "allow_headers": ["*"], + }, + "rate_limit": { + "enabled": False, + "requests_per_minute": 60, + }, + "oauth": { + "enabled": False, # MINDSDB_MCP_OAUTH_ENABLED + "issuer_url": "", # MINDSDB_MCP_OAUTH_ISSUER_URL + "client_id": "", # MINDSDB_MCP_OAUTH_CLIENT_ID + "client_secret": "", # MINDSDB_MCP_OAUTH_CLIENT_SECRET + "scope": "mcp:tools", # MINDSDB_MCP_OAUTH_SCOPE + }, + "dns_rebinding_protection": False, # MINDSDB_MCP_DNS_REBINDING_PROTECTION + }, }, "cache": {"type": "local"}, "ml_task_queue": {"type": "local"}, @@ -249,7 +313,10 @@ def prepare_env_config(self) -> None: """Collect config values from env vars to self._env_config""" self._env_config = { "logging": {"handlers": {"console": {}, "file": {}}}, - "api": {"http": {}}, + "api": { + "http": {}, + "mcp": {"cors": {}, "rate_limit": {}, "oauth": {}}, + }, "auth": {}, "paths": {}, "permanent_storage": {}, @@ -315,6 +382,10 @@ def prepare_env_config(self) -> None: elif http_auth_type != "": raise ValueError(f"Wrong value of env var MINDSDB_HTTP_AUTH_TYPE={http_auth_type}") + mindsdb_auth_token = os.environ.get("MINDSDB_AUTH_TOKEN", "") + if mindsdb_auth_token != "": + self._env_config["auth"]["token"] = mindsdb_auth_token + # region logging if os.environ.get("MINDSDB_LOG_LEVEL", "") != "": self._env_config["logging"]["handlers"]["console"]["level"] = os.environ["MINDSDB_LOG_LEVEL"] @@ -401,20 +472,16 @@ def prepare_env_config(self) -> None: if "default_reranking_model" not in self._env_config: self._env_config["default_reranking_model"] = {} self._env_config["default_reranking_model"].update(reranker_config) - if os.environ.get("MINDSDB_DATA_CATALOG_ENABLED", "").lower() in ("1", "true"): + if get_bool_env_var("MINDSDB_DATA_CATALOG_ENABLED") is True: self._env_config["data_catalog"] = {"enabled": True} - if os.environ.get("MINDSDB_NO_STUDIO", "").lower() in ("1", "true"): + if get_bool_env_var("MINDSDB_NO_STUDIO") is True: self._env_config["gui"]["open_on_start"] = False self._env_config["gui"]["autoupdate"] = False - mindsdb_gui_autoupdate = os.environ.get("MINDSDB_GUI_AUTOUPDATE", "").lower() - if mindsdb_gui_autoupdate in ("0", "false"): - self._env_config["gui"]["autoupdate"] = False - elif mindsdb_gui_autoupdate in ("1", "true"): - self._env_config["gui"]["autoupdate"] = True - elif mindsdb_gui_autoupdate != "": - raise ValueError(f"Wrong value of env var MINDSDB_GUI_AUTOUPDATE={mindsdb_gui_autoupdate}") + mindsdb_gui_autoupdate = get_bool_env_var("MINDSDB_GUI_AUTOUPDATE") + if mindsdb_gui_autoupdate is not None: + self._env_config["gui"]["autoupdate"] = mindsdb_gui_autoupdate if os.environ.get("MINDSDB_PID_FILE_CONTENT", "") != "": try: @@ -430,6 +497,46 @@ def prepare_env_config(self) -> None: elif mindsdb_byom_enabled != "": raise ValueError(f"Wrong value of env var MINDSDB_BYOM_ENABLED={mindsdb_byom_enabled}") + # region MCP config + mindsdb_mcp_enabled = get_bool_env_var("MINDSDB_MCP_CORS_ENABLED") + if mindsdb_mcp_enabled is not None: + self._env_config["api"]["mcp"]["cors"]["enabled"] = mindsdb_mcp_enabled + mindsdb_mcp_allow_origins = get_list_env_var("MINDSDB_MCP_ALLOW_ORIGINS") + if isinstance(mindsdb_mcp_allow_origins, list): + self._env_config["api"]["mcp"]["cors"]["allow_origins"] = mindsdb_mcp_allow_origins + mindsdb_mcp_allow_headers = get_list_env_var("MINDSDB_MCP_ALLOW_HEADERS") + if isinstance(mindsdb_mcp_allow_headers, list): + self._env_config["api"]["mcp"]["cors"]["allow_headers"] = mindsdb_mcp_allow_headers + mindsdb_mcp_allow_origin_regex = os.environ.get("MINDSDB_MCP_ALLOW_ORIGIN_REGEXP", "") + if mindsdb_mcp_allow_origin_regex != "": + self._env_config["api"]["mcp"]["cors"]["allow_origin_regex"] = mindsdb_mcp_allow_origin_regex + mindsdb_mcp_rate_limit_enabled = get_bool_env_var("MINDSDB_MCP_RATE_LIMIT_ENABLED") + if mindsdb_mcp_rate_limit_enabled is not None: + self._env_config["api"]["mcp"]["rate_limit"]["enabled"] = mindsdb_mcp_rate_limit_enabled + mindsdb_mcp_rate_limit_rpm = os.environ.get("MINDSDB_MCP_RATE_LIMIT_RPM", "") + if mindsdb_mcp_rate_limit_rpm != "": + self._env_config["api"]["mcp"]["rate_limit"]["requests_per_minute"] = int(mindsdb_mcp_rate_limit_rpm) + + mindsdb_mcp_oauth_enabled = get_bool_env_var("MINDSDB_MCP_OAUTH_ENABLED") + if mindsdb_mcp_oauth_enabled is not None: + self._env_config["api"]["mcp"]["oauth"]["enabled"] = mindsdb_mcp_oauth_enabled + mindsdb_mcp_oauth_issuer_url = os.environ.get("MINDSDB_MCP_OAUTH_ISSUER_URL", "") + if mindsdb_mcp_oauth_issuer_url != "": + self._env_config["api"]["mcp"]["oauth"]["issuer_url"] = mindsdb_mcp_oauth_issuer_url + mindsdb_mcp_oauth_client_id = os.environ.get("MINDSDB_MCP_OAUTH_CLIENT_ID", "") + if mindsdb_mcp_oauth_client_id != "": + self._env_config["api"]["mcp"]["oauth"]["client_id"] = mindsdb_mcp_oauth_client_id + mindsdb_mcp_oauth_client_secret = os.environ.get("MINDSDB_MCP_OAUTH_CLIENT_SECRET", "") + if mindsdb_mcp_oauth_client_secret != "": + self._env_config["api"]["mcp"]["oauth"]["client_secret"] = mindsdb_mcp_oauth_client_secret + mindsdb_mcp_oauth_scope = os.environ.get("MINDSDB_MCP_OAUTH_SCOPE", "") + if mindsdb_mcp_oauth_scope != "": + self._env_config["api"]["mcp"]["oauth"]["scope"] = mindsdb_mcp_oauth_scope + mindsdb_mcp_dns_rebinding_protection = get_bool_env_var("MINDSDB_MCP_DNS_REBINDING_PROTECTION") + if mindsdb_mcp_dns_rebinding_protection is not None: + self._env_config["api"]["mcp"]["dns_rebinding_protection"] = mindsdb_mcp_dns_rebinding_protection + # endregion + def fetch_auto_config(self) -> bool: """Load dict readed from config.auto.json to `auto_config`. Do it only if `auto_config` was not loaded before or config.auto.json been changed. @@ -592,6 +699,7 @@ def parse_cmd_args(self) -> None: agent=None, project=None, update_gui=False, + mcp_stdio=False, ) return @@ -618,6 +726,7 @@ def parse_cmd_args(self) -> None: parser.add_argument("--project-name", type=str, default=None, help="MindsDB project name") parser.add_argument("--update-gui", action="store_true", default=False, help="Update GUI and exit") + parser.add_argument("--mcp-stdio", action="store_true", default=False, help="Run MCP with STDIO transport") self._cmd_args = parser.parse_args() diff --git a/mindsdb/utilities/log.py b/mindsdb/utilities/log.py index 8c76ad9d4ea..eb174f54c2a 100644 --- a/mindsdb/utilities/log.py +++ b/mindsdb/utilities/log.py @@ -205,6 +205,7 @@ def get_handlers_config(process_name: str) -> dict: "class": "mindsdb.utilities.log.StreamSanitizingHandler", "formatter": console_handler_config.get("formatter", "default"), "level": console_handler_config_level, + "stream": console_handler_config.get("stream", "ext://sys.stderr"), } file_handler_config = app_config["logging"]["handlers"]["file"] diff --git a/requirements/requirements-agents.txt b/requirements/requirements-agents.txt index 233cdc35198..dbf6acc1096 100644 --- a/requirements/requirements-agents.txt +++ b/requirements/requirements-agents.txt @@ -5,7 +5,7 @@ transformers >= 4.42.4 # Required for KB mindsdb-evaluator == 0.0.21 -mcp~=1.10.1 # Required for MCP server +mcp~=1.26.0 # Required for MCP server # A2A requirements httpx==0.28.1 diff --git a/tests/unit/api/mcp/__init__.py b/tests/unit/api/mcp/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/api/mcp/test_completions.py b/tests/unit/api/mcp/test_completions.py new file mode 100644 index 00000000000..ab03ecff049 --- /dev/null +++ b/tests/unit/api/mcp/test_completions.py @@ -0,0 +1,135 @@ +""" +Unit tests for the MCP completion handler (mindsdb/api/mcp/completions.py). +""" + +import asyncio +from unittest.mock import MagicMock, patch + +from mcp.types import PromptReference, ResourceTemplateReference +from mcp.shared.memory import create_connected_server_and_client_session + +from mindsdb.api.mcp.mcp_instance import mcp + +# --------------------------------------------------------------------------- +# Patch targets +# --------------------------------------------------------------------------- + +_PATCH_GET_DB_NAMES = "mindsdb.api.mcp.completions._get_database_names" +_PATCH_CTX = "mindsdb.api.mcp.completions.ctx" +_PATCH_SESSION = "mindsdb.api.mcp.completions.SessionController" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run(coro): + return asyncio.run(coro) + + +def _complete(ref, argument: dict, context_arguments: dict | None = None) -> list[str]: + """Run a completion request and return the list of completion values.""" + + async def _inner(): + async with create_connected_server_and_client_session(mcp) as client: + result = await client.complete( + ref=ref, + argument=argument, + context_arguments=context_arguments, + ) + return result.completion.values + + return _run(_inner()) + + +_PROMPT_REF = PromptReference(type="ref/prompt", name="sample_table") +_RESOURCE_REF = ResourceTemplateReference( + type="ref/resource", + uri="schema://databases/{database_name}/tables", +) + + +def _make_table_mock(name: str) -> MagicMock: + t = MagicMock() + t.TABLE_NAME = name + return t + + +class TestDatabaseNameCompletion: + def test_returns_matching_databases(self): + with patch(_PATCH_GET_DB_NAMES, return_value=["pg_prod", "pg_staging", "mysql_db"]): + values = _complete(_PROMPT_REF, {"name": "database_name", "value": "pg"}) + + assert values == ["pg_prod", "pg_staging"] + + def test_prefix_filters_case_sensitively(self): + with patch(_PATCH_GET_DB_NAMES, return_value=["Postgres", "postgres"]): + values = _complete(_PROMPT_REF, {"name": "database_name", "value": "post"}) + + assert values == ["postgres"] + + def test_empty_prefix_returns_all_databases(self): + db_names = ["pg", "mysql", "mongo"] + with patch(_PATCH_GET_DB_NAMES, return_value=db_names): + values = _complete(_PROMPT_REF, {"name": "database_name", "value": ""}) + + assert values == db_names + + def test_no_match_returns_empty_list(self): + with patch(_PATCH_GET_DB_NAMES, return_value=["pg", "mysql"]): + values = _complete(_PROMPT_REF, {"name": "database_name", "value": "oracle"}) + + assert values == [] + + +class TestTableNameCompletion: + def test_returns_matching_tables(self): + with patch(_PATCH_SESSION) as SC: + SC.return_value.datahub.get.return_value.get_tables.return_value = [ + _make_table_mock("orders"), + _make_table_mock("order_items"), + _make_table_mock("users"), + ] + + # match 2/3 + values = _complete( + _RESOURCE_REF, + {"name": "table_name", "value": "ord"}, + context_arguments={"database_name": "pg"}, + ) + + SC.return_value.datahub.get.assert_called_with("pg") + assert values == ["orders", "order_items"] + + # match all + values = _complete( + _RESOURCE_REF, + {"name": "table_name", "value": ""}, + context_arguments={"database_name": "pg"}, + ) + + assert values == ["orders", "order_items", "users"] + + # match 0 + values = _complete( + _RESOURCE_REF, + {"name": "table_name", "value": "qwerty"}, + context_arguments={"database_name": "pg"}, + ) + + assert values == [] + + def test_missing_database_name_context_returns_empty(self): + """When database_name is not in context_arguments, return empty.""" + with patch(_PATCH_SESSION): + values = _complete( + _RESOURCE_REF, + {"name": "table_name", "value": "ord"}, + context_arguments=None, + ) + + assert values == [] + + def test_unknown_argument_name_returns_empty(self): + values = _complete(_PROMPT_REF, {"name": "unknown_param", "value": "foo"}) + assert values == [] diff --git a/tests/unit/api/mcp/test_prompts.py b/tests/unit/api/mcp/test_prompts.py new file mode 100644 index 00000000000..2e7ea7b5d60 --- /dev/null +++ b/tests/unit/api/mcp/test_prompts.py @@ -0,0 +1,45 @@ +""" +Unit tests for MCP prompts (mindsdb/api/mcp/prompts/*). + +mcp.get_prompt() is async; tests run it with asyncio.run(). +""" + +import json +import asyncio + +from mindsdb.api.mcp.mcp_instance import mcp + + +def _run(coro): + return asyncio.run(coro) + + +def _get_sample_table_prompt(database_name: str, table_name: str): + """Call sample_table prompt and return the GetPromptResult.""" + return _run(mcp.get_prompt("sample_table", {"database_name": database_name, "table_name": table_name})) + + +def _get_first_message_text(prompt: object) -> str: + """Return the text content of the first message.""" + raw = prompt.messages[0].content.text + # FastMCP serialises the TextContent to JSON inside the PromptMessage + return json.loads(raw)["text"] + + +class TestPrompt: + def test_sample_table_exists(self): + # sample_table exists and has description + prompts = _run(mcp.list_prompts()) + prompt = next(p for p in prompts if p.name == "sample_table") + assert prompt.description # non-empty + + def test_sample_table_content(self): + # test content of the prompt + result = _get_sample_table_prompt("MyDB", "mytable") + assert len(result.messages) == 1 + assert result.messages[0].role == "user" + assert result.messages[0].content.type == "text" + + text = _get_first_message_text(result) + assert "`MyDB`.`mytable`" in text + assert "limit 5" in text.lower() diff --git a/tests/unit/api/mcp/test_query_tool.py b/tests/unit/api/mcp/test_query_tool.py new file mode 100644 index 00000000000..bd4d0bcd430 --- /dev/null +++ b/tests/unit/api/mcp/test_query_tool.py @@ -0,0 +1,129 @@ +""" +Unit tests for the MCP tools (mindsdb/api/mcp/tools/*). +""" + +import asyncio +import json +from unittest.mock import patch + + +_PATCH_PROXY = "mindsdb.api.mcp.tools.query.FakeMysqlProxy" + + +def _run(coro): + """Run an async coroutine synchronously.""" + return asyncio.run(coro) + + +def _call_tool(sql: str, context=None): + """Call the MCP query tool synchronously and return parsed JSON.""" + args = {"query": sql} + if context is not None: + args["context"] = context + + from mindsdb.api.mcp.mcp_instance import mcp + + content, _ = _run(mcp.call_tool("query", args)) + return json.loads(content[0].text) + + +def _make_proxy_ok(mock_proxy_cls, affected_rows=0): + """Configure mock proxy to return an OK response.""" + mock_proxy_cls.return_value.process_query.return_value.dump_http_response.return_value = { + "type": "ok", + "affected_rows": affected_rows, + } + return mock_proxy_cls.return_value + + +def _make_proxy_table(mock_proxy_cls, column_names, data): + """Configure mock proxy to return a table response.""" + mock_proxy_cls.return_value.process_query.return_value.dump_http_response.return_value = { + "type": "table", + "column_names": column_names, + "data": data, + } + return mock_proxy_cls.return_value + + +def _make_proxy_error(mock_proxy_cls, error_message, error_code=0): + """Configure mock proxy to return an error response.""" + mock_proxy_cls.return_value.process_query.return_value.dump_http_response.return_value = { + "type": "error", + "error_code": error_code, + "error_message": error_message, + } + return mock_proxy_cls.return_value + + +class TestResponseTypes: + def test_select_returns_table_type(self): + expected_data = [[1, "alice"], [2, "bob"]] + columns_list = ["id", "name"] + with patch(_PATCH_PROXY) as MockProxy: + _make_proxy_table(MockProxy, columns_list, expected_data) + result = _call_tool("SELECT * FROM mydb.users") + + assert result["type"] == "table" + assert result["column_names"] == columns_list + assert result["data"] == expected_data + + def test_select_empty_result(self): + columns_list = ["id", "name"] + with patch(_PATCH_PROXY) as MockProxy: + _make_proxy_table(MockProxy, columns_list, []) + result = _call_tool("SELECT * FROM mydb.users WHERE 1=0") + + assert result["type"] == "table" + assert result["column_names"] == columns_list + assert result["data"] == [] + + def test_insert_returns_ok_type(self): + with patch(_PATCH_PROXY) as MockProxy: + _make_proxy_ok(MockProxy, affected_rows=1) + result = _call_tool("INSERT INTO mydb.t (id) VALUES (1)") + + assert result["type"] == "ok" + assert result["affected_rows"] == 1 + + def test_proxy_error_response_returns_error_type(self): + error_message = "Table 'x' doesn't exist" + with patch(_PATCH_PROXY) as MockProxy: + _make_proxy_error(MockProxy, error_message, error_code=123) + result = _call_tool("SELECT * FROM mydb.x") + + assert result["type"] == "error" + assert result["error_message"] == error_message + assert result["error_code"] == 123 + + def test_exception_in_process_query_returns_error_type(self): + error_message = "connection refused" + with patch(_PATCH_PROXY) as MockProxy: + MockProxy.return_value.process_query.side_effect = Exception(error_message) + result = _call_tool("SELECT 1") + + assert result["type"] == "error" + assert result["error_message"] == error_message + + +class TestContextParameter: + def test_context_is_passed_to_set_context(self): + with patch(_PATCH_PROXY) as MockProxy: + proxy = _make_proxy_ok(MockProxy) + _call_tool("SELECT 1", context={"db": "my_postgres"}) + + proxy.set_context.assert_called_once_with({"db": "my_postgres"}) + + def test_omitted_context_defaults_to_empty_dict(self): + with patch(_PATCH_PROXY) as MockProxy: + proxy = _make_proxy_ok(MockProxy) + _call_tool("SELECT 1") # no context argument + + proxy.set_context.assert_called_once_with({}) + + def test_explicit_none_context_defaults_to_empty_dict(self): + with patch(_PATCH_PROXY) as MockProxy: + proxy = _make_proxy_ok(MockProxy) + _call_tool("SELECT 1", context=None) + + proxy.set_context.assert_called_once_with({}) diff --git a/tests/unit/api/mcp/test_resources.py b/tests/unit/api/mcp/test_resources.py new file mode 100644 index 00000000000..6bac3891875 --- /dev/null +++ b/tests/unit/api/mcp/test_resources.py @@ -0,0 +1,177 @@ +""" +Unit tests for MCP resources (mindsdb/api/mcp/resources/*) +""" + +import asyncio +import json +from unittest.mock import MagicMock, patch + +import pandas as pd + +from mindsdb.integrations.libs.response import TableResponse as HandlerTableResponse +from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE + + +_PATCH_SESSION = "mindsdb.api.mcp.resources.schema.SessionController" +_PATCH_TABLE_RESPONSE = "mindsdb.api.mcp.resources.schema.TableResponse" +_PATCH_RESPONSE_TYPE = "mindsdb.api.mcp.resources.schema.RESPONSE_TYPE" + + +def _run(coro): + return asyncio.run(coro) + + +def _read(uri: str) -> list: + """Read a resource and return parsed JSON payload.""" + from mindsdb.api.mcp.mcp_instance import mcp + + contents = list(_run(mcp.read_resource(uri))) + return json.loads(contents[0].content) + + +def _make_table_mock(name: str, table_type: str = "BASE TABLE", schema: str = "public") -> MagicMock: + t = MagicMock() + t.TABLE_NAME = name + t.TABLE_TYPE = table_type + t.TABLE_SCHEMA = schema + return t + + +def _make_columns_table_response(rows: list[dict]) -> MagicMock: + """Build a mock HandlerTableResponse with COLUMNS_TABLE type.""" + tr = MagicMock(spec=HandlerTableResponse) + tr.type = RESPONSE_TYPE.COLUMNS_TABLE + tr.fetchall.return_value = pd.DataFrame(rows) + return tr + + +def _make_kb(name, project, metadata_cols=None, content_cols=None, id_col="id"): + return { + "name": name, + "project": project, + "metadata_columns": metadata_cols or [], + "content_columns": content_cols or ["body"], + "id_column": id_col, + } + + +class TestListDatabases: + def test_returns_only_data_type_databases(self): + from mindsdb.api.mcp.mcp_instance import mcp + + with patch(_PATCH_SESSION) as SC: + SC.return_value.database_controller.get_list.return_value = [ + {"name": "pg_prod", "type": "data"}, + {"name": "mindsdb", "type": "project"}, + {"name": "mysql_db", "type": "data"}, + ] + + result = list(_run(mcp.read_resource("schema://databases"))) + + assert len(result) == 1 + assert json.loads(result[0].content) == ["pg_prod", "mysql_db"] + assert result[0].mime_type == "application/json" + + def test_filters_out_all_non_data_types(self): + with patch(_PATCH_SESSION) as SC: + SC.return_value.database_controller.get_list.return_value = [ + {"name": "mindsdb", "type": "project"}, + {"name": "files", "type": "files"}, + ] + result = _read("schema://databases") + + assert result == [] + + +class TestDbTables: + def test_returns_table_names(self): + with patch(_PATCH_SESSION) as SC: + SC.return_value.datahub.get.return_value.get_tables.return_value = [ + _make_table_mock("orders"), + _make_table_mock("users"), + ] + result = _read("schema://databases/mydb/tables") + + SC.return_value.datahub.get.assert_called_once_with("mydb") + + names = [t["TABLE_NAME"] for t in result] + assert names == ["orders", "users"] + assert set(result[0].keys()) == {"TABLE_NAME", "TABLE_TYPE", "TABLE_SCHEMA"} + + def test_returns_table_type_and_schema(self): + with patch(_PATCH_SESSION) as SC: + SC.return_value.datahub.get.return_value.get_tables.return_value = [ + _make_table_mock("orders", table_type="VIEW", schema="myschema"), + ] + result = _read("schema://databases/mydb/tables") + + assert result[0]["TABLE_TYPE"] == "VIEW" + assert result[0]["TABLE_SCHEMA"] == "myschema" + + def test_empty_database_returns_empty_list(self): + with patch(_PATCH_SESSION) as SC: + SC.return_value.datahub.get.return_value.get_tables.return_value = [] + result = _read("schema://databases/emptydb/tables") + + assert result == [] + + +class TestDbTableColumns: + def test_returns_column_names_and_types(self): + rows = [ + {"COLUMN_NAME": "id", "MYSQL_DATA_TYPE": "int"}, + {"COLUMN_NAME": "email", "MYSQL_DATA_TYPE": "varchar(255)"}, + ] + with ( + patch(_PATCH_SESSION) as SC, + patch(_PATCH_TABLE_RESPONSE, HandlerTableResponse), + patch(_PATCH_RESPONSE_TYPE, RESPONSE_TYPE), + ): + SC.return_value.integration_controller.get_data_handler.return_value.get_columns.return_value = ( + _make_columns_table_response(rows) + ) + + result = _read("schema://databases/mydb/tables/orders/columns") + SC.return_value.integration_controller.get_data_handler.assert_called_once_with("mydb") + SC.return_value.integration_controller.get_data_handler.return_value.get_columns.assert_called_once_with( + "orders" + ) + + assert result[0] == {"COLUMN_NAME": "id", "MYSQL_DATA_TYPE": "int"} + assert result[1] == {"COLUMN_NAME": "email", "MYSQL_DATA_TYPE": "varchar(255)"} + + +class TestListKnowledgeBases: + def test_returns_knowledge_bases_from_all_projects(self): + with patch(_PATCH_SESSION) as SC: + SC.return_value.datahub.get_projects_names.return_value = ["mindsdb", "my_project"] + SC.return_value.kb_controller.list.side_effect = [ + [_make_kb("kb1", "mindsdb")], + [_make_kb("kb2", "my_project")], + ] + result = _read("schema://knowledge_bases") + + assert len(result) == 2 + assert result[0]["name"] == "kb1" + assert result[1]["name"] == "kb2" + + def test_returns_correct_kb_fields(self): + kb = _make_kb( + "docs_kb", + "mindsdb", + metadata_cols=["source", "date"], + content_cols=["body"], + id_col="doc_id", + ) + with patch(_PATCH_SESSION) as SC: + SC.return_value.datahub.get_projects_names.return_value = ["mindsdb"] + SC.return_value.kb_controller.list.return_value = [kb] + result = _read("schema://knowledge_bases") + + assert result[0] == { + "name": "docs_kb", + "project": "mindsdb", + "metadata_columns": ["source", "date"], + "content_columns": ["body"], + "id_column": "doc_id", + } diff --git a/tests/unit/handlers/test_file.py b/tests/unit/handlers/test_file.py index 9df2ee28415..7c54c8cbbc7 100644 --- a/tests/unit/handlers/test_file.py +++ b/tests/unit/handlers/test_file.py @@ -17,8 +17,8 @@ ) from mindsdb.integrations.handlers.file_handler.file_handler import FileHandler -from mindsdb.integrations.libs.response import RESPONSE_TYPE - +from mindsdb.integrations.libs.response import RESPONSE_TYPE, INF_SCHEMA_COLUMNS_NAMES_SET, INF_SCHEMA_COLUMNS_NAMES +from mindsdb.api.mysql.mysql_proxy.libs.constants.mysql import MYSQL_DATA_TYPE from mindsdb.integrations.utilities.files.file_reader import ( FileReader, FileProcessingError, @@ -406,8 +406,25 @@ def test_get_columns(): file_handler = FileHandler(file_controller=MockFileController()) response = file_handler.get_columns("mock") - assert response.type == RESPONSE_TYPE.TABLE - - expected_df = pandas.DataFrame([{"Field": x, "Type": "str"} for x in file_records[0][2]]) - - assert response.data_frame.equals(expected_df) + assert response.type == RESPONSE_TYPE.COLUMNS_TABLE + + data = [] + for name in file_records[0][2]: + row = {} + for key_name in INF_SCHEMA_COLUMNS_NAMES_SET: + if key_name == INF_SCHEMA_COLUMNS_NAMES.COLUMN_NAME: + row[key_name] = name + elif key_name == INF_SCHEMA_COLUMNS_NAMES.DATA_TYPE: + row[key_name] = "str" + elif key_name == INF_SCHEMA_COLUMNS_NAMES.MYSQL_DATA_TYPE: + row[key_name] = MYSQL_DATA_TYPE.TEXT + else: + row[key_name] = None + data.append(row) + + expected_df = pandas.DataFrame(data) + assert set(response.data_frame.columns) == set(expected_df.columns) + expected_df = expected_df[response.data_frame.columns] + + # Use 'compare' to ignore dtypes (object != string) + assert response.data_frame.compare(expected_df).empty From 5bcf37b70684e5676b004653bdbc5469b3803e27 Mon Sep 17 00:00:00 2001 From: Gable Date: Wed, 25 Mar 2026 13:58:25 -0400 Subject: [PATCH 088/125] Added CLA signers from the CLa branch --- .../signatures/cla.json | 1048 +++++++++++++++++ 1 file changed, 1048 insertions(+) diff --git a/assets/contributions-agreement/signatures/cla.json b/assets/contributions-agreement/signatures/cla.json index dc0e2328551..b3dabe85607 100644 --- a/assets/contributions-agreement/signatures/cla.json +++ b/assets/contributions-agreement/signatures/cla.json @@ -5831,6 +5831,1054 @@ "created_at": "2023-10-30T12:46:04Z", "repoId": 143328315, "pullRequestNo": 8163 + }, + { + "name": "minakshisharma197", + "id": 184736207, + "comment_id": 2413433683, + "created_at": "2024-10-15T09:55:40Z", + "repoId": 143328315, + "pullRequestNo": 9865 + }, + { + "name": "divyakhatiyan", + "id": 141419850, + "comment_id": 2417330560, + "created_at": "2024-10-16T16:28:34Z", + "repoId": 143328315, + "pullRequestNo": 9899 + }, + { + "name": "Sekhar-Kumar-Dash", + "id": 119131588, + "comment_id": 2419495274, + "created_at": "2024-10-17T13:05:15Z", + "repoId": 143328315, + "pullRequestNo": 9914 + }, + { + "name": "kom-senapati", + "id": 92045934, + "comment_id": 2423485137, + "created_at": "2024-10-19T02:28:49Z", + "repoId": 143328315, + "pullRequestNo": 9807 + }, + { + "name": "RiyanaD", + "id": 117534139, + "comment_id": 2420766574, + "created_at": "2024-10-17T22:54:53Z", + "repoId": 143328315, + "pullRequestNo": 9427 + }, + { + "name": "narengogi", + "id": 47327611, + "comment_id": 2296396377, + "created_at": "2024-08-19T11:55:54Z", + "repoId": 143328315, + "pullRequestNo": 9641 + }, + { + "name": "PatLittle", + "id": 31454591, + "comment_id": 2425743649, + "created_at": "2024-10-21T06:49:47Z", + "repoId": 143328315, + "pullRequestNo": 9962 + }, + { + "name": "panoskyriakis", + "id": 134383572, + "comment_id": 2317914456, + "created_at": "2024-08-29T14:39:56Z", + "repoId": 143328315, + "pullRequestNo": 9654 + }, + { + "name": "lucas-koontz", + "id": 7515210, + "comment_id": 2428585608, + "created_at": "2024-10-22T08:19:54Z", + "repoId": 143328315, + "pullRequestNo": 9976 + }, + { + "name": "Tryxns", + "id": 10586708, + "comment_id": 2433530462, + "created_at": "2024-10-23T21:51:00Z", + "repoId": 143328315, + "pullRequestNo": 9975 + }, + { + "name": "DhanushNehru", + "id": 22955675, + "comment_id": 2438155935, + "created_at": "2024-10-25T15:40:09Z", + "repoId": 143328315, + "pullRequestNo": 10047 + }, + { + "name": "TalaatHasanin", + "id": 105648065, + "comment_id": 2439488990, + "created_at": "2024-10-26T10:54:04Z", + "repoId": 143328315, + "pullRequestNo": 9726 + }, + { + "name": "AkashJana18", + "id": 103350981, + "comment_id": 2442254462, + "created_at": "2024-10-28T17:52:47Z", + "repoId": 143328315, + "pullRequestNo": 10073 + }, + { + "name": "prajwal-pai77", + "id": 108796209, + "comment_id": 2445980761, + "created_at": "2024-10-30T06:33:47Z", + "repoId": 143328315, + "pullRequestNo": 10039 + }, + { + "name": "JanumalaAkhilendra", + "id": 82641474, + "comment_id": 2446791257, + "created_at": "2024-10-30T11:43:16Z", + "repoId": 143328315, + "pullRequestNo": 10051 + }, + { + "name": "herjanice", + "id": 72483795, + "comment_id": 2370891577, + "created_at": "2024-09-24T10:33:26Z", + "repoId": 143328315, + "pullRequestNo": 9727 + }, + { + "name": "mabderrahim", + "id": 20402768, + "comment_id": 2377340466, + "created_at": "2024-09-26T15:48:00Z", + "repoId": 143328315, + "pullRequestNo": 9727 + }, + { + "name": "mohamed-abderrahim3", + "id": 183199390, + "comment_id": 2380593605, + "created_at": "2024-09-28T10:19:38Z", + "repoId": 143328315, + "pullRequestNo": 9727 + }, + { + "name": "chuangyeshuo", + "id": 14370480, + "comment_id": 2449017804, + "created_at": "2024-10-31T05:05:38Z", + "repoId": 143328315, + "pullRequestNo": 10099 + }, + { + "name": "md-abid-hussain", + "id": 101964499, + "comment_id": 2449303679, + "created_at": "2024-10-31T08:26:54Z", + "repoId": 143328315, + "pullRequestNo": 10100 + }, + { + "name": "poisonvine", + "id": 179939949, + "comment_id": 2408223847, + "created_at": "2024-10-11T23:08:39Z", + "repoId": 143328315, + "pullRequestNo": 9833 + }, + { + "name": "code-vine", + "id": 95056519, + "comment_id": 2408235943, + "created_at": "2024-10-11T23:31:03Z", + "repoId": 143328315, + "pullRequestNo": 9833 + }, + { + "name": "poisonvine", + "id": 179939949, + "comment_id": 2463687190, + "created_at": "2024-11-08T03:30:34Z", + "repoId": 143328315, + "pullRequestNo": 9833 + }, + { + "name": "vishwamartur", + "id": 64204611, + "comment_id": 2480506920, + "created_at": "2024-11-16T10:24:05Z", + "repoId": 143328315, + "pullRequestNo": 10176 + }, + { + "name": "UTSAVS26", + "id": 119779889, + "comment_id": 2482548112, + "created_at": "2024-11-18T10:15:35Z", + "repoId": 143328315, + "pullRequestNo": 10182 + }, + { + "name": "fshabashev", + "id": 6548211, + "comment_id": 2482924022, + "created_at": "2024-11-18T12:36:59Z", + "repoId": 143328315, + "pullRequestNo": 10153 + }, + { + "name": "GTgyani206", + "id": 128274569, + "comment_id": 2407637789, + "created_at": "2024-10-11T15:20:07Z", + "repoId": 143328315, + "pullRequestNo": 9832 + }, + { + "name": "QuantumPlumber", + "id": 44450703, + "comment_id": 2521508302, + "created_at": "2024-12-05T21:39:15Z", + "repoId": 143328315, + "pullRequestNo": 10243 + }, + { + "name": "Abdusshh", + "id": 101020733, + "comment_id": 2525127867, + "created_at": "2024-12-07T13:40:48Z", + "repoId": 143328315, + "pullRequestNo": 10253 + }, + { + "name": "cliffordp", + "id": 1812179, + "comment_id": 2540449382, + "created_at": "2024-12-13T03:21:48Z", + "repoId": 143328315, + "pullRequestNo": 10285 + }, + { + "name": "abhirajadhikary06", + "id": 171187625, + "comment_id": 2563775672, + "created_at": "2024-12-27T14:55:52Z", + "repoId": 143328315, + "pullRequestNo": 10331 + }, + { + "name": "jbrass", + "id": 125982, + "comment_id": 2587312474, + "created_at": "2025-01-13T14:50:21Z", + "repoId": 143328315, + "pullRequestNo": 10355 + }, + { + "name": "dj013", + "id": 47425755, + "comment_id": 2593267189, + "created_at": "2025-01-15T15:43:10Z", + "repoId": 143328315, + "pullRequestNo": 10371 + }, + { + "name": "juliette0704", + "id": 91728573, + "comment_id": 2609377887, + "created_at": "2025-01-23T10:01:31Z", + "repoId": 143328315, + "pullRequestNo": 10395 + }, + { + "name": "ivancastanop", + "id": 107499323, + "comment_id": 2598203208, + "created_at": "2025-01-17T11:55:12Z", + "repoId": 143328315, + "pullRequestNo": 10379 + }, + { + "name": "rdonato", + "id": 128521, + "comment_id": 2643683251, + "created_at": "2025-02-07T18:22:51Z", + "repoId": 143328315, + "pullRequestNo": 10444 + }, + { + "name": "SoNiC-HeRE", + "id": 96797205, + "comment_id": 2654003700, + "created_at": "2025-02-12T15:10:05Z", + "repoId": 143328315, + "pullRequestNo": 10460 + }, + { + "name": "guspan-tanadi", + "id": 36249910, + "comment_id": 2675814807, + "created_at": "2025-02-21T23:28:45Z", + "repoId": 143328315, + "pullRequestNo": 10465 + }, + { + "name": "arashaomrani", + "id": 20032520, + "comment_id": 2705110135, + "created_at": "2025-03-06T22:46:52Z", + "repoId": 143328315, + "pullRequestNo": 10544 + }, + { + "name": "kevinrawal", + "id": 84058124, + "comment_id": 2708288010, + "created_at": "2025-03-08T13:33:56Z", + "repoId": 143328315, + "pullRequestNo": 10550 + }, + { + "name": "MR901", + "id": 20877166, + "comment_id": 2788354723, + "created_at": "2025-04-09T05:54:32Z", + "repoId": 143328315, + "pullRequestNo": 10681 + }, + { + "name": "pnewsam", + "id": 22651415, + "comment_id": 2813745881, + "created_at": "2025-04-17T18:35:55Z", + "repoId": 143328315, + "pullRequestNo": 10736 + }, + { + "name": "emmanuel-ferdman", + "id": 35470921, + "comment_id": 2816053850, + "created_at": "2025-04-18T19:17:39Z", + "repoId": 143328315, + "pullRequestNo": 10739 + }, + { + "name": "Konstantinos-10", + "id": 161840728, + "comment_id": 2833463268, + "created_at": "2025-04-27T13:35:09Z", + "repoId": 143328315, + "pullRequestNo": 10761 + }, + { + "name": "NikosLaspias", + "id": 148558723, + "comment_id": 2834255670, + "created_at": "2025-04-28T07:38:11Z", + "repoId": 143328315, + "pullRequestNo": 10760 + }, + { + "name": "jzs1997", + "id": 29564670, + "comment_id": 2840686847, + "created_at": "2025-04-30T03:07:12Z", + "repoId": 143328315, + "pullRequestNo": 10776 + }, + { + "name": "HarshaVardhanMannem", + "id": 144146034, + "comment_id": 2896453670, + "created_at": "2025-05-21T03:28:49Z", + "repoId": 143328315, + "pullRequestNo": 10861 + }, + { + "name": "arun-prasath2005", + "id": 84761066, + "comment_id": 2906488930, + "created_at": "2025-05-24T06:10:22Z", + "repoId": 143328315, + "pullRequestNo": 10882 + }, + { + "name": "vmanikanta07", + "id": 117996904, + "comment_id": 2906811274, + "created_at": "2025-05-24T12:37:42Z", + "repoId": 143328315, + "pullRequestNo": 10885 + }, + { + "name": "omerc7", + "id": 32813109, + "comment_id": 2908711653, + "created_at": "2025-05-26T06:34:46Z", + "repoId": 143328315, + "pullRequestNo": 10895 + }, + { + "name": "trickster026", + "id": 212937700, + "comment_id": 2910591816, + "created_at": "2025-05-26T20:34:08Z", + "repoId": 143328315, + "pullRequestNo": 10903 + }, + { + "name": "ivanvza", + "id": 8543825, + "comment_id": 2911844022, + "created_at": "2025-05-27T09:31:36Z", + "repoId": 143328315, + "pullRequestNo": 10900 + }, + { + "name": "Joystonm", + "id": 116254639, + "comment_id": 2965183033, + "created_at": "2025-06-12T05:37:40Z", + "repoId": 143328315, + "pullRequestNo": 11070 + }, + { + "name": "noname4life", + "id": 77653287, + "comment_id": 2983573198, + "created_at": "2025-06-18T10:07:09Z", + "repoId": 143328315, + "pullRequestNo": 11117 + }, + { + "name": "D1m7asis", + "id": 80602676, + "comment_id": 2985345244, + "created_at": "2025-06-18T18:42:15Z", + "repoId": 143328315, + "pullRequestNo": 11124 + }, + { + "name": "Alex-xd", + "id": 11256006, + "comment_id": 2999207900, + "created_at": "2025-06-24T07:50:23Z", + "repoId": 143328315, + "pullRequestNo": 11160 + }, + { + "name": "PriyanshuPz", + "id": 112266318, + "comment_id": 3000590454, + "created_at": "2025-06-24T13:51:44Z", + "repoId": 143328315, + "pullRequestNo": 11163 + }, + { + "name": "rawathemant246", + "id": 99639231, + "comment_id": 2999067598, + "created_at": "2025-06-24T06:59:35Z", + "repoId": 143328315, + "pullRequestNo": 11159 + }, + { + "name": "aryanmalik-iet", + "id": 187411120, + "comment_id": 3007270696, + "created_at": "2025-06-26T06:24:35Z", + "repoId": 143328315, + "pullRequestNo": 11186 + }, + { + "name": "iabhi4", + "id": 61010675, + "comment_id": 3017197726, + "created_at": "2025-06-29T22:22:14Z", + "repoId": 143328315, + "pullRequestNo": 11212 + }, + { + "name": "dotWee", + "id": 8060356, + "comment_id": 3072932250, + "created_at": "2025-07-15T09:45:05Z", + "repoId": 143328315, + "pullRequestNo": 11300 + }, + { + "name": "buallen", + "id": 54055907, + "comment_id": 3078683990, + "created_at": "2025-07-16T13:40:32Z", + "repoId": 143328315, + "pullRequestNo": 11234 + }, + { + "name": "Raahim-Lone", + "id": 175012415, + "comment_id": 3120439531, + "created_at": "2025-07-25T21:35:44Z", + "repoId": 143328315, + "pullRequestNo": 11365 + }, + { + "name": "kaizenjinco", + "id": 78314961, + "comment_id": 3124537097, + "created_at": "2025-07-27T16:53:30Z", + "repoId": 143328315, + "pullRequestNo": 11367 + }, + { + "name": "huang-x-h", + "id": 381860, + "comment_id": 3132498852, + "created_at": "2025-07-29T13:16:29Z", + "repoId": 143328315, + "pullRequestNo": 11126 + }, + { + "name": "aperepel", + "id": 119367, + "comment_id": 3137657308, + "created_at": "2025-07-30T20:03:35Z", + "repoId": 143328315, + "pullRequestNo": 11385 + }, + { + "name": "abhayasr", + "id": 108477628, + "comment_id": 3164476409, + "created_at": "2025-08-07T14:39:49Z", + "repoId": 143328315, + "pullRequestNo": 11291 + }, + { + "name": "logan-mo", + "id": 63550599, + "comment_id": 3167373652, + "created_at": "2025-08-08T10:27:53Z", + "repoId": 143328315, + "pullRequestNo": 11414 + }, + { + "name": "kylediaz", + "id": 35979917, + "comment_id": 3180690963, + "created_at": "2025-08-12T19:21:02Z", + "repoId": 143328315, + "pullRequestNo": 11427 + }, + { + "name": "Kenxpx", + "id": 155082290, + "comment_id": 3194287003, + "created_at": "2025-08-17T10:15:06Z", + "repoId": 143328315, + "pullRequestNo": 11450 + }, + { + "name": "Nancy9ice", + "id": 103530451, + "comment_id": 3197557060, + "created_at": "2025-08-18T16:11:20Z", + "repoId": 143328315, + "pullRequestNo": 11453 + }, + { + "name": "Matvey-Kuk", + "id": 3284841, + "comment_id": 3197947416, + "created_at": "2025-08-18T18:18:26Z", + "repoId": 143328315, + "pullRequestNo": 11452 + }, + { + "name": "louisneal", + "id": 47094728, + "comment_id": 3222541351, + "created_at": "2025-08-26T04:06:55Z", + "repoId": 143328315, + "pullRequestNo": 11478 + }, + { + "name": "sejubar", + "id": 154475559, + "comment_id": 3240009269, + "created_at": "2025-08-31T09:59:19Z", + "repoId": 143328315, + "pullRequestNo": 11495 + }, + { + "name": "sudsmenon", + "id": 11342520, + "comment_id": 3250743797, + "created_at": "2025-09-03T20:48:18Z", + "repoId": 143328315, + "pullRequestNo": 11510 + }, + { + "name": "TaniyaKatigar", + "id": 214086943, + "comment_id": 3262560837, + "created_at": "2025-09-06T16:30:40Z", + "repoId": 143328315, + "pullRequestNo": 11530 + }, + { + "name": "GeorgeGithiri5", + "id": 46107866, + "comment_id": 3269367783, + "created_at": "2025-09-09T07:49:06Z", + "repoId": 143328315, + "pullRequestNo": 11541 + }, + { + "name": "gauiPPP", + "id": 43440362, + "comment_id": 3284159007, + "created_at": "2025-09-12T07:46:21Z", + "repoId": 143328315, + "pullRequestNo": 11554 + }, + { + "name": "morningman", + "id": 2899462, + "comment_id": 3293544413, + "created_at": "2025-09-15T19:07:52Z", + "repoId": 143328315, + "pullRequestNo": 11574 + }, + { + "name": "sadiqkhzn", + "id": 24961132, + "comment_id": 3312201690, + "created_at": "2025-09-19T13:26:49Z", + "repoId": 143328315, + "pullRequestNo": 11596 + }, + { + "name": "yumosx", + "id": 141902143, + "comment_id": 3322908961, + "created_at": "2025-09-23T08:21:07Z", + "repoId": 143328315, + "pullRequestNo": 11605 + }, + { + "name": "aimurphy", + "id": 36110273, + "comment_id": 3335211124, + "created_at": "2025-09-25T17:38:00Z", + "repoId": 143328315, + "pullRequestNo": 11618 + }, + { + "name": "richardokonicha", + "id": 48168290, + "comment_id": 3346750889, + "created_at": "2025-09-29T12:48:00Z", + "repoId": 143328315, + "pullRequestNo": 11552 + }, + { + "name": "vigbav36", + "id": 90998381, + "comment_id": 3361788337, + "created_at": "2025-10-02T15:24:35Z", + "repoId": 143328315, + "pullRequestNo": 11666 + }, + { + "name": "yashisthebatman", + "id": 149709821, + "comment_id": 3364470461, + "created_at": "2025-10-03T06:48:03Z", + "repoId": 143328315, + "pullRequestNo": 11676 + }, + { + "name": "survivant", + "id": 191879, + "comment_id": 3369115643, + "created_at": "2025-10-05T15:02:15Z", + "repoId": 143328315, + "pullRequestNo": 11684 + }, + { + "name": "Sai-Sravya-Thumati", + "id": 64857617, + "comment_id": 3370705793, + "created_at": "2025-10-06T09:31:16Z", + "repoId": 143328315, + "pullRequestNo": 11686 + }, + { + "name": "cclauss", + "id": 3709715, + "comment_id": 3364277206, + "created_at": "2025-10-03T05:08:38Z", + "repoId": 143328315, + "pullRequestNo": 11673 + }, + { + "name": "ParasNingune", + "id": 153178176, + "comment_id": 3388187853, + "created_at": "2025-10-10T03:48:32Z", + "repoId": 143328315, + "pullRequestNo": 11703 + }, + { + "name": "HarshitR2004", + "id": 159914116, + "comment_id": 3388359328, + "created_at": "2025-10-10T05:37:12Z", + "repoId": 143328315, + "pullRequestNo": 11704 + }, + { + "name": "Nirzak", + "id": 11460645, + "comment_id": 3393522813, + "created_at": "2025-10-11T17:20:41Z", + "repoId": 143328315, + "pullRequestNo": 11726 + }, + { + "name": "faizan842", + "id": 91795555, + "comment_id": 3407632893, + "created_at": "2025-10-15T17:55:57Z", + "repoId": 143328315, + "pullRequestNo": 11748 + }, + { + "name": "AhmadYasser1", + "id": 77586860, + "comment_id": 3419161297, + "created_at": "2025-10-19T02:48:49Z", + "repoId": 143328315, + "pullRequestNo": 11766 + }, + { + "name": "Nikhil172913832", + "id": 140622713, + "comment_id": 3443931056, + "created_at": "2025-10-24T16:13:14Z", + "repoId": 143328315, + "pullRequestNo": 11786 + }, + { + "name": "jiaqicheng1998", + "id": 65794980, + "comment_id": 3459506446, + "created_at": "2025-10-29T03:48:36Z", + "repoId": 143328315, + "pullRequestNo": 11793 + }, + { + "name": "Aashish079", + "id": 106550372, + "comment_id": 3461223031, + "created_at": "2025-10-29T12:19:16Z", + "repoId": 143328315, + "pullRequestNo": 11812 + }, + { + "name": "guddu-debasis", + "id": 167549811, + "comment_id": 3463419567, + "created_at": "2025-10-29T19:15:44Z", + "repoId": 143328315, + "pullRequestNo": 11821 + }, + { + "name": "jeis4wpi", + "id": 42679190, + "comment_id": 3467642515, + "created_at": "2025-10-30T11:55:54Z", + "repoId": 143328315, + "pullRequestNo": 11822 + }, + { + "name": "ak4shravikumar", + "id": 189372043, + "comment_id": 3469119609, + "created_at": "2025-10-30T17:15:30Z", + "repoId": 143328315, + "pullRequestNo": 11828 + }, + { + "name": "rajesh-adk-137", + "id": 89499267, + "comment_id": 3470873094, + "created_at": "2025-10-31T00:51:14Z", + "repoId": 143328315, + "pullRequestNo": 11835 + }, + { + "name": "KrishThakur23", + "id": 214495511, + "comment_id": 3475330781, + "created_at": "2025-11-01T01:05:56Z", + "repoId": 143328315, + "pullRequestNo": 11841 + }, + { + "name": "ritoban23", + "id": 124308320, + "comment_id": 3476917215, + "created_at": "2025-11-01T22:16:42Z", + "repoId": 143328315, + "pullRequestNo": 11843 + }, + { + "name": "bala-ceg", + "id": 70808619, + "comment_id": 3478836423, + "created_at": "2025-11-03T04:05:40Z", + "repoId": 143328315, + "pullRequestNo": 11844 + }, + { + "name": "HamoonDBA", + "id": 3939424, + "comment_id": 3499521731, + "created_at": "2025-11-06T21:49:51Z", + "repoId": 143328315, + "pullRequestNo": 11858 + }, + { + "name": "md-ziauddin", + "id": 29926473, + "comment_id": 3533762471, + "created_at": "2025-11-14T17:15:19Z", + "repoId": 143328315, + "pullRequestNo": 11888 + }, + { + "name": "suman-X", + "id": 137594910, + "comment_id": 3534136586, + "created_at": "2025-11-14T18:54:22Z", + "repoId": 143328315, + "pullRequestNo": 11890 + }, + { + "name": "suman-X", + "id": 137594910, + "comment_id": 3534230691, + "created_at": "2025-11-14T19:21:59Z", + "repoId": 143328315, + "pullRequestNo": 11890 + }, + { + "name": "SyedaAnshrahGillani", + "id": 90501474, + "comment_id": 3616952272, + "created_at": "2025-12-05T13:33:42Z", + "repoId": 143328315, + "pullRequestNo": 11973 + }, + { + "name": "neversettle17-101", + "id": 41864816, + "comment_id": 3620426556, + "created_at": "2025-12-06T13:56:57Z", + "repoId": 143328315, + "pullRequestNo": 11975 + }, + { + "name": "duskobogdanovski", + "id": 21080468, + "comment_id": 3656079267, + "created_at": "2025-12-15T14:55:07Z", + "repoId": 143328315, + "pullRequestNo": 12013 + }, + { + "name": "kelvinvelasquez-SDE", + "id": 112011775, + "comment_id": 3675658408, + "created_at": "2025-12-19T16:19:32Z", + "repoId": 143328315, + "pullRequestNo": 12029 + }, + { + "name": "PPeitsch", + "id": 88450637, + "comment_id": 3704693294, + "created_at": "2026-01-02T07:50:33Z", + "repoId": 143328315, + "pullRequestNo": 12048 + }, + { + "name": "SachinMyadam", + "id": 110909093, + "comment_id": 3716118688, + "created_at": "2026-01-06T20:02:10Z", + "repoId": 143328315, + "pullRequestNo": 12054 + }, + { + "name": "xuwei95", + "id": 18109811, + "comment_id": 3723114411, + "created_at": "2026-01-08T10:01:14Z", + "repoId": 143328315, + "pullRequestNo": 12063 + }, + { + "name": "Nandha-kumar-S", + "id": 85221220, + "comment_id": 3727602927, + "created_at": "2026-01-09T07:52:10Z", + "repoId": 143328315, + "pullRequestNo": 12082 + }, + { + "name": "Sweetdevil144", + "id": 117591942, + "comment_id": 3761427133, + "created_at": "2026-01-16T19:12:39Z", + "repoId": 143328315, + "pullRequestNo": 12110 + }, + { + "name": "Sriram-B-Srivatsa", + "id": 144884365, + "comment_id": 3765374596, + "created_at": "2026-01-18T14:51:54Z", + "repoId": 143328315, + "pullRequestNo": 12113 + }, + { + "name": "zhaojinxin409", + "id": 5874804, + "comment_id": 3771260955, + "created_at": "2026-01-20T06:34:45Z", + "repoId": 143328315, + "pullRequestNo": 12122 + }, + { + "name": "murataslan1", + "id": 78961478, + "comment_id": 3784602307, + "created_at": "2026-01-22T14:06:33Z", + "repoId": 143328315, + "pullRequestNo": 12004 + }, + { + "name": "C1ARKGABLE", + "id": 13039858, + "comment_id": 3792661007, + "created_at": "2026-01-23T21:53:19Z", + "repoId": 143328315, + "pullRequestNo": 11988 + }, + { + "name": "AndrewFarley", + "id": 470163, + "comment_id": 3801391357, + "created_at": "2026-01-26T19:40:00Z", + "repoId": 143328315, + "pullRequestNo": 12123 + }, + { + "name": "007slm", + "id": 1670036, + "comment_id": 3803635367, + "created_at": "2026-01-27T07:48:21Z", + "repoId": 143328315, + "pullRequestNo": 12155 + }, + { + "name": "C0staTin", + "id": 12409467, + "comment_id": 3812795861, + "created_at": "2026-01-28T17:36:00Z", + "repoId": 143328315, + "pullRequestNo": 12151 + }, + { + "name": "Amogh-2404", + "id": 114862749, + "comment_id": 3814926744, + "created_at": "2026-01-29T02:00:24Z", + "repoId": 143328315, + "pullRequestNo": 12167 + }, + { + "name": "themavik", + "id": 179817126, + "comment_id": 3936291923, + "created_at": "2026-02-20T17:50:39Z", + "repoId": 143328315, + "pullRequestNo": 12213 + }, + { + "name": "ianu82", + "id": 86010258, + "comment_id": 3973995110, + "created_at": "2026-02-27T16:55:27Z", + "repoId": 143328315, + "pullRequestNo": 12251 + }, + { + "name": "Mirza-Samad-Ahmed-Baig", + "id": 89132160, + "comment_id": 4054729064, + "created_at": "2026-03-13T12:24:17Z", + "repoId": 143328315, + "pullRequestNo": 12290 + }, + { + "name": "Krishnav1237", + "id": 147693159, + "comment_id": 4061239564, + "created_at": "2026-03-14T19:45:50Z", + "repoId": 143328315, + "pullRequestNo": 12294 + }, + { + "name": "StefanTrsunov", + "id": 91495981, + "comment_id": 4070493719, + "created_at": "2026-03-16T20:45:48Z", + "repoId": 143328315, + "pullRequestNo": 12297 + }, + { + "name": "Tzsapphire", + "id": 209363831, + "comment_id": 4106737895, + "created_at": "2026-03-22T18:27:23Z", + "repoId": 143328315, + "pullRequestNo": 12317 + }, + { + "name": "jnMetaCode", + "id": 12096460, + "comment_id": 4111619407, + "created_at": "2026-03-23T15:43:15Z", + "repoId": 143328315, + "pullRequestNo": 12279 } ] } \ No newline at end of file From ff4fb12223209f541573697c26732936a5fe7569 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Thu, 26 Mar 2026 12:45:06 +0300 Subject: [PATCH 089/125] Fix SQLQuery on_error step number handling (#12331) --- mindsdb/api/executor/sql_query/sql_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/api/executor/sql_query/sql_query.py b/mindsdb/api/executor/sql_query/sql_query.py index aae2902f713..03526f7ea22 100644 --- a/mindsdb/api/executor/sql_query/sql_query.py +++ b/mindsdb/api/executor/sql_query/sql_query.py @@ -314,7 +314,7 @@ def execute_query(self): except Exception as e: if self.run_query is not None: # set error and place where it stopped - self.run_query.on_error(e, step.step_num, self.steps_data) + self.run_query.on_error(e, step.step_num if 'step' in locals() else -1, self.steps_data) raise e else: # mark running query as completed From dc45ae559962eeb92c7cdda43f044878f56db2c2 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 2 Mar 2026 12:13:27 +0100 Subject: [PATCH 090/125] Add hubpost oauth connection helper --- .../hubspot_handler/hubspot_handler.py | 2 +- .../handlers/hubspot_handler/hubspot_oauth.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index 72af5db2a40..8148db84ea0 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -154,7 +154,7 @@ def connect(self) -> HubSpot: if not access_token or not isinstance(access_token, str): raise ValueError("Invalid access_token provided") - logger.info("Connecting to HubSpot using access token") + logger.info("Connecting to HubSpot using PAT credentials") self.connection = HubSpot(access_token=access_token) elif "client_id" in self.connection_data and "client_secret" in self.connection_data: diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py new file mode 100644 index 00000000000..6daee96d042 --- /dev/null +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py @@ -0,0 +1,32 @@ + +from hubspot import HubSpot +from hubspot.auth.oauth import ApiException as OAuthApiException +from hubspot.auth.oauth import OAuthApi +from hubspot.auth.oauth import TokenResponse + + +def oauth__connect(client_id: str, client_secret: str) -> HubSpot: + """ + Connect to HubSpot using OAuth credentials. + + Args: + client_id (str): The client ID from your HubSpot app. + client_secret (str): The client secret from your HubSpot app. + + Returns: + HubSpot: An authenticated HubSpot client instance. + + Raises: + ValueError: If authentication fails or credentials are invalid. + """ + try: + oauth_api = OAuthApi() + token_response: TokenResponse = oauth_api.create_token( + grant_type="client_credentials", + client_id=client_id, + client_secret=client_secret, + ) + access_token = token_response.access_token + return HubSpot(access_token=access_token) + except OAuthApiException as e: + raise ValueError(f"OAuth authentication failed: {e}") \ No newline at end of file From a27e710da8ec588d20698b637397dc810179c50a Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 2 Mar 2026 15:10:20 +0100 Subject: [PATCH 091/125] Update the oauth flow --- .../hubspot_handler/connection_args.py | 43 ++++- .../hubspot_handler/hubspot_handler.py | 49 ++++-- .../handlers/hubspot_handler/hubspot_oauth.py | 153 ++++++++++++++---- 3 files changed, 193 insertions(+), 52 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/connection_args.py b/mindsdb/integrations/handlers/hubspot_handler/connection_args.py index 795f84e8863..4d3c8c026d5 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/connection_args.py +++ b/mindsdb/integrations/handlers/hubspot_handler/connection_args.py @@ -1,29 +1,62 @@ from collections import OrderedDict -from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE +from mindsdb.integrations.libs.const import ( + HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE, +) connection_args = OrderedDict( access_token={ "type": ARG_TYPE.STR, - "description": "The access token for the HubSpot API. Required for direct access token authentication.", + "description": ( + "The access token for the HubSpot API. " + "Required for direct access token authentication." + ), "required": False, "label": "Access Token", }, client_id={ "type": ARG_TYPE.STR, - "description": "The client ID (consumer key) from your HubSpot app for OAuth authentication.", + "description": ( + "The client ID (consumer key) from your HubSpot app " + "for OAuth authentication." + ), "required": False, "label": "Client ID", }, client_secret={ "type": ARG_TYPE.PWD, - "description": "The client secret (consumer secret) from your HubSpot app for OAuth authentication.", + "description": ( + "The client secret (consumer secret) from your HubSpot app " + "for OAuth authentication." + ), "secret": True, "required": False, "label": "Client Secret", }, + scopes={ + "type": ARG_TYPE.STR, + "description": ( + "Space-separated OAuth scopes requested " + "for token generation." + ), + "required": False, + "label": "Scopes", + }, + redirect_uri={ + "type": ARG_TYPE.STR, + "description": ( + "Optional OAuth callback URI. Defaults to " + "http://localhost:47334/verify-auth." + ), + "required": False, + "label": "Redirect URI", + }, ) connection_args_example = OrderedDict( - access_token="your_access_token", client_id="your_client_id", client_secret="your_client_secret" + access_token="your_access_token", + client_id="your_client_id", + client_secret="your_client_secret", + scopes="crm.objects.contacts.read crm.objects.companies.read", + redirect_uri="http://localhost:47334/verify-auth", ) diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index 8148db84ea0..cf8ea6d2c60 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -34,6 +34,9 @@ from mindsdb.utilities import log from mindsdb_sql_parser import parse_sql +from mindsdb.integrations.handlers.hubspot_handler.hubspot_oauth import HubSpotOAuth2Manager +from mindsdb.integrations.utilities.handlers.auth_utilities.exceptions import AuthException + logger = log.getLogger(__name__) @@ -118,6 +121,7 @@ def __init__(self, name: str, **kwargs: Any) -> None: connection_data = kwargs.get("connection_data", {}) self.connection_data = connection_data self.kwargs = kwargs + self.handler_storage = kwargs.get("handler_storage") self.connection: Optional[HubSpot] = None self.is_connected: bool = False @@ -149,39 +153,48 @@ def connect(self) -> HubSpot: return self.connection try: - if "access_token" in self.connection_data: - access_token = self.connection_data["access_token"] - if not access_token or not isinstance(access_token, str): + access_token = self.connection_data.get("access_token") + client_id = self.connection_data.get("client_id") + client_secret = self.connection_data.get("client_secret") + + if access_token: + if not isinstance(access_token, str) or not access_token.strip(): raise ValueError("Invalid access_token provided") - logger.info("Connecting to HubSpot using PAT credentials") + logger.info("Connecting to HubSpot using access token") self.connection = HubSpot(access_token=access_token) - elif "client_id" in self.connection_data and "client_secret" in self.connection_data: - client_id = self.connection_data["client_id"] - client_secret = self.connection_data["client_secret"] - - if not client_id or not client_secret: - raise ValueError("Invalid OAuth credentials provided") - + elif client_id and client_secret: logger.info("Connecting to HubSpot using OAuth credentials") - self.connection = HubSpot(client_id=client_id, client_secret=client_secret) + oauth_manager = HubSpotOAuth2Manager( + handler_storage=self.handler_storage, + client_id=client_id, + client_secret=client_secret, + scopes=self.connection_data.get("scopes"), + redirect_uri=self.connection_data.get("redirect_uri"), + ) + self.connection = HubSpot(access_token=oauth_manager.get_access_token()) + else: raise ValueError( "Authentication credentials missing. Provide either 'access_token' " - "or both 'client_id' and 'client_secret' for OAuth authentication." + "or OAuth credentials: 'client_id' and 'client_secret'." ) self.is_connected = True logger.info("Successfully connected to HubSpot API") return self.connection - except ValueError: - logger.error("Failed to connect to HubSpot API") + except AuthException: + self.connection = None + self.is_connected = False + logger.info("HubSpot OAuth authorization required") raise except Exception as e: - logger.error("Failed to connect to HubSpot API") - raise ValueError(f"Connection to HubSpot failed: {str(e)}") + self.connection = None + self.is_connected = False + logger.error("Failed to connect to HubSpot API: %s", e) + raise ValueError(f"Connection to HubSpot failed: {e}") from e def disconnect(self) -> None: """Close connection and cleanup resources.""" @@ -215,6 +228,8 @@ def check_connection(self) -> StatusResponse: response.error_message = error_msg response.success = False + except AuthException: + raise except Exception as e: error_msg = _extract_hubspot_error_message(e) logger.error(f"HubSpot connection check failed: {error_msg}") diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py index 6daee96d042..1cd1b78704c 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py @@ -1,32 +1,125 @@ +import time +from typing import Optional +from flask import request from hubspot import HubSpot -from hubspot.auth.oauth import ApiException as OAuthApiException -from hubspot.auth.oauth import OAuthApi -from hubspot.auth.oauth import TokenResponse - - -def oauth__connect(client_id: str, client_secret: str) -> HubSpot: - """ - Connect to HubSpot using OAuth credentials. - - Args: - client_id (str): The client ID from your HubSpot app. - client_secret (str): The client secret from your HubSpot app. - - Returns: - HubSpot: An authenticated HubSpot client instance. - - Raises: - ValueError: If authentication fails or credentials are invalid. - """ - try: - oauth_api = OAuthApi() - token_response: TokenResponse = oauth_api.create_token( - grant_type="client_credentials", - client_id=client_id, - client_secret=client_secret, - ) - access_token = token_response.access_token - return HubSpot(access_token=access_token) - except OAuthApiException as e: - raise ValueError(f"OAuth authentication failed: {e}") \ No newline at end of file +from hubspot.utils.oauth import get_auth_url + +from mindsdb.utilities import log +from mindsdb.integrations.utilities.handlers.auth_utilities.exceptions import AuthException + +logger = log.getLogger(__name__) + +_STORAGE_KEY = "hubspot_oauth_tokens" +_DEFAULT_REDIRECT_PATH = "/verify-auth" +_TOKEN_EXPIRY_BUFFER = 0.95 + + +class HubSpotOAuth2Manager: + """ + Manages HubSpot OAuth2 authorization_code flow for MindsDB. + + On the first connect (no stored token, no code), raises AuthException + with the HubSpot authorization URL so MindsDB can redirect the user. + Once the user authorizes and the callback code is passed via connection_data, + the code is exchanged for tokens which are persisted in handler_storage. + Subsequent connects use the stored token, refreshing it automatically when expired. + """ + + def __init__( + self, + handler_storage, + client_id: str, + client_secret: str, + scopes: Optional[str] = None, + redirect_uri: Optional[str] = None, + code: Optional[str] = None, + ) -> None: + self.handler_storage = handler_storage + self.client_id = client_id + self.client_secret = client_secret + self.scopes = tuple(scopes.split()) if scopes else () + self.redirect_uri = redirect_uri + self.code = code + + def get_access_token(self) -> str: + """ + Return a valid HubSpot access token. + Raises: + AuthException: User authorization required; auth_url is attached. + """ + stored = self.handler_storage.encrypted_json_get(_STORAGE_KEY) + + if stored: + if time.time() < stored.get("expires_at", 0): + return stored["access_token"] + + if stored.get("refresh_token"): + try: + return self._refresh_token(stored["refresh_token"]) + except Exception as e: + logger.warning("HubSpot token refresh failed, reauthorization required: %s", e) + + runtime_code = self._get_runtime_code() + if runtime_code: + return self._exchange_code(runtime_code) + + auth_url = get_auth_url( + scope=self.scopes, + client_id=self.client_id, + redirect_uri=self._get_redirect_uri(), + ) + raise AuthException( + f"HubSpot authorization required. Please visit: {auth_url}", + auth_url=auth_url, + ) + + def _get_runtime_code(self) -> Optional[str]: + """Return the OAuth authorization code from explicit value or active request context.""" + if self.code: + return self.code + try: + return request.args.get("code") + except RuntimeError: + return None + + def _exchange_code(self, code: str) -> str: + """Exchange an authorization code for access and refresh tokens.""" + response = HubSpot().auth.oauth.tokens_api.create( + grant_type="authorization_code", + code=code, + redirect_uri=self._get_redirect_uri(), + client_id=self.client_id, + client_secret=self.client_secret, + ) + return self._persist_tokens(response) + + def _refresh_token(self, refresh_token: str) -> str: + """Obtain a new access token using the stored refresh token.""" + response = HubSpot().auth.oauth.tokens_api.create( + grant_type="refresh_token", + refresh_token=refresh_token, + redirect_uri=self._get_redirect_uri(), + client_id=self.client_id, + client_secret=self.client_secret, + ) + return self._persist_tokens(response) + + def _persist_tokens(self, token_response) -> str: + """Save token data to encrypted handler storage and return the access token.""" + tokens = { + "access_token": token_response.access_token, + "refresh_token": token_response.refresh_token, + "expires_at": time.time() + token_response.expires_in * _TOKEN_EXPIRY_BUFFER, + } + self.handler_storage.encrypted_json_set(_STORAGE_KEY, tokens) + return tokens["access_token"] + + def _get_redirect_uri(self) -> str: + if self.redirect_uri: + return self.redirect_uri + try: + origin = request.headers.get("ORIGIN", "http://localhost:47334") + except RuntimeError: + origin = "http://localhost:47334" + return origin + _DEFAULT_REDIRECT_PATH From e43b8195257540eadca29032f1335f7020f276a4 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 2 Mar 2026 16:06:18 +0100 Subject: [PATCH 092/125] Update the authentication code flow --- .../hubspot_handler/connection_args.py | 33 +++++++------------ .../hubspot_handler/hubspot_handler.py | 9 +++-- .../handlers/hubspot_handler/hubspot_oauth.py | 9 ++--- 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/connection_args.py b/mindsdb/integrations/handlers/hubspot_handler/connection_args.py index 4d3c8c026d5..e4a8b3c8be7 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/connection_args.py +++ b/mindsdb/integrations/handlers/hubspot_handler/connection_args.py @@ -7,50 +7,41 @@ connection_args = OrderedDict( access_token={ "type": ARG_TYPE.STR, - "description": ( - "The access token for the HubSpot API. " - "Required for direct access token authentication." - ), + "description": ("The access token for the HubSpot API. Required for direct access token authentication."), "required": False, "label": "Access Token", }, client_id={ "type": ARG_TYPE.STR, - "description": ( - "The client ID (consumer key) from your HubSpot app " - "for OAuth authentication." - ), + "description": ("The client ID (consumer key) from your HubSpot app for OAuth authentication."), "required": False, "label": "Client ID", }, client_secret={ "type": ARG_TYPE.PWD, - "description": ( - "The client secret (consumer secret) from your HubSpot app " - "for OAuth authentication." - ), + "description": ("The client secret (consumer secret) from your HubSpot app for OAuth authentication."), "secret": True, "required": False, "label": "Client Secret", }, - scopes={ + scope={ "type": ARG_TYPE.STR, - "description": ( - "Space-separated OAuth scopes requested " - "for token generation." - ), + "description": ("Space-separated OAuth scopes requested for token generation."), "required": False, "label": "Scopes", }, redirect_uri={ "type": ARG_TYPE.STR, - "description": ( - "Optional OAuth callback URI. Defaults to " - "http://localhost:47334/verify-auth." - ), + "description": ("Optional OAuth callback URI. Defaults to http://localhost:47334/verify-auth."), "required": False, "label": "Redirect URI", }, + code={ + "type": ARG_TYPE.STR, + "description": "OAuth authorization code returned by HubSpot after user consent. Only used within UI flow.", + "required": False, + "label": "Authorization Code", + }, ) connection_args_example = OrderedDict( diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index cf8ea6d2c60..5d33cb7f2b8 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -170,8 +170,9 @@ def connect(self) -> HubSpot: handler_storage=self.handler_storage, client_id=client_id, client_secret=client_secret, - scopes=self.connection_data.get("scopes"), + scopes=self.connection_data.get("scope"), redirect_uri=self.connection_data.get("redirect_uri"), + code=self.connection_data.get("code"), ) self.connection = HubSpot(access_token=oauth_manager.get_access_token()) @@ -228,8 +229,10 @@ def check_connection(self) -> StatusResponse: response.error_message = error_msg response.success = False - except AuthException: - raise + except AuthException as error: + response.error_message = str(error) + response.redirect_url = error.auth_url + return response except Exception as e: error_msg = _extract_hubspot_error_message(e) logger.error(f"HubSpot connection check failed: {error_msg}") diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py index 1cd1b78704c..c4c1ccd44ce 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py @@ -13,17 +13,12 @@ _STORAGE_KEY = "hubspot_oauth_tokens" _DEFAULT_REDIRECT_PATH = "/verify-auth" _TOKEN_EXPIRY_BUFFER = 0.95 +_DEFAULT_SCOPES = ("oauth",) class HubSpotOAuth2Manager: """ Manages HubSpot OAuth2 authorization_code flow for MindsDB. - - On the first connect (no stored token, no code), raises AuthException - with the HubSpot authorization URL so MindsDB can redirect the user. - Once the user authorizes and the callback code is passed via connection_data, - the code is exchanged for tokens which are persisted in handler_storage. - Subsequent connects use the stored token, refreshing it automatically when expired. """ def __init__( @@ -38,7 +33,7 @@ def __init__( self.handler_storage = handler_storage self.client_id = client_id self.client_secret = client_secret - self.scopes = tuple(scopes.split()) if scopes else () + self.scopes = tuple(scopes.split()) if scopes else _DEFAULT_SCOPES self.redirect_uri = redirect_uri self.code = code From 06139dec8e35fc401cc636077e02c9d36026d569 Mon Sep 17 00:00:00 2001 From: Konstantin Sivakov Date: Mon, 16 Mar 2026 11:39:13 +0100 Subject: [PATCH 093/125] Improve the oauth handling --- .../handlers/hubspot_handler/connection_args.py | 10 ++++++++-- .../handlers/hubspot_handler/hubspot_handler.py | 3 +++ .../handlers/hubspot_handler/hubspot_oauth.py | 12 +++++++----- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/connection_args.py b/mindsdb/integrations/handlers/hubspot_handler/connection_args.py index e4a8b3c8be7..9154946b884 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/connection_args.py +++ b/mindsdb/integrations/handlers/hubspot_handler/connection_args.py @@ -26,9 +26,15 @@ }, scope={ "type": ARG_TYPE.STR, - "description": ("Space-separated OAuth scopes requested for token generation."), + "description": "Space-separated required OAuth scopes (scope URL param). Defaults to 'oauth'.", "required": False, - "label": "Scopes", + "label": "Required Scopes", + }, + optional_scope={ + "type": ARG_TYPE.STR, + "description": "Space-separated optional OAuth scopes.", + "required": False, + "label": "Optional Scopes", }, redirect_uri={ "type": ARG_TYPE.STR, diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index 5d33cb7f2b8..1ca17e1691d 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -171,9 +171,12 @@ def connect(self) -> HubSpot: client_id=client_id, client_secret=client_secret, scopes=self.connection_data.get("scope"), + optional_scopes=self.connection_data.get("optional_scope"), redirect_uri=self.connection_data.get("redirect_uri"), code=self.connection_data.get("code"), ) + logger.info("Attempting to obtain access token via OAuth flow") + logger.debug(oauth_manager) self.connection = HubSpot(access_token=oauth_manager.get_access_token()) else: diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py index c4c1ccd44ce..857edb4a047 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py @@ -13,7 +13,6 @@ _STORAGE_KEY = "hubspot_oauth_tokens" _DEFAULT_REDIRECT_PATH = "/verify-auth" _TOKEN_EXPIRY_BUFFER = 0.95 -_DEFAULT_SCOPES = ("oauth",) class HubSpotOAuth2Manager: @@ -27,13 +26,15 @@ def __init__( client_id: str, client_secret: str, scopes: Optional[str] = None, + optional_scopes: Optional[str] = None, redirect_uri: Optional[str] = None, code: Optional[str] = None, ) -> None: self.handler_storage = handler_storage self.client_id = client_id self.client_secret = client_secret - self.scopes = tuple(scopes.split()) if scopes else _DEFAULT_SCOPES + self.scopes = tuple(scopes.split()) if scopes else ("oauth",) + self.optional_scopes = tuple(optional_scopes.split()) if optional_scopes else None self.redirect_uri = redirect_uri self.code = code @@ -44,7 +45,7 @@ def get_access_token(self) -> str: AuthException: User authorization required; auth_url is attached. """ stored = self.handler_storage.encrypted_json_get(_STORAGE_KEY) - + logger.debug(f"Retrieved stored token data: {stored}") if stored: if time.time() < stored.get("expires_at", 0): return stored["access_token"] @@ -61,6 +62,7 @@ def get_access_token(self) -> str: auth_url = get_auth_url( scope=self.scopes, + optional_scope=self.optional_scopes, client_id=self.client_id, redirect_uri=self._get_redirect_uri(), ) @@ -80,7 +82,7 @@ def _get_runtime_code(self) -> Optional[str]: def _exchange_code(self, code: str) -> str: """Exchange an authorization code for access and refresh tokens.""" - response = HubSpot().auth.oauth.tokens_api.create( + response = HubSpot().oauth.tokens_api.create( grant_type="authorization_code", code=code, redirect_uri=self._get_redirect_uri(), @@ -91,7 +93,7 @@ def _exchange_code(self, code: str) -> str: def _refresh_token(self, refresh_token: str) -> str: """Obtain a new access token using the stored refresh token.""" - response = HubSpot().auth.oauth.tokens_api.create( + response = HubSpot().oauth.tokens_api.create( grant_type="refresh_token", refresh_token=refresh_token, redirect_uri=self._get_redirect_uri(), From 1740cfee7bfe5dacb5d32c1bbf9fededbfbda24f Mon Sep 17 00:00:00 2001 From: S3TO Date: Wed, 18 Mar 2026 08:54:24 -0500 Subject: [PATCH 094/125] Fix: `check_connection` and enhance the authorization URL (#12298) --- .../hubspot_handler/hubspot_handler.py | 14 ++++++++ .../handlers/hubspot_handler/hubspot_oauth.py | 35 +++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py index 1ca17e1691d..b979a57c797 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_handler.py @@ -174,6 +174,7 @@ def connect(self) -> HubSpot: optional_scopes=self.connection_data.get("optional_scope"), redirect_uri=self.connection_data.get("redirect_uri"), code=self.connection_data.get("code"), + datasource_name=self.name, ) logger.info("Attempting to obtain access token via OAuth flow") logger.debug(oauth_manager) @@ -210,6 +211,19 @@ def check_connection(self) -> StatusResponse: """Checks whether the API client is connected to Hubspot.""" response = StatusResponse(False) + # Defer OAuth code-for-token exchange: CREATE DATABASE runs check_connection + # with ephemeral handler_storage, so tokens written here would be discarded; + # later requests then fail with BAD_AUTH_CODE. Exchange only when a request + if self.connection_data.get("code") and not self.is_connected: + from mindsdb.integrations.handlers.hubspot_handler.hubspot_oauth import _STORAGE_KEY + + if not self.handler_storage.encrypted_json_get(_STORAGE_KEY): + logger.info( + "Deferring HubSpot check_connection because OAuth code exchange must happen in a persistent context." + ) + response.success = True + return response + try: self.connect() diff --git a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py index 857edb4a047..9608055edad 100644 --- a/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py +++ b/mindsdb/integrations/handlers/hubspot_handler/hubspot_oauth.py @@ -1,4 +1,5 @@ import time +import urllib.parse from typing import Optional from flask import request @@ -29,6 +30,7 @@ def __init__( optional_scopes: Optional[str] = None, redirect_uri: Optional[str] = None, code: Optional[str] = None, + datasource_name: Optional[str] = None, ) -> None: self.handler_storage = handler_storage self.client_id = client_id @@ -37,6 +39,7 @@ def __init__( self.optional_scopes = tuple(optional_scopes.split()) if optional_scopes else None self.redirect_uri = redirect_uri self.code = code + self.datasource_name = datasource_name def get_access_token(self) -> str: """ @@ -58,14 +61,40 @@ def get_access_token(self) -> str: runtime_code = self._get_runtime_code() if runtime_code: - return self._exchange_code(runtime_code) - + try: + return self._exchange_code(runtime_code) + except Exception as e: + # OAuth codes are single-use and expire quickly. + # If the exchange fails (BAD_AUTH_CODE), don't retry — prompt re-authorization. + logger.warning("HubSpot code exchange failed (code may be expired/used): %s", e) + + redirect_uri = self._get_redirect_uri() auth_url = get_auth_url( scope=self.scopes, optional_scope=self.optional_scopes, client_id=self.client_id, - redirect_uri=self._get_redirect_uri(), + redirect_uri=redirect_uri, ) + # Fix for HubSpot's strict URL parsing. Python's URL encode translates spaces to `+`, but + # HubSpot's optional_scopes requires `%20` or `,`. + auth_url = auth_url.replace("+", "%20") + + # Append state with datasource info so the frontend can complete the connection + # even when localStorage context is missing (e.g. script-initiated flows). + if self.datasource_name: + state_data = urllib.parse.urlencode( + { + "datasource_name": self.datasource_name, + "integrations_name": "hubspot", + "client_id": self.client_id, + "client_secret": self.client_secret, + "redirect_uri": redirect_uri, + "scope": " ".join(self.scopes) if self.scopes else "oauth", + "optional_scope": " ".join(self.optional_scopes) if self.optional_scopes else "", + } + ) + auth_url += f"&state={urllib.parse.quote(state_data)}" + raise AuthException( f"HubSpot authorization required. Please visit: {auth_url}", auth_url=auth_url, From 206ece58f8c065bf81625b482a99d055d43dd2d6 Mon Sep 17 00:00:00 2001 From: S3TO Date: Thu, 26 Mar 2026 06:48:35 -0500 Subject: [PATCH 095/125] FQE-2165 - Add an endpoint for extracting default vector stores (#12283) --- mindsdb/api/http/namespaces/config.py | 19 +++- .../interfaces/knowledge_base/controller.py | 64 +++++++------ .../default_storage_resolver.py | 90 +++++++++++++++++++ mindsdb/utilities/config.py | 4 + tests/unit/api/http/config_test.py | 9 ++ .../test_default_storage_resolution.py | 79 ++++++++++++++++ tests/unit/utilities/test_config.py | 20 +++++ 7 files changed, 254 insertions(+), 31 deletions(-) create mode 100644 mindsdb/interfaces/knowledge_base/default_storage_resolver.py create mode 100644 tests/unit/api/http/config_test.py create mode 100644 tests/unit/interfaces/knowledge_base/test_default_storage_resolution.py diff --git a/mindsdb/api/http/namespaces/config.py b/mindsdb/api/http/namespaces/config.py index b31e8d9b293..da4412b7891 100644 --- a/mindsdb/api/http/namespaces/config.py +++ b/mindsdb/api/http/namespaces/config.py @@ -16,6 +16,10 @@ from mindsdb.utilities.functions import decrypt, encrypt from mindsdb.utilities.config import Config from mindsdb.integrations.libs.response import HandlerStatusResponse +from mindsdb.interfaces.knowledge_base.default_storage_resolver import ( + get_env_available_engines, + resolve_default_storage_engines, +) logger = log.getLogger(__name__) @@ -34,6 +38,11 @@ def get(self): if value is not None: resp[key] = value + knowledge_bases_config = copy.deepcopy(config["knowledge_bases"]) + knowledge_bases_config.update(resolve_default_storage_engines(config)) + knowledge_bases_config["engines"] = get_env_available_engines() + resp["knowledge_bases"] = knowledge_bases_config + api_status = get_api_status() api_configs = copy.deepcopy(config["api"]) for api_name, api_config in api_configs.items(): @@ -47,12 +56,18 @@ def get(self): def put(self): data = request.json - allowed_arguments = {"auth", "default_llm", "default_embedding_model", "default_reranking_model"} + allowed_arguments = { + "auth", + "default_llm", + "default_embedding_model", + "default_reranking_model", + "knowledge_bases", + } unknown_arguments = list(set(data.keys()) - allowed_arguments) if len(unknown_arguments) > 0: return http_error(HTTPStatus.BAD_REQUEST, "Wrong arguments", f"Unknown argumens: {unknown_arguments}") - nested_keys_to_validate = {"auth"} + nested_keys_to_validate = {"auth", "knowledge_bases"} for key in data.keys(): if key in nested_keys_to_validate: unknown_arguments = list(set(data[key].keys()) - set(Config()[key].keys())) diff --git a/mindsdb/interfaces/knowledge_base/controller.py b/mindsdb/interfaces/knowledge_base/controller.py index aa32a46419e..ad7cdb7fbad 100644 --- a/mindsdb/interfaces/knowledge_base/controller.py +++ b/mindsdb/interfaces/knowledge_base/controller.py @@ -1,4 +1,3 @@ -import os import copy from typing import Dict, List, Optional, Any, Text, Tuple, Union import json @@ -32,6 +31,7 @@ from mindsdb.interfaces.knowledge_base.preprocessing.document_preprocessor import PreprocessorFactory from mindsdb.interfaces.knowledge_base.evaluate import EvaluateBase from mindsdb.interfaces.knowledge_base.executor import KnowledgeBaseQueryExecutor +from mindsdb.interfaces.knowledge_base.default_storage_resolver import resolve_default_storage_engines from mindsdb.interfaces.model.functions import PredictorRecordNotFound from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError from mindsdb.integrations.utilities.sql_utils import FilterCondition, FilterOperator, KeywordSearchArgs @@ -1236,34 +1236,12 @@ def add( # search for the vector database table if storage is None: - cloud_pg_vector = os.environ.get("KB_PGVECTOR_URL") - if cloud_pg_vector: - vector_table_name = name - # Add sparse vector support for pgvector - vector_db_params = {} - # Check both explicit parameter and model configuration - if is_sparse: - vector_db_params["is_sparse"] = True - if vector_size is not None: - vector_db_params["vector_size"] = vector_size - vector_db_name = self._create_persistent_pgvector(vector_db_params) - params["default_vector_storage"] = vector_db_name - else: - # try faiss - module = self.session.integration_controller.get_handler_module("duckdb_faiss") - if module is None or module.Handler is None: - raise ValueError( - "Vector table is not defined. Set it by `storage=vector_db.vector_table`. " - "One of the options is to use pgvector: " - "https://docs.mindsdb.com/integrations/vector-db-integrations/pgvector" - ) - - # create faiss db with same name - vector_table_name = "data" - vector_db_name = self._create_persistent_faiss(name) - # memorize to remove it later - params["default_vector_storage"] = vector_db_name - + vector_db_name, vector_table_name = self._resolve_default_vector_storage( + kb_name=name, + is_sparse=is_sparse, + vector_size=vector_size, + ) + params["default_vector_storage"] = vector_db_name elif len(storage.parts) != 2: raise ValueError("Storage param has to be vector db with table") else: @@ -1466,6 +1444,34 @@ def _create_persistent_chroma(self, kb_name, engine="chromadb"): self.session.integration_controller.add(vector_store_name, engine, connection_args) return vector_store_name + def _resolve_default_vector_storage(self, kb_name: str, is_sparse: bool = False, vector_size: int = None): + resolved_storage = resolve_default_storage_engines(config) + default_engine = resolved_storage["default_storage"] + + if default_engine is None: + raise ValueError( + "Vector table is not defined. Set it by `storage=vector_db.vector_table` or configure " + "`knowledge_bases.storage` as one of: pgvector, faiss." + ) + + if default_engine == "pgvector": + vector_db_params = {} + if is_sparse: + vector_db_params["is_sparse"] = True + if vector_size is not None: + vector_db_params["vector_size"] = vector_size + vector_db_name = self._create_persistent_pgvector(vector_db_params) + return vector_db_name, kb_name + + if default_engine in ("duckdb_faiss", "faiss"): + vector_db_name = self._create_persistent_faiss(kb_name) + return vector_db_name, kb_name + + raise ValueError( + f"Automatic default storage creation is not supported for engine '{default_engine}'. " + "Set `storage=vector_db.vector_table` explicitly." + ) + def _check_embedding_model(self, project_name, params: dict = None, kb_name="") -> dict: """check embedding model for knowledge base, return embedding model info""" diff --git a/mindsdb/interfaces/knowledge_base/default_storage_resolver.py b/mindsdb/interfaces/knowledge_base/default_storage_resolver.py new file mode 100644 index 00000000000..93a4c364054 --- /dev/null +++ b/mindsdb/interfaces/knowledge_base/default_storage_resolver.py @@ -0,0 +1,90 @@ +import os +from typing import Any + +from mindsdb.utilities.config import config + + +def _normalize_engine_name(engine: str | None) -> str | None: + if engine is None: + return None + normalized = engine.strip().lower() + if normalized in ("duckdb_faiss", "faiss"): + return "faiss" + if normalized == "pgvector": + return "pgvector" + return normalized or None + + +def _get_env_available_engines() -> list[str]: + engines: list[str] = ["faiss"] + if os.environ.get("KB_PGVECTOR_URL"): + engines.append("pgvector") + return engines + + +def get_env_available_engines() -> list[str]: + return _get_env_available_engines() + + +def get_knowledge_base_storage_config(config_obj=None) -> str | None: + config_obj = config_obj or config + storage = config_obj.get("knowledge_bases", {}).get("storage", None) + + if storage is None: + return None + + if isinstance(storage, list): + if len(storage) == 0: + return None + storage = storage[0] + + if not isinstance(storage, str): + raise ValueError("knowledge_bases.storage must be a string value") + + return _normalize_engine_name(storage) + + +def _unique_default_first(default: str | None, ordered: list[str]) -> list[str]: + """Return `ordered` with `default` first if set, dropping later duplicates.""" + out: list[str] = [] + seen: set[str] = set() + for engine in ([default] if default else []) + ordered: + if engine not in seen: + seen.add(engine) + out.append(engine) + return out + + +def resolve_default_storage_engines(config_obj=None) -> dict[str, Any]: + configured = get_knowledge_base_storage_config(config_obj) + pgvector_enabled = os.environ.get("KB_PGVECTOR_URL") is not None + available = _get_env_available_engines() + + if configured and configured not in available: + available = [configured, *available] + + default = configured + if default is None: + default = "pgvector" if pgvector_enabled else None + if default is None and available: + default = available[0] + + candidates = _unique_default_first(default, available) + available_set = set(available) + resolved_storage = [ + { + "engine": name, + "available": name in available_set, + "default": name == default, + "source": "config" if configured == name else "fallback", + } + for name in candidates + ] + + return { + "storage": configured, + "resolved_storage": resolved_storage, + "default_storage": default, + "available_vector_engines": available, + "pgvector_enabled": pgvector_enabled, + } diff --git a/mindsdb/utilities/config.py b/mindsdb/utilities/config.py index b7187dff930..b534a4c5a98 100644 --- a/mindsdb/utilities/config.py +++ b/mindsdb/utilities/config.py @@ -290,6 +290,7 @@ def __new__(cls, *args, **kwargs) -> "Config": "knowledge_bases": { "disable_autobatch": False, "disable_pgvector_autobatch": True, + "storage": None, }, } # endregion @@ -323,6 +324,7 @@ def prepare_env_config(self) -> None: "ml_task_queue": {}, "gui": {}, "byom": {}, + "knowledge_bases": {}, } # region storage root path @@ -537,6 +539,8 @@ def prepare_env_config(self) -> None: self._env_config["api"]["mcp"]["dns_rebinding_protection"] = mindsdb_mcp_dns_rebinding_protection # endregion + # Keep env-based KB defaults out of config.auto.json overrides. + def fetch_auto_config(self) -> bool: """Load dict readed from config.auto.json to `auto_config`. Do it only if `auto_config` was not loaded before or config.auto.json been changed. diff --git a/tests/unit/api/http/config_test.py b/tests/unit/api/http/config_test.py new file mode 100644 index 00000000000..672d7d31cd0 --- /dev/null +++ b/tests/unit/api/http/config_test.py @@ -0,0 +1,9 @@ +def test_get_config_returns_knowledge_bases_storage(client): + response = client.get("/api/config/") + + assert response.status_code == 200 + payload = response.get_json() + assert "knowledge_bases" in payload + assert "storage" in payload["knowledge_bases"] + assert "available_vector_engines" in payload["knowledge_bases"] + assert "pgvector_enabled" in payload["knowledge_bases"] diff --git a/tests/unit/interfaces/knowledge_base/test_default_storage_resolution.py b/tests/unit/interfaces/knowledge_base/test_default_storage_resolution.py new file mode 100644 index 00000000000..6543ef28f4a --- /dev/null +++ b/tests/unit/interfaces/knowledge_base/test_default_storage_resolution.py @@ -0,0 +1,79 @@ +import os +from types import SimpleNamespace +from unittest.mock import MagicMock +from unittest.mock import patch + +from mindsdb.interfaces.knowledge_base.controller import KnowledgeBaseController +from mindsdb.interfaces.knowledge_base.default_storage_resolver import resolve_default_storage_engines +from mindsdb.utilities.config import config + + +def _make_controller(handler_meta_by_name): + integration_controller = MagicMock() + integration_controller.get_handler_meta.side_effect = lambda name: handler_meta_by_name.get(name) + integration_controller.get.return_value = None + + session = SimpleNamespace(integration_controller=integration_controller) + return KnowledgeBaseController(session), integration_controller + + +def test_resolve_default_vector_storage_uses_pgvector_from_config(): + previous_storage = config["knowledge_bases"].get("storage", None) + controller, _ = _make_controller({"pgvector": {"import": {"success": True}}}) + + try: + config.update({"knowledge_bases": {"storage": "pgvector"}}) + vector_db_name = "kb_pgvector_store" + controller._create_persistent_pgvector = MagicMock(return_value=vector_db_name) + + vector_db, vector_table = controller._resolve_default_vector_storage("kb_docs") + + assert vector_db == vector_db_name + assert vector_table == "kb_docs" + controller._create_persistent_pgvector.assert_called_once_with({}) + finally: + config.update({"knowledge_bases": {"storage": previous_storage}}) + + +def test_resolve_default_vector_storage_uses_faiss_from_config(): + previous_storage = config["knowledge_bases"].get("storage", None) + controller, _ = _make_controller({"duckdb_faiss": {"import": {"success": True}}}) + + try: + config.update({"knowledge_bases": {"storage": "faiss"}}) + + vector_db_name = "store_kb_docs" + controller._create_persistent_faiss = MagicMock(return_value=vector_db_name) + + vector_db, vector_table = controller._resolve_default_vector_storage("kb_docs") + + assert vector_db == vector_db_name + assert vector_table == "kb_docs" + controller._create_persistent_faiss.assert_called_once_with("kb_docs") + finally: + config.update({"knowledge_bases": {"storage": previous_storage}}) + + +def test_create_persistent_pgvector_reuses_existing_store(): + controller, integration_controller = _make_controller({}) + integration_controller.get.return_value = {"name": "kb_pgvector_store"} + + vector_store_name = controller._create_persistent_pgvector({"is_sparse": True, "vector_size": 30522}) + + assert vector_store_name == "kb_pgvector_store" + integration_controller.add.assert_not_called() + + +def test_resolver_uses_pgvector_url_fallback_when_storage_is_empty(): + previous_storage = config["knowledge_bases"].get("storage", None) + controller, _ = _make_controller({}) + + try: + config.update({"knowledge_bases": {"storage": None}}) + with patch.dict(os.environ, {"KB_PGVECTOR_URL": "postgresql://user:pass@host/db"}, clear=False): + resolved = resolve_default_storage_engines(config) + assert resolved["default_storage"] == "pgvector" + assert resolved["available_vector_engines"] == ["faiss", "pgvector"] + assert resolved["pgvector_enabled"] is True + finally: + config.update({"knowledge_bases": {"storage": previous_storage}}) diff --git a/tests/unit/utilities/test_config.py b/tests/unit/utilities/test_config.py index 88113161409..d5bd93d46b7 100644 --- a/tests/unit/utilities/test_config.py +++ b/tests/unit/utilities/test_config.py @@ -39,3 +39,23 @@ def test_invalid_mindsdb_db_con_raises_error(self): error_message = str(exc_info.value) assert "Invalid MINDSDB_DB_CON value" in error_message assert invalid_db_con in error_message + + def test_knowledge_bases_storage_env_does_not_override_storage_config(self): + Config._Config__instance = None + + with tempfile.TemporaryDirectory() as tmpdir: + config_file = Path(tmpdir) / "config.json" + config_file.write_text(json.dumps({})) + + with patch.dict( + os.environ, + { + "MINDSDB_CONFIG_PATH": str(config_file), + "MINDSDB_STORAGE_DIR": tmpdir, + "KNOWLEDGE_BASES_STORAGE": "faiss, pgvector", + }, + clear=False, + ): + cfg = Config() + + assert cfg["knowledge_bases"]["storage"] is None From f059876a46af89ea33cf0176d3e21c62d39a807d Mon Sep 17 00:00:00 2001 From: Kan Lu <54055907+buallen@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:56:56 +0800 Subject: [PATCH 096/125] feat: Show all schemas by default in database tree view (#11234) --- mindsdb/api/http/namespaces/tree.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindsdb/api/http/namespaces/tree.py b/mindsdb/api/http/namespaces/tree.py index e9e1ee25fa7..87e03225dfd 100644 --- a/mindsdb/api/http/namespaces/tree.py +++ b/mindsdb/api/http/namespaces/tree.py @@ -39,7 +39,8 @@ def get(self, db_name): if isinstance(with_schemas, str): with_schemas = with_schemas.lower() in ("1", "true") else: - with_schemas = False + # Show all schemas by default for better UX + with_schemas = True db_name = db_name.lower() databases = ca.database_controller.get_dict() if db_name not in databases: From ac5e474a329fbe49c9d2e17eb4d33a2b72e68829 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 26 Mar 2026 15:11:24 +0300 Subject: [PATCH 097/125] Bump flask from 3.0.3 to 3.1.3 in /requirements (#12238) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Max Stepanov --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 737d9e8b18c..1d17d9c4f64 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ packaging -flask == 3.0.3 +flask == 3.1.3 werkzeug == 3.1.6 flask-restx >= 1.3.0, < 2.0.0 pandas == 2.2.3 From e9b04cf999ac7cb3947869a29e9e13b25ece60c9 Mon Sep 17 00:00:00 2001 From: Raahim Lone Date: Thu, 26 Mar 2026 09:24:00 -0400 Subject: [PATCH 098/125] Docs - BYOM fix predict example (#12033) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michael Olayemi Olawepo <154475559+sejubar@users.noreply.github.com> Co-authored-by: andrew Co-authored-by: April I. Murphy <36110273+aimurphy@users.noreply.github.com> Co-authored-by: Minura Punchihewa <49385643+MinuraPunchihewa@users.noreply.github.com> Co-authored-by: Konstantin Sivakov Co-authored-by: martyna-mindsdb <109554435+martyna-mindsdb@users.noreply.github.com> Co-authored-by: Sebastián Tobón Hernández --- docs/integrations/ai-engines/byom.mdx | 63 +++++++++++++++------------ 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/docs/integrations/ai-engines/byom.mdx b/docs/integrations/ai-engines/byom.mdx index 736af426317..ffdc564ff57 100644 --- a/docs/integrations/ai-engines/byom.mdx +++ b/docs/integrations/ai-engines/byom.mdx @@ -25,7 +25,7 @@ Let's briefly go over the files that need to be uploaded: ```py class CustomPredictor(): - ​ + def train(self, df, target_col, args=None): return '' @@ -39,38 +39,41 @@ Let's briefly go over the files that need to be uploaded: ```py import os import pandas as pd - ​ + from sklearn.cross_decomposition import PLSRegression from sklearn import preprocessing - ​ + class CustomPredictor(): - ​ + def train(self, df, target_col, args=None): print(args, '1111') - ​ + self.target_col = target_col y = df[self.target_col] x = df.drop(columns=self.target_col) x_cols = list(x.columns) - ​ + x_scaler = preprocessing.StandardScaler().fit(x) y_scaler = preprocessing.StandardScaler().fit(y.values.reshape(-1, 1)) - ​ + xs = x_scaler.transform(x) ys = y_scaler.transform(y.values.reshape(-1, 1)) - ​ + pls = PLSRegression(n_components=1) pls.fit(xs, ys) - ​ + + self.pls = pls + self.y_scaler = y_scaler + T = pls.x_scores_ W = pls.x_weights_ P = pls.x_loadings_ R = pls.x_rotations_ - ​ + self.x_cols = x_cols self.x_scaler = x_scaler self.P = P - ​ + def calc_limit(df): res = None for column in df.columns: @@ -89,32 +92,32 @@ Let's briefly go over the files that need to be uploaded: except: res = tbl return res - ​ + trdf = pd.DataFrame() trdf[self.target_col] = y.values trdf['T1'] = T.squeeze() limit = calc_limit(trdf).reset_index() - ​ + self.limit = limit - ​ + return "Trained predictor ready to be stored" - ​ + def predict(self, df): - ​ - yt = df[self.target_col].values + + xt = df[self.x_cols] - ​ + xt = self.x_scaler.transform(xt) - ​ + excess_cols = list(set(df.columns) - set(self.x_cols)) - ​ + pred_df = df[excess_cols].copy() - ​ - pred_df[self.target_col] = yt + + ys_pred = self.pls.predict(xt) + y_pred = self.y_scaler.inverse_transform(ys_pred).ravel() + pred_df[self.target_col] = y_pred + pred_df['T1'] = (xt @ self.P).squeeze() - ​ - pred_df = pd.merge(pred_df, self.limit[[self.target_col, 'lower', 'upper']], how='left', on=self.target_col) - ​ return pred_df ``` @@ -195,12 +198,14 @@ USING ENGINE = 'custom_model_engine'; ``` -Let's query for predictions by joining the custom model with the data table. +Let's query for predictions by joining the custom model with the data table. Please note that when querying for predictions, do not include the target column in the `input` data selection. ```sql -SELECT input.feature_column, model_target_column -FROM my_integration.my_table as input -JOIN custom_model as model; +SELECT + input.feature_column, + model.target AS predicted_target +FROM my_integration.my_table AS input +JOIN custom_model AS model; ``` From 8d56367310cf764738ea5b72c48c191e2e6bb275 Mon Sep 17 00:00:00 2001 From: Andrey Date: Thu, 26 Mar 2026 17:03:11 +0300 Subject: [PATCH 099/125] KB unit tests (#12321) --- .../duckdb_faiss_handler.py | 2 +- .../duckdb_faiss_table.py | 4 +- .../duckdb_faiss_handler/faiss_index.py | 18 +- tests/unit/executor/test_knowledge_base.py | 164 +++++++++++++++++- 4 files changed, 173 insertions(+), 15 deletions(-) diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py index 1cabd09ae79..22153163ae8 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_handler.py @@ -246,7 +246,7 @@ def get_tables(self) -> Response: # -- table methods -- - def create_index(self, table_name: str, type: str = "ivf_file", nlist: int = None, train_count: int = None): + def create_index(self, table_name: str, type: str = None, nlist: int = None, train_count: int = None): with self.open_table(table_name) as table: table.create_index(type=type, nlist=nlist, train_count=train_count) diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py index b8a5324ad63..526fd2b3ff7 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py @@ -80,9 +80,7 @@ def _sync(self, dump_faiss=True): if self.handler._use_handler_storage: self.handler.handler_storage.folder_sync(self.table_name) - def create_index(self, type: str = "ivf_file", nlist: int = None, train_count: int = None): - if type not in ("ivf", "ivf_file"): - raise NotImplementedError("Only ivf or ivf_file indexes are supported") + def create_index(self, type: str = None, nlist: int = None, train_count: int = None): self.faiss_index.create_index(type, nlist=nlist, train_count=train_count) # index was already saved. don't dump it twice self._sync(dump_faiss=False) diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py index 3673a982f6f..66ab6a64a22 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py @@ -503,11 +503,11 @@ def get_size(self): return self.index.ntotal def check_required_disk_space(self, index_type): - available = psutil.disk_usage(self.path).free + base_path = Path(self.path).parent + available = psutil.disk_usage(str(base_path)).free # current size of index index_size = 0 - base_path = Path(self.path).parent for item in base_path.iterdir(): if item.is_dir() or not item.name.startswith("faiss_index"): continue @@ -526,7 +526,7 @@ def check_required_disk_space(self, index_type): to_free_gb = round((index_size * (k - 1)) / 1024**3, 2) raise ValueError(f"Unable run indexing FAISS not enough disk space, get free at least : {to_free_gb} Gb") - def create_index(self, index_type, nlist=None, train_count=None): + def create_index(self, index_type=None, nlist=None, train_count=None): """ Create or recreate IVF index @@ -536,6 +536,18 @@ def create_index(self, index_type, nlist=None, train_count=None): """ + if index_type is None: + if os.name == "nt": + index_type = "ivf" + else: + index_type = "ivf_file" + + elif index_type not in ("ivf", "ivf_file"): + raise NotImplementedError("Only ivf or ivf_file indexes are supported") + + if index_type == "ivf_file" and os.name == "nt": + raise ValueError("'ivf_file' index is not supported on Windows. Try to use 'ivf' instead") + # index might not fit into RAM, extract data to files base_path = Path(self.path).parent dump_path = base_path / "dump" diff --git a/tests/unit/executor/test_knowledge_base.py b/tests/unit/executor/test_knowledge_base.py index 1155bd1d1e0..991166e45ab 100644 --- a/tests/unit/executor/test_knowledge_base.py +++ b/tests/unit/executor/test_knowledge_base.py @@ -1,6 +1,7 @@ import time import json import tempfile +import datetime as dt from unittest.mock import patch, MagicMock import threading @@ -31,12 +32,13 @@ def task_monitor(): worker.join() -def dummy_embeddings(string, dimension=None): +def dummy_embeddings(string, dimension=None, base=None): # Imitates embedding generation: create vectors which are similar for similar words in inputs if dimension is None: dimension = 25**2 embeds = [0] * dimension - base = 25 + if base is None: + base = 25 string = string.lower().replace(",", " ").replace(".", " ") for word in string.split(): @@ -59,9 +61,9 @@ def dummy_embeddings(string, dimension=None): return embeds -def set_embedding(mock_embedding, dimension=None): +def set_embedding(mock_embedding, dimension=None, base=None): def resp_f(input, *args, **kwargs): - return [dummy_embeddings(s, dimension) for s in input] + return [dummy_embeddings(s, dimension, base) for s in input] mock_embedding().embeddings.side_effect = resp_f @@ -165,7 +167,7 @@ def _get_ral_table(self): return pd.DataFrame(data, columns=["ral", "english", "italian"]) -class TestKB(BaseTestKB): +class TestKBNOAutoBatch(BaseTestKB): def setup_method(self): super().setup_method() from mindsdb.utilities.config import config @@ -479,7 +481,6 @@ def test_kb_partitions(self, mock_handler, mock_embedding): set_embedding(mock_embedding) df = self._get_ral_table() - self.save_file("ral", df) df = pd.concat([df] * 30) # unique ids @@ -626,8 +627,8 @@ def test_kb_algebra(self, mock_embedding): for size in ("big", "middle", "small"): for shape in ("square", "triangle", "circle"): i += 1 - lines.append([i, i, f"{color} {size} {shape}", color, size, shape]) - df = pd.DataFrame(lines, columns=["id", "num", "content", "color", "size", "shape"]) + lines.append([i, i, f"{color} {size} {shape}", color, size, shape, dt.date(2000, 1, i)]) + df = pd.DataFrame(lines, columns=["id", "num", "content", "color", "size", "shape", "valid_date"]) self.save_file("items", df) @@ -736,6 +737,43 @@ def test_kb_algebra(self, mock_embedding): else: assert "small" in content + # -- metadata: like, not like + for query in ("trian%", "%riangl%", "%angle"): + ret = self.run_sql(f"select * from kb_alg where shape like '{query}'") + + # only triangle + assert set(ret["shape"]) == {"triangle"} + + # -- metadata: '>=', '>', '<=', '<' + + ret = self.run_sql("select * from kb_alg where color > 'red'") + # only white + assert set(ret["color"]) == {"white"} + + ret = self.run_sql("select * from kb_alg where color < 'red'") + # only green + assert set(ret["color"]) == {"green"} + + ret = self.run_sql("select * from kb_alg where color <= 'red' and color > 'green'") + # only red + assert set(ret["color"]) == {"red"} + + # filter by int + ret = self.run_sql("select * from kb_alg where num >= 10") + assert ret["num"].min() == 10 + + # filter by date + ret = self.run_sql("select * from kb_alg where valid_date >= '2000-01-15'") + assert ret["valid_date"].min() > "2000-01-14" and ret["valid_date"].min() < "2000-01-16" + + ret = self.run_sql("select * from kb_alg where valid_date < '2000-01-15'") + assert ret["valid_date"].max() > "2000-01-13" and ret["valid_date"].min() < "2000-01-15" + + # -- filter by id and content + ret = self.run_sql("select * from kb_alg where content = 'green' and id < 22") + assert ret["color"][0] == "green" + assert ret["id"].max() < 22 + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") def test_select_allowed_columns(self, mock_embedding): set_embedding(mock_embedding) @@ -1031,6 +1069,11 @@ def test_alter_kb(self, mock_embedding, mock_get_scores): assert kb.params["reranking_model"]["provider"] == "ollama" assert "api_key" not in kb.params["reranking_model"] + # disable reranking model and ensure config is cleared + self.run_sql("ALTER KNOWLEDGE BASE kb1 USING reranking_model = false") + kb = self.db.KnowledgeBase.query.filter_by(name="kb1").first() + assert kb.params["reranking_model"] == {} + @patch("mindsdb.integrations.utilities.rag.rerankers.base_reranker.BaseLLMReranker.get_scores") @patch("mindsdb.interfaces.knowledge_base.llm_client.OpenAI") def test_ollama(self, mock_openai, mock_get_scores): @@ -1204,6 +1247,111 @@ def test_update(self, mock_embedding): assert len(ret) == 1 assert ret["chunk_content"][0] == "dog" + @patch("mindsdb.integrations.utilities.rag.rerankers.base_reranker.BaseLLMReranker.get_scores") + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_reranking(self, mock_embedding, mock_get_scores): + set_embedding(mock_embedding) + + self._create_kb( + "kb_ral", + content_columns=["english"], + reranking_model={ + "provider": "openai", + "model_name": "gpt-3", + "api_key": "embed-key-1", + }, + ) + + df = self._get_ral_table() + self.save_file("ral", df) + + self.run_sql( + """ + insert into kb_ral + select * from files.ral + """ + ) + + # rank from greater to lower + mock_get_scores.side_effect = lambda query, docs: [1 - i / 4 for i in range(len(docs))] + ret = self.run_sql("select * from kb_ral where content='white'") + assert "white" in ret["chunk_content"].iloc[0] + + # reverse rank: from lower to greater. the most semantic result have to be moved back + mock_get_scores.side_effect = lambda query, docs: [i / 4 for i in range(len(docs))] + ret = self.run_sql("select * from kb_ral where content='white'") + assert "white" not in ret["chunk_content"].iloc[0] + + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_hybrid_search(self, mock_embedding): + df = self._get_ral_table() + self.save_file("ral", df) + + set_embedding(mock_embedding) + + self._create_kb("kb_hybrid", content_columns=["english"]) + + self.run_sql("insert into kb_hybrid select * from files.ral") + + # changing embedding config, making semantic search irrelevant + set_embedding(mock_embedding, base=20) + + # white is not at the top + ret = self.run_sql("select * from kb_hybrid where content='white'") + assert "white" not in ret["chunk_content"].iloc[0] + + # but it is when hybrid search is used + ret = self.run_sql(""" + select * from kb_hybrid where content='white' + and hybrid_search_alpha = 0 + """) + assert "white" in ret["chunk_content"].iloc[0] + + # checking alpha=0.5 + ret = self.run_sql(""" + select * from kb_hybrid where content='white' + and hybrid_search = true + """) + assert "white" in ret["chunk_content"].iloc[0] + + # @pytest.mark.slow + @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") + def test_create_index(self, mock_embedding): + set_embedding(mock_embedding) + + df = self._get_ral_table() + + df = pd.concat([df] * 30) + # unique ids + df["id"] = list(map(str, range(len(df)))) + self.save_file("ral", df) + + # create kb, fill it + self._create_kb("kb_ral", content_columns=["english"]) + + self.run_sql("insert into kb_ral select * from files.ral") + + # create index default index (ivf_file, for windows it is ivf) + self.run_sql( + """ + CREATE INDEX ON KNOWLEDGE_BASE kb_ral WITH (nlist=1) + """ + ) + + # check kb works after index was created + ret = self.run_sql("select * from kb_ral where content='white'") + assert "white" in ret["chunk_content"].iloc[0] + + # specified index + self.run_sql( + """ + CREATE INDEX ON KNOWLEDGE_BASE kb_ral + WITH (nlist=1, type='ivf', train_count=50) + """ + ) + ret = self.run_sql("select * from kb_ral where content='white'") + assert "white" in ret["chunk_content"].iloc[0] + class TestKBAutoBatch(BaseTestKB): @patch("mindsdb.interfaces.knowledge_base.controller.LLMClient") From 63c99c9f9b52db1e2b15d9816104650bd579a6ad Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 26 Mar 2026 18:35:17 +0300 Subject: [PATCH 100/125] ruff --- mindsdb/api/executor/sql_query/sql_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/api/executor/sql_query/sql_query.py b/mindsdb/api/executor/sql_query/sql_query.py index 03526f7ea22..0ec9e58a872 100644 --- a/mindsdb/api/executor/sql_query/sql_query.py +++ b/mindsdb/api/executor/sql_query/sql_query.py @@ -314,7 +314,7 @@ def execute_query(self): except Exception as e: if self.run_query is not None: # set error and place where it stopped - self.run_query.on_error(e, step.step_num if 'step' in locals() else -1, self.steps_data) + self.run_query.on_error(e, step.step_num if "step" in locals() else -1, self.steps_data) raise e else: # mark running query as completed From bdee9de42c7acc2156410c7219fb2b0c5602fdb4 Mon Sep 17 00:00:00 2001 From: Sriram-B-Srivatsa <144884365+Sriram-B-Srivatsa@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:55:57 +0530 Subject: [PATCH 101/125] feat(ollama): add support for temperature parameter (#12114) Co-authored-by: Andrey --- .../handlers/ollama_handler/__about__.py | 16 +- .../handlers/ollama_handler/__init__.py | 12 +- .../handlers/ollama_handler/ollama_handler.py | 151 ++++++++++-------- .../tests/test_ollama_handler.py | 45 ++++++ 4 files changed, 145 insertions(+), 79 deletions(-) create mode 100644 mindsdb/integrations/handlers/ollama_handler/tests/test_ollama_handler.py diff --git a/mindsdb/integrations/handlers/ollama_handler/__about__.py b/mindsdb/integrations/handlers/ollama_handler/__about__.py index d379f39e148..37799994782 100644 --- a/mindsdb/integrations/handlers/ollama_handler/__about__.py +++ b/mindsdb/integrations/handlers/ollama_handler/__about__.py @@ -1,9 +1,9 @@ -__title__ = 'MindsDB Ollama handler' -__package_name__ = 'mindsdb_ollama_handler' -__version__ = '0.0.1' +__title__ = "MindsDB Ollama handler" +__package_name__ = "mindsdb_ollama_handler" +__version__ = "0.0.1" __description__ = "MindsDB handler for Ollama" -__author__ = 'MindsDB Inc' -__github__ = 'https://github.com/mindsdb/mindsdb' -__pypi__ = 'https://pypi.org/project/mindsdb/' -__license__ = 'MIT' -__copyright__ = 'Copyright 2023- mindsdb' +__author__ = "MindsDB Inc" +__github__ = "https://github.com/mindsdb/mindsdb" +__pypi__ = "https://pypi.org/project/mindsdb/" +__license__ = "MIT" +__copyright__ = "Copyright 2023- mindsdb" diff --git a/mindsdb/integrations/handlers/ollama_handler/__init__.py b/mindsdb/integrations/handlers/ollama_handler/__init__.py index 806f750edb9..eea6a1903d6 100644 --- a/mindsdb/integrations/handlers/ollama_handler/__init__.py +++ b/mindsdb/integrations/handlers/ollama_handler/__init__.py @@ -1,19 +1,19 @@ from mindsdb.integrations.libs.const import HANDLER_TYPE from .__about__ import __version__ as version, __description__ as description + try: from .ollama_handler import OllamaHandler as Handler + import_error = None except Exception as e: Handler = None import_error = e -title = 'Ollama' -name = 'ollama' +title = "Ollama" +name = "ollama" type = HANDLER_TYPE.ML -icon_path = 'icon.png' +icon_path = "icon.png" permanent = False -__all__ = [ - 'Handler', 'version', 'name', 'type', 'title', 'description', 'import_error', 'icon_path' -] +__all__ = ["Handler", "version", "name", "type", "title", "description", "import_error", "icon_path"] diff --git a/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py b/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py index 5b03b2a1f68..74923d640a8 100644 --- a/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py +++ b/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py @@ -14,38 +14,40 @@ class OllamaHandler(BaseMLEngine): @staticmethod def create_validation(target, args=None, **kwargs): - if 'using' not in args: + if "using" not in args: raise Exception("Ollama engine requires a USING clause! Refer to its documentation for more details.") else: - args = args['using'] + args = args["using"] - if 'model_name' not in args: - raise Exception('`model_name` must be provided in the USING clause.') + if "model_name" not in args: + raise Exception("`model_name` must be provided in the USING clause.") # check ollama service health - connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) - status = requests.get(connection + '/api/tags').status_code + connection = args.get("ollama_serve_url", OllamaHandler.DEFAULT_SERVE_URL) + status = requests.get(connection + "/api/tags").status_code if status != 200: - raise Exception(f"Ollama service is not working (status `{status}`). Please double check it is running and try again.") # noqa + raise Exception( + f"Ollama service is not working (status `{status}`). Please double check it is running and try again." + ) # noqa def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: - """ Pull LLM artifacts with Ollama API. """ + """Pull LLM artifacts with Ollama API.""" # arg setter - args = args['using'] - args['target'] = target - connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) + args = args["using"] + args["target"] = target + connection = args.get("ollama_serve_url", OllamaHandler.DEFAULT_SERVE_URL) def _model_check(): - """ Checks model has been pulled and that it works correctly. """ + """Checks model has been pulled and that it works correctly.""" responses = {} - for endpoint in ['generate', 'embeddings']: + for endpoint in ["generate", "embeddings"]: try: code = requests.post( - connection + f'/api/{endpoint}', + connection + f"/api/{endpoint}", json={ - 'model': args['model_name'], - 'prompt': 'Hello.', - } + "model": args["model_name"], + "prompt": "Hello.", + }, ).status_code responses[endpoint] = code except Exception: @@ -57,19 +59,21 @@ def _model_check(): if 200 not in responses.values(): # pull model (blocking operation) and serve # TODO: point to the engine storage folder instead of default location - connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) - requests.post(connection + '/api/pull', json={'name': args['model_name']}) + connection = args.get("ollama_serve_url", OllamaHandler.DEFAULT_SERVE_URL) + requests.post(connection + "/api/pull", json={"name": args["model_name"]}) # try one last time responses = _model_check() if 200 not in responses.values(): - raise Exception(f"Ollama model `{args['model_name']}` is not working correctly. Please try pulling this model manually, check it works correctly and try again.") # noqa + raise Exception( + f"Ollama model `{args['model_name']}` is not working correctly. Please try pulling this model manually, check it works correctly and try again." + ) # noqa supported_modes = {k: True if v == 200 else False for k, v in responses.items()} # check if a mode has been provided and if it is valid runnable_modes = [mode for mode, supported in supported_modes.items() if supported] - if 'mode' in args: - if args['mode'] not in runnable_modes: + if "mode" in args: + if args["mode"] not in runnable_modes: raise Exception(f"Mode `{args['mode']}` is not supported by the model `{args['model_name']}`.") # if a mode has not been provided, check if the model supports only one mode @@ -77,11 +81,11 @@ def _model_check(): # if it supports multiple modes, set the default mode to 'generate' else: if len(runnable_modes) == 1: - args['mode'] = runnable_modes[0] + args["mode"] = runnable_modes[0] else: - args['mode'] = 'generate' + args["mode"] = "generate" - self.model_storage.json_set('args', args) + self.model_storage.json_set("args", args) def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: """ @@ -93,50 +97,63 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame pd.DataFrame: The DataFrame containing row-wise text completions. """ # setup - pred_args = args.get('predict_params', {}) - args = self.model_storage.json_get('args') - model_name, target_col = args['model_name'], args['target'] - prompt_template = pred_args.get('prompt_template', - args.get('prompt_template', 'Answer the following question: {{{{text}}}}')) + pred_args = args.get("predict_params", {}) + args = self.model_storage.json_get("args") + model_name, target_col = args["model_name"], args["target"] + prompt_template = pred_args.get( + "prompt_template", args.get("prompt_template", "Answer the following question: {{{{text}}}}") + ) # prepare prompts prompts, empty_prompt_ids = get_completed_prompts(prompt_template, df) - df['__mdb_prompt'] = prompts + df["__mdb_prompt"] = prompts # setup endpoint - endpoint = args.get('mode', 'generate') + endpoint = args.get("mode", "generate") # call llm completions = [] for i, row in df.iterrows(): if i not in empty_prompt_ids: - connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) + temperature = pred_args.get("temperature", args.get("temperature")) + + # Options dictionary + options = {} + if temperature is not None: + try: + options["temperature"] = float(temperature) + except ValueError: + pass + + # Calling API with the new options + connection = args.get("ollama_serve_url", OllamaHandler.DEFAULT_SERVE_URL) raw_output = requests.post( - connection + f'/api/{endpoint}', + connection + f"/api/{endpoint}", json={ - 'model': model_name, - 'prompt': row['__mdb_prompt'], - } + "model": model_name, + "prompt": row["__mdb_prompt"], + "options": options, # options passed here + }, ) - lines = raw_output.content.decode().split('\n') # stream of output tokens + lines = raw_output.content.decode().split("\n") # stream of output tokens values = [] for line in lines: - if line != '': + if line != "": info = json.loads(line) - if 'response' in info: - token = info['response'] + if "response" in info: + token = info["response"] values.append(token) - elif 'embedding' in info: - embedding = info['embedding'] + elif "embedding" in info: + embedding = info["embedding"] values.append(embedding) - if endpoint == 'embeddings': + if endpoint == "embeddings": completions.append(values) else: - completions.append(''.join(values)) + completions.append("".join(values)) else: - completions.append('') + completions.append("") # consolidate output data = pd.DataFrame(completions) @@ -144,28 +161,32 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame return data def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: - args = self.model_storage.json_get('args') - model_name, target_col = args['model_name'], args['target'] - prompt_template = args.get('prompt_template', 'Answer the following question: {{{{text}}}}') + args = self.model_storage.json_get("args") + model_name, target_col = args["model_name"], args["target"] + prompt_template = args.get("prompt_template", "Answer the following question: {{{{text}}}}") if attribute == "features": - return pd.DataFrame([[target_col, prompt_template]], columns=['target_column', 'mindsdb_prompt_template']) + return pd.DataFrame([[target_col, prompt_template]], columns=["target_column", "mindsdb_prompt_template"]) # get model info else: - connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) - model_info = requests.post(connection + '/api/show', json={'name': model_name}).json() - return pd.DataFrame([[ - model_name, - model_info.get('license', 'N/A'), - model_info.get('modelfile', 'N/A'), - model_info.get('parameters', 'N/A'), - model_info.get('template', 'N/A'), - ]], + connection = args.get("ollama_serve_url", OllamaHandler.DEFAULT_SERVE_URL) + model_info = requests.post(connection + "/api/show", json={"name": model_name}).json() + return pd.DataFrame( + [ + [ + model_name, + model_info.get("license", "N/A"), + model_info.get("modelfile", "N/A"), + model_info.get("parameters", "N/A"), + model_info.get("template", "N/A"), + ] + ], columns=[ - 'model_type', - 'license', - 'modelfile', - 'parameters', - 'ollama_base_template', - ]) + "model_type", + "license", + "modelfile", + "parameters", + "ollama_base_template", + ], + ) diff --git a/mindsdb/integrations/handlers/ollama_handler/tests/test_ollama_handler.py b/mindsdb/integrations/handlers/ollama_handler/tests/test_ollama_handler.py new file mode 100644 index 00000000000..b06caaae4e6 --- /dev/null +++ b/mindsdb/integrations/handlers/ollama_handler/tests/test_ollama_handler.py @@ -0,0 +1,45 @@ +import unittest +from unittest.mock import patch, Mock +import pandas as pd +from mindsdb.integrations.handlers.ollama_handler.ollama_handler import OllamaHandler + + +class TestOllamaHandler(unittest.TestCase): + def setUp(self): + # Mock the storage to return valid model configuration + mock_storage = Mock() + mock_storage.json_get.return_value = { + "model_name": "tinyllama", + "target": "response", + "ollama_serve_url": "http://localhost:11434", + } + + # Initialize handler with mocked storage + self.handler = OllamaHandler(name="test_ollama", model_storage=mock_storage, engine_storage={}) + + @patch("mindsdb.integrations.handlers.ollama_handler.ollama_handler.requests.post") + def test_temperature_passing(self, mock_post): + """ + Test that the temperature parameter is correctly extracted from args + and passed to the Ollama API options. + """ + # Setup mock response + mock_response = Mock() + mock_response.content = b'{"response": "Test response"}' + mock_post.return_value = mock_response + + # Create input dataframe + df = pd.DataFrame({"text": ["Hello"]}) + + # Execute prediction with temperature argument + self.handler.predict(df, args={"predict_params": {"temperature": 0.5}}) + + # Verify API call payload + call_args = mock_post.call_args[1]["json"] + + self.assertIn("options", call_args) + self.assertEqual(call_args["options"]["temperature"], 0.5) + + +if __name__ == "__main__": + unittest.main() From 9e2392afa00c1a0e6a33f9510a273246e9c2dbaf Mon Sep 17 00:00:00 2001 From: "Vignesh S.M" <90998381+vigbav36@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:10:05 +0530 Subject: [PATCH 102/125] Feature - Freshdesk Integration (#11683) Co-authored-by: Andrey --- .../handlers/freshdesk_handler/README.md | 65 +++++ .../handlers/freshdesk_handler/__about__.py | 9 + .../handlers/freshdesk_handler/__init__.py | 30 ++ .../freshdesk_handler/connection_args.py | 22 ++ .../freshdesk_handler/freshdesk_handler.py | 97 +++++++ .../freshdesk_handler/freshdesk_tables.py | 263 ++++++++++++++++++ .../handlers/freshdesk_handler/icon.svg | 1 + .../freshdesk_handler/requirements.txt | 1 + .../freshdesk_handler/tests/__init__.py | 1 + .../tests/test_freshdesk_handler.py | 207 ++++++++++++++ 10 files changed, 696 insertions(+) create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/README.md create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/__about__.py create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/__init__.py create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/connection_args.py create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/freshdesk_handler.py create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/freshdesk_tables.py create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/icon.svg create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/requirements.txt create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/tests/__init__.py create mode 100644 mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py diff --git a/mindsdb/integrations/handlers/freshdesk_handler/README.md b/mindsdb/integrations/handlers/freshdesk_handler/README.md new file mode 100644 index 00000000000..82190ed9426 --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/README.md @@ -0,0 +1,65 @@ +# Freshdesk Integration + +This documentation describes the integration of MindsDB with [Freshdesk](https://www.Freshdesk.com/), which provides software-as-a-service products related to customer support, sales, and other customer communications. + +The integration allows MindsDB to access data from Freshdesk and enhance it with AI capabilities. + +## Prerequisites + +Before proceeding, ensure the following prerequisites are met: + +1. Install MindsDB locally via [Docker](https://docs.mindsdb.com/setup/self-hosted/docker) or [Docker Desktop](https://docs.mindsdb.com/setup/self-hosted/docker-desktop). +2. To connect Freshdesk to MindsDB, install the required dependencies following [this instruction](https://docs.mindsdb.com/setup/self-hosted/docker#install-dependencies). + +## Connection + +Establish a connection to Freshdesk from MindsDB by executing the following SQL command and providing its [handler name](https://github.com/mindsdb/mindsdb/tree/main/mindsdb/integrations/handlers/Freshdesk_handler) as an engine. + +```sql +CREATE DATABASE freshdesk_datasource +WITH + ENGINE = 'freshdesk', + PARAMETERS = { + "api_key":"your_api_key_here", + "domain": "yourcompany.freshdesk.com" + }; +``` + +Required connection parameters include the following: + +* `api_key`: The API key for the Freshdesk account. +* `domain`: The Freshdesk domain (e.g., yourcompany.freshdesk.com). + + +For enabling, generating and deleting API access, refer [Managing access to the Freshdesk API](https://support.Freshdesk.com/hc/en-us/articles/4408889192858-Managing-access-to-the-Freshdesk-API) + + +## Usage + +Retrieve data from a specified table by providing the integration and table names: + +```sql +SELECT * +FROM freshdesk_datasource.table_name +LIMIT 10; +``` + +Retrieve data for a specific ticket by providing the id: + +```sql +SELECT * +FROM freshdesk_datasource.tickets +where id=""; +``` + + + +The above examples utilize `freshdesk_datasource` as the datasource name, which is defined in the `CREATE DATABASE` command. + + +## Supported Tables + +The Freshdesk integration supports the following tables: + +* `agents` : The table lists all the agents. +* `tickets` : The table lists all the tickets. \ No newline at end of file diff --git a/mindsdb/integrations/handlers/freshdesk_handler/__about__.py b/mindsdb/integrations/handlers/freshdesk_handler/__about__.py new file mode 100644 index 00000000000..da26a08424e --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/__about__.py @@ -0,0 +1,9 @@ +__title__ = "MindsDB Freshdesk handler" +__package_name__ = "mindsdb_freshdesk_handler" +__version__ = "0.0.1" +__description__ = "MindsDB handler for Freshdesk" +__author__ = "Vignesh S M" +__github__ = "https://github.com/mindsdb/mindsdb" +__pypi__ = "https://pypi.org/project/mindsdb/" +__license__ = "MIT" +__copyright__ = "Copyright 2023 - mindsdb" diff --git a/mindsdb/integrations/handlers/freshdesk_handler/__init__.py b/mindsdb/integrations/handlers/freshdesk_handler/__init__.py new file mode 100644 index 00000000000..ca82e320104 --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/__init__.py @@ -0,0 +1,30 @@ +from mindsdb.integrations.libs.const import HANDLER_TYPE + +from .__about__ import __version__ as version, __description__ as description +from .connection_args import connection_args, connection_args_example + +try: + from .freshdesk_handler import FreshdeskHandler as Handler + + import_error = None # noqa +except Exception as e: + Handler = None + import_error = e + +title = "Freshdesk" +name = "freshdesk" +type = HANDLER_TYPE.DATA +icon_path = "icon.svg" + +__all__ = [ + "Handler", + "version", + "name", + "type", + "title", + "description", + "import_error", + "icon_path", + "connection_args_example", + "connection_args", +] diff --git a/mindsdb/integrations/handlers/freshdesk_handler/connection_args.py b/mindsdb/integrations/handlers/freshdesk_handler/connection_args.py new file mode 100644 index 00000000000..f279bd7d459 --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/connection_args.py @@ -0,0 +1,22 @@ +from collections import OrderedDict + +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE + + +connection_args = OrderedDict( + api_key={ + "type": ARG_TYPE.STR, + "description": "Freshdesk API key", + "required": True, + "label": "api_key", + "secret": True, + }, + domain={ + "type": ARG_TYPE.STR, + "description": "Freshdesk domain (e.g., yourcompany.freshdesk.com)", + "required": True, + "label": "domain", + }, +) + +connection_args_example = OrderedDict(api_key="your_api_key_here", domain="yourcompany.freshdesk.com") diff --git a/mindsdb/integrations/handlers/freshdesk_handler/freshdesk_handler.py b/mindsdb/integrations/handlers/freshdesk_handler/freshdesk_handler.py new file mode 100644 index 00000000000..cc1dde9e9ac --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/freshdesk_handler.py @@ -0,0 +1,97 @@ +from mindsdb_sql_parser import parse_sql + +from mindsdb.integrations.handlers.freshdesk_handler.freshdesk_tables import FreshdeskAgentsTable, FreshdeskTicketsTable + +from mindsdb.integrations.libs.api_handler import APIHandler +from mindsdb.integrations.libs.response import ( + HandlerStatusResponse as StatusResponse, +) +from mindsdb.utilities import log +from freshdesk.v2.api import API + +logger = log.getLogger(__name__) + + +class FreshdeskHandler(APIHandler): + """The Freshdesk handler implementation""" + + def __init__(self, name: str, **kwargs): + """Initialize the freshdesk handler. + + Parameters + ---------- + name : str + name of a handler instance + """ + super().__init__(name) + + connection_data = kwargs.get("connection_data", {}) + self.connection_data = connection_data + self.kwargs = kwargs + self.freshdesk_client: API = None + self.is_connected = False + + self._register_table("agents", FreshdeskAgentsTable(self)) + self._register_table("tickets", FreshdeskTicketsTable(self)) + + def connect(self) -> StatusResponse: + """Set up the connection required by the handler. + + Returns + ------- + StatusResponse + connection object + """ + resp = StatusResponse(False) + try: + if not self.connection_data.get("domain"): + raise ValueError("Missing required parameter: domain") + if not self.connection_data.get("api_key"): + raise ValueError("Missing required parameter: api_key") + + self.freshdesk_client = API(domain=self.connection_data["domain"], api_key=self.connection_data["api_key"]) + # Test the connection by getting new tickets + self.freshdesk_client.tickets.list_new_and_my_open_tickets(page=1, per_page=1) + self.is_connected = True + resp.success = True + except KeyError as ex: + resp.success = False + resp.error_message = f"Missing required connection parameter: {str(ex)}" + self.is_connected = False + except ValueError as ex: + resp.success = False + resp.error_message = str(ex) + self.is_connected = False + except Exception as ex: + resp.success = False + resp.error_message = f"Failed to connect to Freshdesk: {str(ex)}" + self.is_connected = False + return resp + + def check_connection(self) -> StatusResponse: + """Check connection to the handler. + + Returns + ------- + StatusResponse + Status confirmation + """ + response = self.connect() + self.is_connected = response.success + return response + + def native_query(self, query: str) -> StatusResponse: + """Receive and process a raw query. + + Parameters + ---------- + query : str + query in a native format + + Returns + ------- + StatusResponse + Request status + """ + ast = parse_sql(query) + return self.query(ast) diff --git a/mindsdb/integrations/handlers/freshdesk_handler/freshdesk_tables.py b/mindsdb/integrations/handlers/freshdesk_handler/freshdesk_tables.py new file mode 100644 index 00000000000..e896a78cb14 --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/freshdesk_tables.py @@ -0,0 +1,263 @@ +import pandas as pd +from typing import List, Dict, Tuple +from mindsdb.integrations.libs.api_handler import APITable +from mindsdb.integrations.utilities.handlers.query_utilities import ( + SELECTQueryParser, + SELECTQueryExecutor, +) +from mindsdb.utilities import log +from mindsdb_sql_parser import ast +from urllib.parse import quote + +logger = log.getLogger(__name__) + + +class FreshdeskAgentsTable(APITable): + """Freshdesk Agents Table implementation""" + + def select(self, query: ast.Select) -> pd.DataFrame: + """Pulls data from the freshdesk list agents API + + Parameters + ---------- + query : ast.Select + Given SQL SELECT query + + Returns + ------- + pd.DataFrame + Freshdesk agents + """ + + select_statement_parser = SELECTQueryParser(query, "agents", self.get_columns()) + selected_columns, where_conditions, order_by_conditions, result_limit = select_statement_parser.parse_query() + + subset_where_conditions, filter_conditions = self.get_conditions(where_conditions) + + df = self.get_freshdesk_agents(filter_conditions) + + select_statement_executor = SELECTQueryExecutor( + df, selected_columns, subset_where_conditions, order_by_conditions, result_limit + ) + df = select_statement_executor.execute_query() + return df + + def get_conditions(self, where_conditions) -> Tuple: + subset_where_conditions = [] + filter_conditions = {} + + for op, arg1, arg2 in where_conditions: + if arg1 in self.get_columns(): + if arg1 in self.get_api_filter_columns() and op == "=": + filter_conditions[self.get_api_filter_columns()[arg1]] = arg2 + else: + subset_where_conditions.append([op, arg1, arg2]) + return subset_where_conditions, filter_conditions + + def get_freshdesk_agents(self, api_filters): + agents = self.handler.freshdesk_client.agents.list_agents(**api_filters) + response = [] + + for agent in agents: + response.append(self.agent_to_dict(agent)) + + return pd.json_normalize(response, sep="_").reindex(columns=self.get_columns(), fill_value=None) + + def get_columns(self) -> List[str]: + """Gets all columns to be returned in pandas DataFrame responses""" + return [ + "available", + "occasional", + "id", + "ticket_scope", + "created_at", + "updated_at", + "last_active_at", + "available_since", + "type", + "deactivated", + "signature", + "focus_mode", + "contact_active", + "contact_email", + "contact_job_title", + "contact_language", + "contact_last_login_at", + "contact_mobile", + "contact_name", + "contact_phone", + "contact_time_zone", + "contact_created_at", + "contact_updated_at", + ] + + def get_api_filter_columns(self) -> Dict[str, str]: + """Gets all columns that can be used to filter through the API directly""" + return { + "contact_email": "email", + "contact_mobile": "mobile", + "contact_phone": "phone", + "contact_state": "state", + } + + def agent_to_dict(self, agent): + dict = {col: getattr(agent, col, None) for col in self.get_columns()} + dict["contact"] = getattr(agent, "contact", None) + return dict + + +class FreshdeskTicketsTable(APITable): + """Freshdesk Tickets Table implementation""" + + PRIORITY_MAP = {"low": 1, "medium": 2, "high": 3, "urgent": 4} + STATUS_MAP = {"open": 2, "pending": 3, "resolved": 4, "closed": 5} + + def select(self, query: ast.Select) -> pd.DataFrame: + """Pulls data from the freshdesk list tickets API + + Parameters + ---------- + query : ast.Select + Given SQL SELECT query + + Returns + ------- + pd.DataFrame + Freshdesk tickets + """ + + select_statement_parser = SELECTQueryParser(query, "tickets", self.get_columns()) + + selected_columns, where_conditions, order_by_conditions, result_limit = select_statement_parser.parse_query() + + subset_where_conditions, filter_conditions = self.get_conditions(where_conditions) + + df = self.get_freshdesk_tickets(filter_conditions) + + select_statement_executor = SELECTQueryExecutor( + df, selected_columns, subset_where_conditions, order_by_conditions, result_limit + ) + + df = select_statement_executor.execute_query() + return df + + def get_conditions(self, where_conditions) -> Tuple: + subset_where_conditions = [] + search_conditions = [] + + for op, arg1, val in where_conditions: + if arg1 in self.get_api_filter_columns() and op in self.get_operator_map().keys(): + if arg1 == "priority" and isinstance(val, str): + val = self.PRIORITY_MAP.get(val.lower(), val) + if arg1 == "status" and isinstance(val, str): + val = self.STATUS_MAP.get(val.lower(), val) + search_conditions.append((op, arg1, val)) + else: + subset_where_conditions.append([op, arg1, val]) + + return subset_where_conditions, search_conditions + + def get_freshdesk_tickets(self, filter_conditions): + if len(filter_conditions) > 0: + tickets = self.handler.freshdesk_client.tickets.filter_tickets( + query=self.build_freshdesk_api_filter_query(filter_conditions) + ) + else: + tickets = self.handler.freshdesk_client.tickets.list_tickets(filter_name=None) + + response = [] + + for ticket in tickets: + response.append(self.ticket_to_dict(ticket)) + + return pd.json_normalize(response, sep="_").reindex(columns=self.get_columns(), fill_value=None) + + def build_freshdesk_api_filter_query(self, conditions): + """ + Build Freshdesk API filter query string, quoting strings and mapping enums. + """ + + op_map = self.get_operator_map() + parts = [] + + for op, field, value in conditions: + freshdesk_operator = op_map.get(op) + if freshdesk_operator is None: + raise ValueError(f"Unsupported operator: {op}") + + if isinstance(value, str): + escaped_value = value.replace("'", "'") + value_str = f"'{escaped_value}'" + else: + value_str = str(value) + + parts.append(f"{field}:{value_str}") + + query_string = " AND ".join(parts) + return quote(query_string) + + def get_operator_map(self): + """Mapping of sql where operators to freshdesk API query operators""" + return { + "=": ":", + ">": ":>", + "<": ":<", + ">=": ":>", + "<=": ":<", + } + + def ticket_to_dict(self, ticket): + return {col: getattr(ticket, col, None) for col in self.get_columns()} + + def get_columns(self) -> List[str]: + """Gets all columns to be returned in pandas DataFrame responses""" + return [ + "attachments", + "cc_emails", + "company_id", + "custom_fields", + "deleted", + "description", + "description_text", + "due_by", + "email", + "email_config_id", + "facebook_id", + "fr_due_by", + "fr_escalated", + "fwd_emails", + "group_id", + "id", + "is_escalated", + "name", + "phone", + "priority", + "product_id", + "reply_cc_emails", + "requester_id", + "responder_id", + "source", + "spam", + "status", + "subject", + "tags", + "to_emails", + "twitter_id", + "type", + "created_at", + "updated_at", + ] + + def get_api_filter_columns(self) -> Dict[str, str]: + """Gets all columns that can be used to filter through the API directly""" + + return { + "status": "status", + "priority": "priority", + "type": "type", + "group_id": "group_id", + "agent_id": "agent_id", + "created_at": "created_at", + "updated_at": "updated_at", + "fr_due_by": "fr_due_by", + } diff --git a/mindsdb/integrations/handlers/freshdesk_handler/icon.svg b/mindsdb/integrations/handlers/freshdesk_handler/icon.svg new file mode 100644 index 00000000000..ecb5c6e0c6f --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/mindsdb/integrations/handlers/freshdesk_handler/requirements.txt b/mindsdb/integrations/handlers/freshdesk_handler/requirements.txt new file mode 100644 index 00000000000..07a7ac52291 --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/requirements.txt @@ -0,0 +1 @@ +python-freshdesk \ No newline at end of file diff --git a/mindsdb/integrations/handlers/freshdesk_handler/tests/__init__.py b/mindsdb/integrations/handlers/freshdesk_handler/tests/__init__.py new file mode 100644 index 00000000000..f83bf577364 --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/tests/__init__.py @@ -0,0 +1 @@ +# Tests for Freshdesk Handler diff --git a/mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py b/mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py new file mode 100644 index 00000000000..2aa4cbc79e5 --- /dev/null +++ b/mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py @@ -0,0 +1,207 @@ +import unittest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd +from mindsdb_sql_parser import parse_sql + +from mindsdb.integrations.handlers.freshdesk_handler.freshdesk_handler import FreshdeskHandler +from mindsdb.integrations.handlers.freshdesk_handler.freshdesk_tables import ( + FreshdeskAgentsTable, + FreshdeskTicketsTable, +) + + +class TestFreshdeskHandler(unittest.TestCase): + """Test cases for Freshdesk Handler""" + + @classmethod + def setUpClass(cls): + """Set up test fixtures before running tests.""" + cls.kwargs = {"connection_data": {"domain": "test.freshdesk.com", "api_key": "test_api_key_123"}} + cls.handler = FreshdeskHandler("test_freshdesk_handler", **cls.kwargs) + cls.agents_table = FreshdeskAgentsTable(cls.handler) + cls.tickets_table = FreshdeskTicketsTable(cls.handler) + + def setUp(self): + """Set up test fixtures before each test method.""" + # Mock the freshdesk client + self.mock_client = Mock() + self.handler.freshdesk_client = self.mock_client + self.handler.is_connected = True + + def _get_agents_mock_data(self, num_records=3): + """Helper method to create mock agents data.""" + return pd.DataFrame( + { + "id": list(range(1, num_records + 1)), + "available": ([True, False, True] * ((num_records // 3) + 1))[:num_records], + "contact_email": [f"agent{i}@test.com" for i in range(1, num_records + 1)], + "contact_mobile": [f"123456789{i}" for i in range(1, num_records + 1)], + "contact_name": [f"Agent {i}" for i in range(1, num_records + 1)], + } + ) + + def _get_tickets_mock_data(self, num_records=3, custom_data=None): + """Helper method to create mock tickets data.""" + base_data = { + "id": list(range(1, num_records + 1)), + "status": list(range(2, num_records + 2)), + "priority": list(range(1, num_records + 1)), + "subject": [f"Issue {i}" for i in range(1, num_records + 1)], + "group_id": list(range(1, num_records + 1)), + } + + # Override with custom data if provided + if custom_data: + base_data.update(custom_data) + + return pd.DataFrame(base_data) + + def test_agents_table_get_conditions(self): + """Test get_conditions method for agents table.""" + where_conditions = [ + ["=", "contact_email", "test@example.com"], + [">", "id", 100], + ["=", "contact_mobile", "+1234567890"], + ] + + subset_conditions, filter_conditions = self.agents_table.get_conditions(where_conditions) + + # Check that API filter conditions are properly extracted + expected_filter_conditions = {"email": "test@example.com", "mobile": "+1234567890"} + self.assertEqual(filter_conditions, expected_filter_conditions) + + # Check that non-API filter conditions are in subset + expected_subset = [[">", "id", 100]] + self.assertEqual(subset_conditions, expected_subset) + + def test_agents_table_select_basic(self): + """Test basic select query for agents table.""" + mock_df = self._get_agents_mock_data() + + with patch.object(self.agents_table, "get_freshdesk_agents", return_value=mock_df): + query = "SELECT id, available FROM agents" + ast = parse_sql(query) + result = self.agents_table.select(ast) + + self.assertIsInstance(result, pd.DataFrame) + self.assertEqual(len(result), 3) + self.assertIn("id", result.columns) + self.assertIn("available", result.columns) + + def test_agents_table_select_with_where(self): + """Test select query with WHERE clause for agents table.""" + mock_df = self._get_agents_mock_data() + + with patch.object(self.agents_table, "get_freshdesk_agents", return_value=mock_df): + query = "SELECT id, available, contact_email FROM agents WHERE contact_email = 'agent1@test.com'" + ast = parse_sql(query) + result = self.agents_table.select(ast) + + self.assertIsInstance(result, pd.DataFrame) + + def test_tickets_table_get_conditions(self): + """Test get_conditions method for tickets table.""" + where_conditions = [["=", "status", "open"], ["=", "priority", "high"], [">", "id", 100], ["=", "group_id", 5]] + + subset_conditions, search_conditions = self.tickets_table.get_conditions(where_conditions) + + # Check that API filter conditions are properly extracted + expected_search_conditions = [("=", "status", 2), ("=", "priority", 3), ("=", "group_id", 5)] + self.assertEqual(search_conditions, expected_search_conditions) + + # Check that non-API filter conditions are in subset + expected_subset = [[">", "id", 100]] + self.assertEqual(subset_conditions, expected_subset) + + def test_tickets_table_priority_status_mapping(self): + """Test priority and status mapping in get_conditions.""" + where_conditions = [["=", "priority", "urgent"], ["=", "status", "closed"]] + + subset_conditions, search_conditions = self.tickets_table.get_conditions(where_conditions) + + # Check that string values are mapped to numbers + expected_search_conditions = [("=", "priority", 4), ("=", "status", 5)] + self.assertEqual(search_conditions, expected_search_conditions) + + def test_tickets_table_build_freshdesk_api_filter_query(self): + """Test build_freshdesk_api_filter_query method.""" + conditions = [("=", "status", 2), ("=", "priority", 3)] + + result = self.tickets_table.build_freshdesk_api_filter_query(conditions) + + # Should return a URL-encoded query string + self.assertIn("status%3A2", result) + self.assertIn("priority%3A3", result) + self.assertIn("AND", result) + + def test_tickets_table_build_freshdesk_api_filter_query_with_strings(self): + """Test build_freshdesk_api_filter_query with string values.""" + conditions = [("=", "status", "open"), ("=", "priority", "high")] + + result = self.tickets_table.build_freshdesk_api_filter_query(conditions) + + # Should handle string values with quotes + self.assertIn("status%3A%27open%27", result) + self.assertIn("priority%3A%27high%27", result) + + def test_tickets_table_select_basic(self): + """Test basic select query for tickets table.""" + mock_df = self._get_tickets_mock_data() + + with patch.object(self.tickets_table, "get_freshdesk_tickets", return_value=mock_df): + query = "SELECT id, status, subject FROM tickets" + ast = parse_sql(query) + result = self.tickets_table.select(ast) + + self.assertIsInstance(result, pd.DataFrame) + self.assertEqual(len(result), 3) + self.assertIn("id", result.columns) + self.assertIn("status", result.columns) + self.assertIn("subject", result.columns) + + def test_tickets_table_select_with_where(self): + """Test select query with WHERE clause for tickets table.""" + mock_df = self._get_tickets_mock_data() + + with patch.object(self.tickets_table, "get_freshdesk_tickets", return_value=mock_df): + query = "SELECT id, status, subject FROM tickets WHERE status = 'open'" + ast = parse_sql(query) + result = self.tickets_table.select(ast) + + self.assertIsInstance(result, pd.DataFrame) + + def test_tickets_table_select_with_limit(self): + """Test select query with LIMIT clause for tickets table.""" + mock_df = self._get_tickets_mock_data(num_records=5) + + with patch.object(self.tickets_table, "get_freshdesk_tickets", return_value=mock_df): + query = "SELECT id, status, subject FROM tickets LIMIT 3" + ast = parse_sql(query) + result = self.tickets_table.select(ast) + + self.assertIsInstance(result, pd.DataFrame) + self.assertLessEqual(len(result), 3) + + def test_tickets_table_select_with_order_by(self): + """Test select query with ORDER BY clause for tickets table.""" + # Create custom data with different order for testing + custom_data = { + "id": [3, 1, 2], + "status": [4, 2, 3], + "priority": [3, 1, 2], + "subject": ["Issue 3", "Issue 1", "Issue 2"], + "group_id": [3, 1, 2], + } + mock_df = self._get_tickets_mock_data(custom_data=custom_data) + + with patch.object(self.tickets_table, "get_freshdesk_tickets", return_value=mock_df): + query = "SELECT id, status, subject FROM tickets ORDER BY id" + ast = parse_sql(query) + result = self.tickets_table.select(ast) + + self.assertIsInstance(result, pd.DataFrame) + self.assertEqual(len(result), 3) + + +if __name__ == "__main__": + unittest.main() From 25aa0651502f56c2ce5ae119b16a8089d8f74994 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 27 Mar 2026 13:42:15 +0300 Subject: [PATCH 103/125] ruff --- .../handlers/freshdesk_handler/tests/test_freshdesk_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py b/mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py index 2aa4cbc79e5..3d91c03ed55 100644 --- a/mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py +++ b/mindsdb/integrations/handlers/freshdesk_handler/tests/test_freshdesk_handler.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch import pandas as pd from mindsdb_sql_parser import parse_sql From ba41c3adb5aeaa35659bec14c945f9957c9bc03a Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 27 Mar 2026 14:23:42 +0300 Subject: [PATCH 104/125] import checks --- tests/scripts/check_requirements.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index 08766c9b2cd..319d94fcbcc 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -145,6 +145,8 @@ def get_requirements_with_DEP002(path): CHROMADB_EP002_IGNORE_HANDLER_DEPS = ["onnxruntime"] +FRESHDESK_EP002_IGNORE_HANDLER_DEPS = ["python-freshdesk"] + # The `pyarrow` package is used only if it is installed. # The handler can work without it. SNOWFLAKE_DEP003_IGNORE_HANDLER_DEPS = ["pyarrow"] @@ -160,6 +162,7 @@ def get_requirements_with_DEP002(path): + SOLR_DEP002_IGNORE_HANDLER_DEPS + OPENAI_DEP002_IGNORE_HANDLER_DEPS + CHROMADB_EP002_IGNORE_HANDLER_DEPS + + FRESHDESK_EP002_IGNORE_HANDLER_DEPS ) ) From 4eae7cb8ade4075e5ad3e94bd45d8d47ea8cc0f8 Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 27 Mar 2026 14:25:24 +0300 Subject: [PATCH 105/125] import checks --- tests/scripts/check_requirements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index 319d94fcbcc..8aa0f0f8bf8 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -180,6 +180,7 @@ def get_requirements_with_DEP002(path): "IfxPyDbi", "ingres_sa_dialect", "pyodbc", + "freshdesk", ], # 'tests' is the mindsdb tests folder in the repo root, 'pyarrow' used in snowflake handler "DEP003": DEP003_IGNORE_HANDLER_DEPS, } From 307d855c8a1ad69906485393a030fe858abf5594 Mon Sep 17 00:00:00 2001 From: neversettle <41864816+neversettle17-101@users.noreply.github.com> Date: Fri, 27 Mar 2026 17:24:21 +0530 Subject: [PATCH 106/125] Fix: Create Knowledge base without explicitly providing provider configs is successful for non-openAI scenarios (#11975) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michael Olayemi Olawepo <154475559+sejubar@users.noreply.github.com> Co-authored-by: andrew Co-authored-by: April I. Murphy <36110273+aimurphy@users.noreply.github.com> Co-authored-by: Minura Punchihewa <49385643+MinuraPunchihewa@users.noreply.github.com> Co-authored-by: Konstantin Sivakov Co-authored-by: martyna-mindsdb <109554435+martyna-mindsdb@users.noreply.github.com> Co-authored-by: Sebastián Tobón Hernández --- mindsdb/integrations/utilities/rag/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindsdb/integrations/utilities/rag/settings.py b/mindsdb/integrations/utilities/rag/settings.py index 56a8306295f..ac42e82a884 100644 --- a/mindsdb/integrations/utilities/rag/settings.py +++ b/mindsdb/integrations/utilities/rag/settings.py @@ -696,7 +696,7 @@ def _missing_(cls, value): class RerankerConfig(BaseModel): model: str = DEFAULT_RERANKING_MODEL - base_url: str = DEFAULT_LLM_ENDPOINT + base_url: Optional[str] = None filtering_threshold: float = 0.5 num_docs_to_keep: Optional[int] = None mode: RerankerMode = Field( From bb97917154c73b65fc9e126bf411988947b4827f Mon Sep 17 00:00:00 2001 From: Jorge Torres Date: Fri, 27 Mar 2026 05:52:07 -0700 Subject: [PATCH 107/125] Fix boot error messages (#11626) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Zoran Pandovski Co-authored-by: Michael Olayemi Olawepo <154475559+sejubar@users.noreply.github.com> Co-authored-by: andrew Co-authored-by: April I. Murphy <36110273+aimurphy@users.noreply.github.com> Co-authored-by: Minura Punchihewa <49385643+MinuraPunchihewa@users.noreply.github.com> Co-authored-by: Konstantin Sivakov Co-authored-by: martyna-mindsdb <109554435+martyna-mindsdb@users.noreply.github.com> Co-authored-by: Sebastián Tobón Hernández --- mindsdb/utilities/log.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mindsdb/utilities/log.py b/mindsdb/utilities/log.py index eb174f54c2a..2ae311a61da 100644 --- a/mindsdb/utilities/log.py +++ b/mindsdb/utilities/log.py @@ -4,10 +4,15 @@ import logging import threading from typing import Any +import warnings from logging.config import dictConfig from mindsdb.utilities.config import config as app_config +# Suppress Pydantic warnings for third-party libraries +# TODO: Work on a better solution to this +warnings.filterwarnings("ignore", message="Field.*has conflict with protected namespace.*", category=UserWarning) + logging_initialized = False From 100e10c051ce62d0f0ad358fa7a7b698c045bd80 Mon Sep 17 00:00:00 2001 From: MirzaSamadAhmedBaig <89132160+Mirza-Samad-Ahmed-Baig@users.noreply.github.com> Date: Mon, 30 Mar 2026 12:18:36 +0300 Subject: [PATCH 108/125] Fix OpenBB command execution by removing eval (#12290) --- .../handlers/openbb_handler/openbb_tables.py | 166 +++++++++++------- tests/unit/handlers/test_openbb_tables.py | 96 ++++++++++ 2 files changed, 196 insertions(+), 66 deletions(-) create mode 100644 tests/unit/handlers/test_openbb_tables.py diff --git a/mindsdb/integrations/handlers/openbb_handler/openbb_tables.py b/mindsdb/integrations/handlers/openbb_handler/openbb_tables.py index 69cf5631ecc..e3002d0cd45 100644 --- a/mindsdb/integrations/handlers/openbb_handler/openbb_tables.py +++ b/mindsdb/integrations/handlers/openbb_handler/openbb_tables.py @@ -8,6 +8,7 @@ from typing import Dict, List, Union from pydantic import ValidationError +import ast as py_ast import pandas as pd @@ -15,6 +16,48 @@ class OpenBBtable(APITable): + def _resolve_openbb_command(self, cmd: str): + """Resolve a validated OpenBB command to a callable.""" + if not isinstance(cmd, str): + raise TypeError("OpenBB command must be a string.") + + parts = cmd.split(".") + if len(parts) < 2 or parts[0] != "obb": + raise ValueError("OpenBB command must start with 'obb.'") + + target = self.handler + for part in parts: + if not part.isidentifier() or part.startswith("_"): + raise ValueError(f"Invalid OpenBB command segment: {part}") + target = getattr(target, part) + + if not callable(target): + raise TypeError(f"OpenBB command '{cmd}' is not callable.") + + return target + + def _coerce_param_value(self, value): + """Coerce string literals to Python values while keeping plain strings intact.""" + if not isinstance(value, str): + return value + + candidate = value.strip() + if candidate == "": + return value + + lowered = candidate.lower() + if lowered == "true": + return True + if lowered == "false": + return False + if lowered in ("none", "null"): + return None + + try: + return py_ast.literal_eval(candidate) + except (ValueError, SyntaxError): + return value + def _get_params_from_conditions(self, conditions: List) -> Dict: """Gets aggregate trade data API params from SQL WHERE conditions. @@ -73,24 +116,16 @@ def select(self, query: ast.Select) -> pd.DataFrame: # Ensure that the cmd provided is a valid OpenBB command available_cmds = [f"obb{cmd}" for cmd in list(self.handler.obb.coverage.commands.keys())] if cmd not in available_cmds: - logger.error(f"The command provided is not supported by OpenBB! Choose one of the following: {', '.join(available_cmds)}") - raise Exception(f"The command provided is not supported by OpenBB! Choose one of the following: {', '.join(available_cmds)}") - - args = "" - # If there are parameters create arguments as a string - if params: - for arg, val in params.items(): - args += f"{arg}={val}," + logger.error( + f"The command provided is not supported by OpenBB! Choose one of the following: {', '.join(available_cmds)}" + ) + raise Exception( + f"The command provided is not supported by OpenBB! Choose one of the following: {', '.join(available_cmds)}" + ) - # Remove the additional ',' added at the end - if args: - args = args[:-1] - - # Recreate the OpenBB command with the arguments - openbb_cmd = f"self.handler.{cmd}({args})" - - # Execute the OpenBB command and return the OBBject - openbb_object = eval(openbb_cmd) + # Resolve command safely and invoke with explicit keyword args. + openbb_function = self._resolve_openbb_command(cmd) + openbb_object = openbb_function(**{key: self._coerce_param_value(val) for key, val in params.items()}) # Transform the OBBject into a pandas DataFrame data = openbb_object.to_df() @@ -109,16 +144,12 @@ def select(self, query: ast.Select) -> pd.DataFrame: return data -def create_table_class( - params_metadata, - response_metadata, - obb_function, - func_docs="", - provider=None -): +def create_table_class(params_metadata, response_metadata, obb_function, func_docs="", provider=None): """Creates a table class for the given OpenBB Platform function.""" - mandatory_fields = [key for key in params_metadata['fields'].keys() if params_metadata['fields'][key].is_required() is True] - response_columns = list(response_metadata['fields'].keys()) + mandatory_fields = [ + key for key in params_metadata["fields"].keys() if params_metadata["fields"][key].is_required() is True + ] + response_columns = list(response_metadata["fields"].keys()) class AnyTable(APITable): def _get_params_from_conditions(self, conditions: List) -> Dict: @@ -152,44 +183,43 @@ def select(self, query: ast.Select) -> pd.DataFrame: params = {} if provider is not None: - params['provider'] = provider + params["provider"] = provider filters = [] mandatory_args_set = {key: False for key in mandatory_fields} columns_to_add = {} - strict_filter = arg_params.get('strict_filter', False) + strict_filter = arg_params.get("strict_filter", False) for op, arg1, arg2 in conditions: - if op == 'or': - raise NotImplementedError('OR is not supported') + if op == "or": + raise NotImplementedError("OR is not supported") if arg1 in mandatory_fields: mandatory_args_set[arg1] = True - if ('start_' + arg1 in params_metadata['fields'] and arg1 in response_columns and arg2 is not None): - - if response_metadata['fields'][arg1].annotation == 'datetime': + if "start_" + arg1 in params_metadata["fields"] and arg1 in response_columns and arg2 is not None: + if response_metadata["fields"][arg1].annotation == "datetime": date = parse_local_date(arg2) - interval = arg_params.get('interval', '1d') + interval = arg_params.get("interval", "1d") - if op == '>': - params['start_' + arg1] = date.strftime('%Y-%m-%d') - elif op == '<': - params['end_' + arg1] = date.strftime('%Y-%m-%d') - elif op == '>=': + if op == ">": + params["start_" + arg1] = date.strftime("%Y-%m-%d") + elif op == "<": + params["end_" + arg1] = date.strftime("%Y-%m-%d") + elif op == ">=": date = date - pd.Timedelta(interval) - params['start_' + arg1] = date.strftime('%Y-%m-%d') - elif op == '<=': + params["start_" + arg1] = date.strftime("%Y-%m-%d") + elif op == "<=": date = date + pd.Timedelta(interval) - params['end_' + arg1] = date.strftime('%Y-%m-%d') - elif op == '=': + params["end_" + arg1] = date.strftime("%Y-%m-%d") + elif op == "=": date = date - pd.Timedelta(interval) - params['start_' + arg1] = date.strftime('%Y-%m-%d') + params["start_" + arg1] = date.strftime("%Y-%m-%d") date = date + pd.Timedelta(interval) - params['end_' + arg1] = date.strftime('%Y-%m-%d') + params["end_" + arg1] = date.strftime("%Y-%m-%d") - elif arg1 in params_metadata['fields'] or not strict_filter: - if op == '=': + elif arg1 in params_metadata["fields"] or not strict_filter: + if op == "=": params[arg1] = arg2 columns_to_add[arg1] = arg2 @@ -201,9 +231,9 @@ def select(self, query: ast.Select) -> pd.DataFrame: # Create docstring for the current function text += "\nDocstring:" - for param in params_metadata['fields']: - field = params_metadata['fields'][param] - if getattr(field.annotation, '__origin__', None) is Union: + for param in params_metadata["fields"]: + field = params_metadata["fields"][param] + if getattr(field.annotation, "__origin__", None) is Union: annotation = f"Union[{', '.join(arg.__name__ for arg in field.annotation.__args__)}]" else: annotation = field.annotation.__name__ @@ -215,8 +245,8 @@ def select(self, query: ast.Select) -> pd.DataFrame: try: # Handle limit keyword correctly since it can't be parsed as a WHERE arg (i.e. WHERE limit = 50) - if query.limit is not None and 'limit' in params_metadata['fields']: - params['limit'] = query.limit.value + if query.limit is not None and "limit" in params_metadata["fields"]: + params["limit"] = query.limit.value obbject = obb_function(**params) # Extract data in dataframe format @@ -273,13 +303,13 @@ def select(self, query: ast.Select) -> pd.DataFrame: return result except AttributeError as e: - logger.info(f'Encountered error while executing OpenBB select: {str(e)}') + logger.info(f"Encountered error while executing OpenBB select: {str(e)}") # Create docstring for the current function text = "Docstring:" - for param in params_metadata['fields']: - field = params_metadata['fields'][param] - if getattr(field.annotation, '__origin__', None) is Union: + for param in params_metadata["fields"]: + field = params_metadata["fields"][param] + if getattr(field.annotation, "__origin__", None) is Union: annotation = f"Union[{', '.join(arg.__name__ for arg in field.annotation.__args__)}]" else: annotation = field.annotation.__name__ @@ -290,13 +320,13 @@ def select(self, query: ast.Select) -> pd.DataFrame: raise Exception(f"{str(e)}\n\n{text}.") from e except ValidationError as e: - logger.info(f'Encountered error while executing OpenBB select: {str(e)}') + logger.info(f"Encountered error while executing OpenBB select: {str(e)}") # Create docstring for the current function text = "Docstring:" - for param in params_metadata['fields']: - field = params_metadata['fields'][param] - if getattr(field.annotation, '__origin__', None) is Union: + for param in params_metadata["fields"]: + field = params_metadata["fields"][param] + if getattr(field.annotation, "__origin__", None) is Union: annotation = f"Union[{', '.join(arg.__name__ for arg in field.annotation.__args__)}]" else: annotation = field.annotation.__name__ @@ -307,21 +337,25 @@ def select(self, query: ast.Select) -> pd.DataFrame: raise Exception(f"{str(e)}\n\n{text}.") from e except Exception as e: - logger.info(f'Encountered error while executing OpenBB select: {str(e)}') + logger.info(f"Encountered error while executing OpenBB select: {str(e)}") # TODO: This one doesn't work because it's taken care of from MindsDB side if "Table not found" in str(e): - raise Exception(f"{str(e)}\n\nCheck if the method exists here: {func_docs}.\n\n - If it doesn't you may need to look for the parent module to check whether there's a typo in the naming.\n- If it does you may need to install a new extension to the OpenBB Platform, and you can see what is available at https://my.openbb.co/app/platform/extensions.") from e + raise Exception( + f"{str(e)}\n\nCheck if the method exists here: {func_docs}.\n\n - If it doesn't you may need to look for the parent module to check whether there's a typo in the naming.\n- If it does you may need to install a new extension to the OpenBB Platform, and you can see what is available at https://my.openbb.co/app/platform/extensions." + ) from e if "Missing credential" in str(e): - raise Exception(f"{str(e)}\n\nGo to https://my.openbb.co/app/platform/api-keys to set this API key, for free.") from e + raise Exception( + f"{str(e)}\n\nGo to https://my.openbb.co/app/platform/api-keys to set this API key, for free." + ) from e # Catch all other errors # Create docstring for the current function text = "Docstring:" - for param in params_metadata['fields']: - field = params_metadata['fields'][param] - if getattr(field.annotation, '__origin__', None) is Union: + for param in params_metadata["fields"]: + field = params_metadata["fields"][param] + if getattr(field.annotation, "__origin__", None) is Union: annotation = f"Union[{', '.join(arg.__name__ for arg in field.annotation.__args__)}]" else: annotation = field.annotation.__name__ diff --git a/tests/unit/handlers/test_openbb_tables.py b/tests/unit/handlers/test_openbb_tables.py new file mode 100644 index 00000000000..4817c12501c --- /dev/null +++ b/tests/unit/handlers/test_openbb_tables.py @@ -0,0 +1,96 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pandas as pd +import pytest + +from mindsdb.integrations.handlers.openbb_handler.openbb_tables import OpenBBtable + + +class _DummyOpenBBResponse: + def __init__(self, payload): + self.payload = payload + + def to_df(self): + return pd.DataFrame([self.payload]) + + +class _DummyPrice: + def historical(self, **kwargs): + return _DummyOpenBBResponse(kwargs) + + +class _DummyEquity: + def __init__(self): + self.price = _DummyPrice() + + +class _DummyCoverage: + def __init__(self): + self.commands = {".equity.price.historical": {}} + + +class _DummyObb: + def __init__(self): + self.equity = _DummyEquity() + self.coverage = _DummyCoverage() + + +class _DummyHandler: + def __init__(self): + self.obb = _DummyObb() + + +def test_openbb_command_resolution_returns_callable(): + table = OpenBBtable(_DummyHandler()) + + function = table._resolve_openbb_command("obb.equity.price.historical") + result = function(symbol="AAPL").to_df() + + assert result.iloc[0]["symbol"] == "AAPL" + + +def test_openbb_select_treats_params_as_data(): + table = OpenBBtable(_DummyHandler()) + malicious_value = "__import__('os').system('echo hacked')" + query = SimpleNamespace(where=object()) + + with patch( + "mindsdb.integrations.handlers.openbb_handler.openbb_tables.extract_comparison_conditions", + return_value=[["=", "cmd", "obb.equity.price.historical"], ["=", "symbol", malicious_value]], + ): + result = table.select(query) + + assert result.iloc[0]["symbol"] == malicious_value + + +def test_openbb_command_resolution_rejects_private_segments(): + table = OpenBBtable(_DummyHandler()) + + with pytest.raises(ValueError, match="Invalid OpenBB command segment"): + table._resolve_openbb_command("obb.__class__") + + +def test_openbb_select_coerces_literal_string_params(): + table = OpenBBtable(_DummyHandler()) + query = SimpleNamespace(where=object()) + + with patch( + "mindsdb.integrations.handlers.openbb_handler.openbb_tables.extract_comparison_conditions", + return_value=[ + ["=", "cmd", "obb.equity.price.historical"], + ["=", "limit", "123"], + ["=", "adjusted", "true"], + ["=", "symbol", "'AAPL'"], + ["=", "ids", "[1, 2]"], + ["=", "raw_symbol", "AAPL"], + ], + ): + result = table.select(query) + + row = result.iloc[0] + assert row["limit"] == 123 + assert bool(row["adjusted"]) is True + assert row["symbol"] == "AAPL" + assert row["ids"] == [1, 2] + assert row["raw_symbol"] == "AAPL" From a07e6cb5e5547587b5bbcc025026a6b8860e5e98 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:36:27 +0300 Subject: [PATCH 109/125] Bump requests from 2.32.4 to 2.33.0 in /requirements (#12332) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Max Stepanov --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 1d17d9c4f64..55c81279c0f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -18,7 +18,7 @@ mindsdb-sql-parser ~= 0.13.8 pydantic == 2.12.5 duckdb == 1.3.0; sys_platform == "win32" duckdb ~= 1.3.2; sys_platform != "win32" -requests == 2.32.4 +requests == 2.33.0 dateparser==1.2.0 dill == 0.3.6 numpy ~= 2.0 From 426d521499ab7b0ee3b479645d21d1c8e06e65d2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:45:35 +0300 Subject: [PATCH 110/125] Bump yaml from 2.8.2 to 2.8.3 in /docs (#12335) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/package-lock.json b/docs/package-lock.json index afd6c9aaa98..bf47d4578e7 100644 --- a/docs/package-lock.json +++ b/docs/package-lock.json @@ -14162,9 +14162,9 @@ "license": "ISC" }, "node_modules/yaml": { - "version": "2.8.2", - "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.2.tgz", - "integrity": "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==", + "version": "2.8.3", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.3.tgz", + "integrity": "sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg==", "license": "ISC", "bin": { "yaml": "bin.mjs" From 4dfa9f89bb637ccf3450e2c11d232b31ea81d670 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:46:13 +0300 Subject: [PATCH 111/125] Bump brace-expansion from 1.1.12 to 1.1.13 in /docs (#12342) Signed-off-by: dependabot[bot] Co-authored-by: Lucas Koontz Co-authored-by: Hamish Fagg Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/build_deploy_staging.yml | 2 +- docs/package-lock.json | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_deploy_staging.yml b/.github/workflows/build_deploy_staging.yml index d0f3b0c27b9..0a5513c4227 100644 --- a/.github/workflows/build_deploy_staging.yml +++ b/.github/workflows/build_deploy_staging.yml @@ -10,7 +10,7 @@ on: types: - closed branches: - - 'develop' + - 'main' - 'releases/*' concurrency: diff --git a/docs/package-lock.json b/docs/package-lock.json index bf47d4578e7..ca5d34748c8 100644 --- a/docs/package-lock.json +++ b/docs/package-lock.json @@ -5043,9 +5043,9 @@ "license": "MIT" }, "node_modules/brace-expansion": { - "version": "1.1.12", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", - "integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.13.tgz", + "integrity": "sha512-9ZLprWS6EENmhEOpjCYW2c8VkmOvckIJZfkr7rBW6dObmfgJ/L1GpSYW5Hpo9lDz4D1+n0Ckz8rU7FwHDQiG/w==", "license": "MIT", "dependencies": { "balanced-match": "^1.0.0", From c675b5b46faa1b44fa1adcd11fdb821ced2933f3 Mon Sep 17 00:00:00 2001 From: Krishna Chaitanya Date: Mon, 30 Mar 2026 06:05:42 -0700 Subject: [PATCH 112/125] fix: initialize missing address attribute in HanaHandler (#12328) --- mindsdb/integrations/handlers/hana_handler/hana_handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindsdb/integrations/handlers/hana_handler/hana_handler.py b/mindsdb/integrations/handlers/hana_handler/hana_handler.py index eb04fd68338..0b1053fc980 100644 --- a/mindsdb/integrations/handlers/hana_handler/hana_handler.py +++ b/mindsdb/integrations/handlers/hana_handler/hana_handler.py @@ -37,6 +37,7 @@ def __init__(self, name: Text, connection_data: Dict, **kwargs: Any) -> None: """ super().__init__(name) self.connection_data = connection_data + self.address = self.connection_data.get('address') self.kwargs = kwargs self.connection = None @@ -94,7 +95,7 @@ def connect(self) -> dbapi.Connection: logger.error(f'Error connecting to SAP HANA, {known_error}!') raise except Exception as unknown_error: - logger.error(f'Unknown error connecting to Teradata, {unknown_error}!') + logger.error(f'Unknown error connecting to SAP HANA, {unknown_error}!') raise def disconnect(self) -> None: From 7fee1dcf290423cc06006708b777c77534f593f6 Mon Sep 17 00:00:00 2001 From: Raahim Lone Date: Mon, 30 Mar 2026 09:06:30 -0400 Subject: [PATCH 113/125] Add bearer token auth support to Confluence handler (#12327) --- .../confluence_api_client.py | 24 ++++++++-- .../confluence_handler/confluence_handler.py | 44 +++++++++++++------ .../confluence_handler/connection_args.py | 29 ++++++++---- 3 files changed, 71 insertions(+), 26 deletions(-) diff --git a/mindsdb/integrations/handlers/confluence_handler/confluence_api_client.py b/mindsdb/integrations/handlers/confluence_handler/confluence_api_client.py index 2ce6dce173d..34d080627c7 100644 --- a/mindsdb/integrations/handlers/confluence_handler/confluence_api_client.py +++ b/mindsdb/integrations/handlers/confluence_handler/confluence_api_client.py @@ -1,17 +1,35 @@ -from typing import List +from typing import List, Optional import requests class ConfluenceAPIClient: - def __init__(self, url: str, username: str, password: str): + def __init__( + self, + url: str, + username: Optional[str] = None, + password: Optional[str] = None, + token: Optional[str] = None, + auth_method: Optional[str] = None, + ): self.url = url self.username = username self.password = password + self.token = token + self.auth_method = auth_method self.session = requests.Session() - self.session.auth = (self.username, self.password) self.session.headers.update({"Accept": "application/json"}) + use_bearer = (auth_method == "bearer") or bool(token) + if use_bearer: + if not token: + raise ValueError("Token must be provided for bearer authentication.") + self.session.headers.update({"Authorization": f"Bearer {token}"}) + else: + if not username or not password: + raise ValueError("Username and password must be provided for basic authentication.") + self.session.auth = (username, password) + def get_spaces( self, ids: List[int] = None, diff --git a/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py b/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py index d1af184b9a5..b78bfb139ac 100644 --- a/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py +++ b/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py @@ -58,28 +58,44 @@ def connect(self) -> ConfluenceAPIClient: ValueError: If the required connection parameters are not provided. Returns: - atlassian.confluence.Confluence: A connection object to the Confluence API. + ConfluenceAPIClient: A connection object to the Confluence API. """ if self.is_connected is True: return self.connection - if not all( - key in self.connection_data and self.connection_data.get(key) - for key in ["api_base", "username", "password"] - ): - raise ValueError( - "Required parameters (api_base, username, password) must be provided and should not be empty." - ) + api_base = self.connection_data.get("api_base") + username = self.connection_data.get("username") + password = self.connection_data.get("password") + token = self.connection_data.get("token") + auth_method = self.connection_data.get("auth_method") + + if not api_base: + raise ValueError("Required parameter 'api_base' must be provided and should not be empty.") + + if token or auth_method == "bearer": + if not token: + raise ValueError("Required parameter 'token' must be provided for bearer authentication.") - self.connection = ConfluenceAPIClient( - url=self.connection_data.get("api_base"), - username=self.connection_data.get("username"), - password=self.connection_data.get("password"), - ) + self.connection = ConfluenceAPIClient( + url=api_base, + token=token, + auth_method="bearer", + ) + else: + if not username or not password: + raise ValueError( + "Required parameters for basic auth (api_base, username, password) must be provided and should not be empty." + ) + + self.connection = ConfluenceAPIClient( + url=api_base, + username=username, + password=password, + ) self.is_connected = True return self.connection - + def check_connection(self) -> StatusResponse: """ Checks the status of the connection to the Confluence API. diff --git a/mindsdb/integrations/handlers/confluence_handler/connection_args.py b/mindsdb/integrations/handlers/confluence_handler/connection_args.py index 52734cda9bf..08787925d5b 100644 --- a/mindsdb/integrations/handlers/confluence_handler/connection_args.py +++ b/mindsdb/integrations/handlers/confluence_handler/connection_args.py @@ -1,8 +1,6 @@ from collections import OrderedDict - from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE - connection_args = OrderedDict( api_base={ "type": ARG_TYPE.URL, @@ -12,21 +10,34 @@ }, username={ "type": ARG_TYPE.STR, - "description": "The username for the Confluence account.", + "description": "The username for basic authentication.", "label": "Username", - "required": True + "required": False }, password={ "type": ARG_TYPE.STR, - "description": "The API token for the Confluence account.", + "description": "The password or API token for basic authentication.", "label": "Password", - "required": True, + "required": False, "secret": True + }, + token={ + "type": ARG_TYPE.STR, + "description": "The personal access token for bearer authentication.", + "label": "Token", + "required": False, + "secret": True + }, + auth_method={ + "type": ARG_TYPE.STR, + "description": "Authentication method to use. Supported values: 'basic', 'bearer'.", + "label": "Auth Method", + "required": False } ) connection_args_example = OrderedDict( api_base="https://marios.atlassian.net/", - username="your_username", - password="access_token" -) + token="your_personal_access_token", + auth_method="bearer" +) \ No newline at end of file From 520414a9502a46da72bc5c1b2ca5a3acc3f9feb3 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Mon, 30 Mar 2026 16:20:38 +0300 Subject: [PATCH 114/125] Fix formatting (#12344) --- .../confluence_handler/confluence_handler.py | 2 +- .../confluence_handler/connection_args.py | 18 +++-- .../handlers/hana_handler/hana_handler.py | 68 +++++++------------ 3 files changed, 35 insertions(+), 53 deletions(-) diff --git a/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py b/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py index b78bfb139ac..2708a511764 100644 --- a/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py +++ b/mindsdb/integrations/handlers/confluence_handler/confluence_handler.py @@ -95,7 +95,7 @@ def connect(self) -> ConfluenceAPIClient: self.is_connected = True return self.connection - + def check_connection(self) -> StatusResponse: """ Checks the status of the connection to the Confluence API. diff --git a/mindsdb/integrations/handlers/confluence_handler/connection_args.py b/mindsdb/integrations/handlers/confluence_handler/connection_args.py index 08787925d5b..27ab917846c 100644 --- a/mindsdb/integrations/handlers/confluence_handler/connection_args.py +++ b/mindsdb/integrations/handlers/confluence_handler/connection_args.py @@ -6,38 +6,36 @@ "type": ARG_TYPE.URL, "description": "The base URL of the Confluence instance/server.", "label": "Base URL", - "required": True + "required": True, }, username={ "type": ARG_TYPE.STR, "description": "The username for basic authentication.", "label": "Username", - "required": False + "required": False, }, password={ "type": ARG_TYPE.STR, "description": "The password or API token for basic authentication.", "label": "Password", "required": False, - "secret": True + "secret": True, }, token={ "type": ARG_TYPE.STR, "description": "The personal access token for bearer authentication.", "label": "Token", "required": False, - "secret": True + "secret": True, }, auth_method={ "type": ARG_TYPE.STR, "description": "Authentication method to use. Supported values: 'basic', 'bearer'.", "label": "Auth Method", - "required": False - } + "required": False, + }, ) connection_args_example = OrderedDict( - api_base="https://marios.atlassian.net/", - token="your_personal_access_token", - auth_method="bearer" -) \ No newline at end of file + api_base="https://marios.atlassian.net/", token="your_personal_access_token", auth_method="bearer" +) diff --git a/mindsdb/integrations/handlers/hana_handler/hana_handler.py b/mindsdb/integrations/handlers/hana_handler/hana_handler.py index 0b1053fc980..7899bbd5e33 100644 --- a/mindsdb/integrations/handlers/hana_handler/hana_handler.py +++ b/mindsdb/integrations/handlers/hana_handler/hana_handler.py @@ -11,7 +11,7 @@ from mindsdb.integrations.libs.response import ( HandlerStatusResponse as StatusResponse, HandlerResponse as Response, - RESPONSE_TYPE + RESPONSE_TYPE, ) from mindsdb.utilities import log @@ -24,7 +24,7 @@ class HanaHandler(DatabaseHandler): This handler handles the connection and execution of SQL statements on SAP HANA. """ - name = 'hana' + name = "hana" def __init__(self, name: Text, connection_data: Dict, **kwargs: Any) -> None: """ @@ -37,7 +37,7 @@ def __init__(self, name: Text, connection_data: Dict, **kwargs: Any) -> None: """ super().__init__(name) self.connection_data = connection_data - self.address = self.connection_data.get('address') + self.address = self.connection_data.get("address") self.kwargs = kwargs self.connection = None @@ -65,37 +65,35 @@ def connect(self) -> dbapi.Connection: return self.connection # Mandatory connection parameters. - if not all(key in self.connection_data for key in ['address', 'port', 'user', 'password']): - raise ValueError('Required parameters (address, port, user, password) must be provided.') + if not all(key in self.connection_data for key in ["address", "port", "user", "password"]): + raise ValueError("Required parameters (address, port, user, password) must be provided.") config = { - 'address': self.connection_data['address'], - 'port': self.connection_data['port'], - 'user': self.connection_data['user'], - 'password': self.connection_data['password'], + "address": self.connection_data["address"], + "port": self.connection_data["port"], + "user": self.connection_data["user"], + "password": self.connection_data["password"], } # Optional connection parameters. - if 'database' in self.connection_data: - config['databaseName'] = self.connection_data['database'] + if "database" in self.connection_data: + config["databaseName"] = self.connection_data["database"] - if 'schema' in self.connection_data: - config['currentSchema'] = self.connection_data['schema'] + if "schema" in self.connection_data: + config["currentSchema"] = self.connection_data["schema"] - if 'encrypt' in self.connection_data: - config['encrypt'] = self.connection_data['encrypt'] + if "encrypt" in self.connection_data: + config["encrypt"] = self.connection_data["encrypt"] try: - self.connection = dbapi.connect( - **config - ) + self.connection = dbapi.connect(**config) self.is_connected = True return self.connection except Error as known_error: - logger.error(f'Error connecting to SAP HANA, {known_error}!') + logger.error(f"Error connecting to SAP HANA, {known_error}!") raise except Exception as unknown_error: - logger.error(f'Unknown error connecting to SAP HANA, {unknown_error}!') + logger.error(f"Unknown error connecting to SAP HANA, {unknown_error}!") raise def disconnect(self) -> None: @@ -119,13 +117,13 @@ def check_connection(self) -> StatusResponse: try: connection = self.connect() with connection.cursor() as cur: - cur.execute('SELECT 1 FROM SYS.DUMMY') + cur.execute("SELECT 1 FROM SYS.DUMMY") response.success = True except (Error, ProgrammingError, ValueError) as known_error: - logger.error(f'Connection check to SAP HANA failed, {known_error}!') + logger.error(f"Connection check to SAP HANA failed, {known_error}!") response.error_message = str(known_error) except Exception as unknown_error: - logger.error(f'Connection check to SAP HANA failed due to an unknown error, {unknown_error}!') + logger.error(f"Connection check to SAP HANA failed due to an unknown error, {unknown_error}!") response.error_message = str(unknown_error) if response.success is True and need_to_close: @@ -155,29 +153,15 @@ def native_query(self, query: Text) -> Response: response = Response(RESPONSE_TYPE.OK) else: result = cur.fetchall() - response = Response( - RESPONSE_TYPE.TABLE, - DataFrame( - result, - columns=[x[0] for x in cur.description] - ) - ) + response = Response(RESPONSE_TYPE.TABLE, DataFrame(result, columns=[x[0] for x in cur.description])) connection.commit() except ProgrammingError as programming_error: - logger.error(f'Error running query: {query} on {self.address}!') - response = Response( - RESPONSE_TYPE.ERROR, - error_code=0, - error_message=str(programming_error) - ) + logger.error(f"Error running query: {query} on {self.address}!") + response = Response(RESPONSE_TYPE.ERROR, error_code=0, error_message=str(programming_error)) connection.rollback() except Exception as unknown_error: - logger.error(f'Unknown error running query: {query} on {self.address}!') - response = Response( - RESPONSE_TYPE.ERROR, - error_code=0, - error_message=str(unknown_error) - ) + logger.error(f"Unknown error running query: {query} on {self.address}!") + response = Response(RESPONSE_TYPE.ERROR, error_code=0, error_message=str(unknown_error)) connection.rollback() if need_to_close is True: From 6dca0b800a016e3b1accc896add20f30dc03c88e Mon Sep 17 00:00:00 2001 From: Ian Unsworth <86010258+ianu82@users.noreply.github.com> Date: Mon, 30 Mar 2026 14:46:01 +0100 Subject: [PATCH 115/125] Fix TaskThread crash on missing task record (#12262) --- mindsdb/interfaces/tasks/task_thread.py | 3 +++ .../unit/interfaces/tasks/test_task_thread.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 tests/unit/interfaces/tasks/test_task_thread.py diff --git a/mindsdb/interfaces/tasks/task_thread.py b/mindsdb/interfaces/tasks/task_thread.py index f753a59928a..8b9eb7ca9e5 100644 --- a/mindsdb/interfaces/tasks/task_thread.py +++ b/mindsdb/interfaces/tasks/task_thread.py @@ -23,6 +23,9 @@ def run(self): # create context and session task_record = db.Tasks.query.get(self.task_id) + if task_record is None: + logger.error(f"Task record not found: {self.task_id}") + return ctx.set_default() ctx.company_id = task_record.company_id diff --git a/tests/unit/interfaces/tasks/test_task_thread.py b/tests/unit/interfaces/tasks/test_task_thread.py new file mode 100644 index 00000000000..ee1218bf2a6 --- /dev/null +++ b/tests/unit/interfaces/tasks/test_task_thread.py @@ -0,0 +1,18 @@ +from unittest.mock import patch + +from mindsdb.interfaces.tasks.task_thread import TaskThread + + +def test_run_returns_when_task_record_is_missing(): + thread = TaskThread(task_id=123) + + with ( + patch("mindsdb.interfaces.tasks.task_thread.db.Tasks.query.get", return_value=None) as get_mock, + patch("mindsdb.interfaces.tasks.task_thread.ctx.set_default") as set_default_mock, + patch("mindsdb.interfaces.tasks.task_thread.db.session.commit") as commit_mock, + ): + thread.run() + + get_mock.assert_called_once_with(123) + set_default_mock.assert_not_called() + commit_mock.assert_not_called() From 5986c7f4dec812b958e44fcdd133e860de578d71 Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Mon, 30 Mar 2026 19:17:27 +0300 Subject: [PATCH 116/125] del useless test --- .../unit/interfaces/tasks/test_task_thread.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 tests/unit/interfaces/tasks/test_task_thread.py diff --git a/tests/unit/interfaces/tasks/test_task_thread.py b/tests/unit/interfaces/tasks/test_task_thread.py deleted file mode 100644 index ee1218bf2a6..00000000000 --- a/tests/unit/interfaces/tasks/test_task_thread.py +++ /dev/null @@ -1,18 +0,0 @@ -from unittest.mock import patch - -from mindsdb.interfaces.tasks.task_thread import TaskThread - - -def test_run_returns_when_task_record_is_missing(): - thread = TaskThread(task_id=123) - - with ( - patch("mindsdb.interfaces.tasks.task_thread.db.Tasks.query.get", return_value=None) as get_mock, - patch("mindsdb.interfaces.tasks.task_thread.ctx.set_default") as set_default_mock, - patch("mindsdb.interfaces.tasks.task_thread.db.session.commit") as commit_mock, - ): - thread.run() - - get_mock.assert_called_once_with(123) - set_default_mock.assert_not_called() - commit_mock.assert_not_called() From de757a96d1b08f3c731d981faf526f265ec143c0 Mon Sep 17 00:00:00 2001 From: Hamish Fagg Date: Tue, 31 Mar 2026 08:33:11 +1300 Subject: [PATCH 117/125] Add healthcheck to container (#11981) Co-authored-by: Lucas Koontz Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docker/mindsdb.Dockerfile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/mindsdb.Dockerfile b/docker/mindsdb.Dockerfile index 1da9a8250ee..54e86d5e14b 100644 --- a/docker/mindsdb.Dockerfile +++ b/docker/mindsdb.Dockerfile @@ -93,6 +93,8 @@ ENV PATH=/venv/bin:$PATH EXPOSE 47334/tcp EXPOSE 47335/tcp +HEALTHCHECK --interval=30s --timeout=10s --retries=5 --start-period=60s CMD curl -fsS "http://localhost:47334/api/status" + # Pre-load web GUI RUN python -m mindsdb --config=/root/mindsdb_config.json --update-gui From c6c8a05d6f63fdb65587d40c960269129b7cfe43 Mon Sep 17 00:00:00 2001 From: Andrey Date: Wed, 1 Apr 2026 11:36:47 +0300 Subject: [PATCH 118/125] Fix static errors (#12349) Co-authored-by: Sriram-B-Srivatsa <144884365+Sriram-B-Srivatsa@users.noreply.github.com> --- .../handlers/chromadb_handler/__init__.py | 2 + .../chromadb_handler/chromadb_handler.py | 48 +++++++---- .../handlers/chromadb_handler/settings.py | 13 +-- .../tests/test_chromadb_handler.py | 83 +++++++++++++++++++ .../handlers/ollama_handler/ollama_handler.py | 16 +++- 5 files changed, 134 insertions(+), 28 deletions(-) create mode 100644 mindsdb/integrations/handlers/chromadb_handler/tests/test_chromadb_handler.py diff --git a/mindsdb/integrations/handlers/chromadb_handler/__init__.py b/mindsdb/integrations/handlers/chromadb_handler/__init__.py index 9c5a069c83f..05d30cf7c7b 100644 --- a/mindsdb/integrations/handlers/chromadb_handler/__init__.py +++ b/mindsdb/integrations/handlers/chromadb_handler/__init__.py @@ -3,8 +3,10 @@ from .__about__ import __description__ as description from .__about__ import __version__ as version from .connection_args import connection_args, connection_args_example + try: from .chromadb_handler import ChromaDBHandler as Handler + import_error = None except Exception as e: Handler = None diff --git a/mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py b/mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py index 61a5b439d12..32d0e566b00 100644 --- a/mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py +++ b/mindsdb/integrations/handlers/chromadb_handler/chromadb_handler.py @@ -215,17 +215,22 @@ def select( include = ["metadatas", "documents", "embeddings"] - # check if embedding vector filter is present - vector_filter = ( - [] - if conditions is None - else [condition for condition in conditions if condition.column == TableField.EMBEDDINGS.value] - ) + # Identify Search Intent + vector_filter = None + content_filter = None - if len(vector_filter) > 0: - vector_filter = vector_filter[0] - else: - vector_filter = None + if conditions is not None: + # Embeddings + v_filters = [c for c in conditions if c.column == TableField.EMBEDDINGS.value] + if v_filters: + vector_filter = v_filters[0] + + # Semantic Search + c_filters = [c for c in conditions if c.column == TableField.CONTENT.value] + if c_filters: + content_filter = c_filters[0] + + # ID Filtering ids_include = [] ids_exclude = [] @@ -242,14 +247,26 @@ def select( elif condition.op == FilterOperator.NOT_IN: ids_exclude.extend(condition.value) - if vector_filter is not None: - # similarity search + # Trigger search if Vector OR Content is present + if vector_filter is not None or content_filter is not None: + # Similarity search query_payload = { "where": filters, - "query_embeddings": vector_filter.value if vector_filter is not None else None, "include": include + ["distances"], } + # Handle Vector Search + if vector_filter: + query_payload["query_embeddings"] = vector_filter.value + + # Handle Text Search + if content_filter: + val = content_filter.value + if isinstance(val, list): + query_payload["query_texts"] = val + else: + query_payload["query_texts"] = [val] + if limit is not None: if len(ids_include) == 0 and len(ids_exclude) == 0: query_payload["n_results"] = limit @@ -265,7 +282,7 @@ def select( embeddings = result["embeddings"][0] else: - # general get query + # general get query (Exact Match) result = collection.get( ids=ids_include or None, where=filters, @@ -279,7 +296,6 @@ def select( embeddings = result["embeddings"] distances = None - # project based on columns payload = { TableField.ID.value: ids, TableField.CONTENT.value: documents, @@ -290,7 +306,7 @@ def select( if columns is not None: payload = {column: payload[column] for column in columns if column != TableField.DISTANCE.value} - # always include distance + # Include distance distance_filter = None distance_col = TableField.DISTANCE.value if distances is not None: diff --git a/mindsdb/integrations/handlers/chromadb_handler/settings.py b/mindsdb/integrations/handlers/chromadb_handler/settings.py index 2b669ed75a8..279c404384e 100644 --- a/mindsdb/integrations/handlers/chromadb_handler/settings.py +++ b/mindsdb/integrations/handlers/chromadb_handler/settings.py @@ -14,7 +14,7 @@ class ChromaHandlerConfig(BaseModel): host: str = None port: str = None password: str = None - distance: str = 'cosine' + distance: str = "cosine" class Config: extra = "forbid" @@ -27,13 +27,9 @@ def check_param_typos(cls, values: Any) -> Any: expected_params = cls.model_fields.keys() for key in values.keys(): if key not in expected_params: - close_matches = difflib.get_close_matches( - key, expected_params, cutoff=0.4 - ) + close_matches = difflib.get_close_matches(key, expected_params, cutoff=0.4) if close_matches: - raise ValueError( - f"Unexpected parameter '{key}'. Did you mean '{close_matches[0]}'?" - ) + raise ValueError(f"Unexpected parameter '{key}'. Did you mean '{close_matches[0]}'?") else: raise ValueError(f"Unexpected parameter '{key}'.") return values @@ -56,8 +52,7 @@ def check_config(cls, values: Any) -> Any: if persist_directory and (host or port): raise ValueError( - f"For {vector_store} handler - if persistence_folder is provided, " - f"host, port should not be provided." + f"For {vector_store} handler - if persistence_folder is provided, host, port should not be provided." ) return values diff --git a/mindsdb/integrations/handlers/chromadb_handler/tests/test_chromadb_handler.py b/mindsdb/integrations/handlers/chromadb_handler/tests/test_chromadb_handler.py new file mode 100644 index 00000000000..d3e5d330d16 --- /dev/null +++ b/mindsdb/integrations/handlers/chromadb_handler/tests/test_chromadb_handler.py @@ -0,0 +1,83 @@ +import unittest +from unittest.mock import Mock, patch +import pandas as pd +from mindsdb.integrations.handlers.chromadb_handler.chromadb_handler import ( + ChromaDBHandler, + TableField, +) + + +class MockCondition: + def __init__(self, column, op, value): + self.column = column + self.op = op + self.value = value + + +class TestChromaHandler(unittest.TestCase): + def setUp(self): + self.handler = ChromaDBHandler(name="test_chroma", connection_data={}, handler_storage=Mock()) + + # INSERT + @patch("mindsdb.integrations.handlers.chromadb_handler.chromadb_handler.ChromaDBHandler.connect") + def test_insert_calls_upsert(self, mock_connect): + mock_client = Mock() + mock_collection = Mock() + mock_client.get_or_create_collection.return_value = mock_collection + self.handler._client = mock_client + self.handler.is_connected = True + + df = pd.DataFrame( + { + TableField.CONTENT.value: ["Cat Photo"], + TableField.EMBEDDINGS.value: [[0.9, 0.1, 0.1]], + TableField.ID.value: ["img_1"], + TableField.METADATA.value: [{"author": "Sriram"}], + } + ) + self.handler.insert("my_gallery", df) + + call_args = mock_collection.upsert.call_args[1] + self.assertEqual(call_args["embeddings"], [[0.9, 0.1, 0.1]]) + + # SELECT + @patch("mindsdb.integrations.handlers.chromadb_handler.chromadb_handler.ChromaDBHandler.disconnect") + @patch("mindsdb.integrations.handlers.chromadb_handler.chromadb_handler.ChromaDBHandler.connect") + def test_select_semantic_search(self, mock_connect, mock_disconnect): + # Mock System + mock_client = Mock() + mock_collection = Mock() + mock_client.get_collection.return_value = mock_collection + + self.handler._client = mock_client + self.handler.is_connected = True + + # Mock Return Data + mock_result = { + "ids": [["id1"]], + "documents": [["Dog"]], + "metadatas": [[{}]], + "embeddings": [[[0.1, 0.2]]], + "distances": [[0.5]], + } + mock_collection.query.return_value = mock_result + mock_collection.get.return_value = mock_result + + conditions = [MockCondition(column=TableField.CONTENT.value, op="=", value="Dog")] + + self.handler.select("my_gallery", conditions=conditions) + + # Verification + if not mock_collection.query.called: + self.fail("CRITICAL: The handler used .get() (Exact Match) instead of .query() (Semantic Search)!") + + call_args = mock_collection.query.call_args[1] + + if "query_texts" not in call_args: + self.fail("CRITICAL: The handler called .query() but forgot 'query_texts'!") + + self.assertEqual(call_args["query_texts"], ["Dog"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py b/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py index 74923d640a8..639345933fa 100644 --- a/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py +++ b/mindsdb/integrations/handlers/ollama_handler/ollama_handler.py @@ -100,9 +100,19 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame pred_args = args.get("predict_params", {}) args = self.model_storage.json_get("args") model_name, target_col = args["model_name"], args["target"] - prompt_template = pred_args.get( - "prompt_template", args.get("prompt_template", "Answer the following question: {{{{text}}}}") - ) + + # Auto-detect column if template is missing + # If user provided a specific template + user_template = pred_args.get("prompt_template", args.get("prompt_template")) + + # OR If no template and 'text' column is missing, then auto-detect + if user_template is None and "text" not in df.columns and len(df.columns) == 1: + col_name = df.columns[0] + # Create a template dynamically + prompt_template = "Answer the following question: {{{{" + col_name + "}}}}" + else: + # Fallback: Use user template OR default to 'text' (Old behavior) + prompt_template = user_template if user_template else "Answer the following question: {{{{text}}}}" # prepare prompts prompts, empty_prompt_ids = get_completed_prompts(prompt_template, df) From 1153bdffb0057251ed5e7b51bbd402a8ae5076ff Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Wed, 1 Apr 2026 11:38:05 +0300 Subject: [PATCH 119/125] Fix subquery references (#12337) --- mindsdb/api/executor/utilities/sql.py | 22 +++++++- tests/unit/executor/test_executor.py | 73 ++++++++++++++++++++++++++- tests/unit/executor/test_files.py | 4 +- 3 files changed, 94 insertions(+), 5 deletions(-) diff --git a/mindsdb/api/executor/utilities/sql.py b/mindsdb/api/executor/utilities/sql.py index 48091e5fff1..f02a9e02d67 100644 --- a/mindsdb/api/executor/utilities/sql.py +++ b/mindsdb/api/executor/utilities/sql.py @@ -245,12 +245,30 @@ def query_dfs(dataframes, query_ast, session=None): else: user_functions = None + # region collect table aliases. Strip schema/db prefix from column identifiers, but keep table aliases. + # Examples: + # files.col = 1 -> col = 1 (schema prefix stripped) + # files.a1.col = 1 -> a1.col = 1 (schema prefix stripped, alias kept) + # a1.col = a2.col -> a1.col = a2.col (aliases untouched, no schema prefix) + # "Custom SQL Query".col -> col (replaced subquery alias stripped) + known_aliases = set() + + def collect_aliases(node, is_table, **kwargs): + if not is_table or not isinstance(node, Identifier): + return + known_aliases.add(node.parts[-1].lower()) + if node.alias is not None: + known_aliases.add(node.alias.parts[-1].lower()) + + query_traversal(query_ast, collect_aliases) + # endregion + def adapt_query(node, is_table, **kwargs): if is_table: return if isinstance(node, Identifier): - if len(node.parts) > 1: - node.parts = [node.parts[-1]] + if len(node.parts) > 1 and node.parts[0].lower() not in known_aliases: + node.parts = node.parts[1:] return node if isinstance(node, Function): fnc = mysql_to_duckdb_fnc(node) diff --git a/tests/unit/executor/test_executor.py b/tests/unit/executor/test_executor.py index c901e7bab55..89a4acdfda5 100644 --- a/tests/unit/executor/test_executor.py +++ b/tests/unit/executor/test_executor.py @@ -11,7 +11,9 @@ from mindsdb.utilities.render.sqlalchemy_render import SqlalchemyRender -from mindsdb.api.executor.utilities.sql import query_df +from mindsdb_sql_parser import parse_sql + +from mindsdb.api.executor.utilities.sql import query_df, query_dfs # How to run: # env PYTHONPATH=./ pytest tests/unit/test_executor.py @@ -1618,6 +1620,75 @@ def test_query_df_functions(self): result = query_df(df, query)["result"][0] assert isinstance(result, dt.time) + def test_not_exists_correlated_subquery(self): + a = pd.DataFrame( + [ + {"tab_num": 1, "shop": 1}, + {"tab_num": 1, "shop": 2}, + {"tab_num": 1, "shop": 3}, + {"tab_num": 2, "shop": 1}, + {"tab_num": 2, "shop": 2}, + {"tab_num": 3, "shop": 1}, + ] + ) + b = pd.DataFrame([{"shop": 1}, {"shop": 2}, {"shop": 3}]) + + result = query_dfs( + {"A": a, "B": b}, + parse_sql( + """ + SELECT DISTINCT a1.tab_num + FROM A a1 + WHERE NOT EXISTS ( + SELECT * FROM B b + WHERE NOT EXISTS ( + SELECT * FROM A a2 + WHERE a2.tab_num = a1.tab_num AND a2.shop = b.shop + ) + ) + """, + dialect="mindsdb", + ), + ) + + # Only tab_num=1 covers all shops {1, 2, 3} + assert list(result["tab_num"]) == [1] + + def test_exists_correlated_subquery(self): + # EXISTS version: find tab_num values missing at least one shop. + # tab_num=2 misses shop=3, tab_num=3 misses shops 2 and 3. + a = pd.DataFrame( + [ + {"tab_num": 1, "shop": 1}, + {"tab_num": 1, "shop": 2}, + {"tab_num": 1, "shop": 3}, + {"tab_num": 2, "shop": 1}, + {"tab_num": 2, "shop": 2}, + {"tab_num": 3, "shop": 1}, + ] + ) + b = pd.DataFrame([{"shop": 1}, {"shop": 2}, {"shop": 3}]) + + result = query_dfs( + {"A": a, "B": b}, + parse_sql( + """ + SELECT DISTINCT a1.tab_num + FROM A a1 + WHERE EXISTS ( + SELECT * FROM B b + WHERE NOT EXISTS ( + SELECT * FROM A a2 + WHERE a2.tab_num = a1.tab_num AND a2.shop = b.shop + ) + ) + """, + dialect="mindsdb", + ), + ) + + assert sorted(result["tab_num"].tolist()) == [2, 3] + class TestIfExistsIfNotExists(BaseExecutorMockPredictor): def setup_method(self, method): diff --git a/tests/unit/executor/test_files.py b/tests/unit/executor/test_files.py index 0181da273fa..cdbee61fbdb 100644 --- a/tests/unit/executor/test_files.py +++ b/tests/unit/executor/test_files.py @@ -152,8 +152,8 @@ def test_multi_table_relational_division(self): """ ) - assert len(result) == 3 - assert sorted(result["tab_num"].tolist()) == [1, 2, 3] + assert len(result) == 2 + assert sorted(result["tab_num"].tolist()) == [1, 2] def test_multi_table_join_with_aliases(self): """Test JOIN with aliases and database prefixes""" From 4ae5e51d8c0d6dacfdf38e417733ca553367d4c9 Mon Sep 17 00:00:00 2001 From: Faridun Mirzoev Date: Wed, 1 Apr 2026 13:13:49 -0700 Subject: [PATCH 120/125] feat: add AG2 multi-agent handler Add a new ML handler for AG2 (formerly AutoGen), an open-source multi-agent framework, enabling users to create and query multi-agent teams via SQL. Supports single-agent and GroupChat modes with configurable agents, speaker selection, and max rounds. --- .../handlers/ag2_handler/README.md | 99 +++++++ .../handlers/ag2_handler/__about__.py | 9 + .../handlers/ag2_handler/__init__.py | 32 +++ .../handlers/ag2_handler/ag2_handler.py | 260 ++++++++++++++++++ .../handlers/ag2_handler/creation_args.py | 31 +++ .../handlers/ag2_handler/icon.svg | 1 + .../handlers/ag2_handler/model_using_args.py | 1 + .../handlers/ag2_handler/requirements.txt | 1 + .../handlers/ag2_handler/tests/__init__.py | 0 .../ag2_handler/tests/test_ag2_handler.py | 230 ++++++++++++++++ 10 files changed, 664 insertions(+) create mode 100644 mindsdb/integrations/handlers/ag2_handler/README.md create mode 100644 mindsdb/integrations/handlers/ag2_handler/__about__.py create mode 100644 mindsdb/integrations/handlers/ag2_handler/__init__.py create mode 100644 mindsdb/integrations/handlers/ag2_handler/ag2_handler.py create mode 100644 mindsdb/integrations/handlers/ag2_handler/creation_args.py create mode 100644 mindsdb/integrations/handlers/ag2_handler/icon.svg create mode 100644 mindsdb/integrations/handlers/ag2_handler/model_using_args.py create mode 100644 mindsdb/integrations/handlers/ag2_handler/requirements.txt create mode 100644 mindsdb/integrations/handlers/ag2_handler/tests/__init__.py create mode 100644 mindsdb/integrations/handlers/ag2_handler/tests/test_ag2_handler.py diff --git a/mindsdb/integrations/handlers/ag2_handler/README.md b/mindsdb/integrations/handlers/ag2_handler/README.md new file mode 100644 index 00000000000..aa2fe180d50 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/README.md @@ -0,0 +1,99 @@ +# AG2 Handler + +This handler integrates [AG2](https://ag2.ai), an open-source multi-agent framework, with MindsDB. It enables creating and querying multi-agent teams via SQL. + +AG2 (formerly AutoGen) has 500K+ monthly PyPI downloads, 4,300+ GitHub stars, and 400+ contributors. + +## Setup + +### Install dependencies + +```bash +pip install "ag2[openai]>=0.11.4,<1.0" +``` + +### Create an ML engine + +```sql +CREATE ML_ENGINE ag2_engine +FROM ag2 +USING openai_api_key = 'your-key-here'; +``` + +## Usage + +### Create a multi-agent model + +```sql +CREATE MODEL research_team +PREDICT answer +USING + engine = 'ag2_engine', + agents = '[ + {"name": "Researcher", "system_message": "You research topics and provide key facts with sources."}, + {"name": "Writer", "system_message": "You write clear, engaging summaries from research findings."}, + {"name": "Critic", "system_message": "You review content for accuracy. Say TERMINATE when approved."} + ]', + max_rounds = 8, + speaker_selection = 'auto'; +``` + +### Query the agents + +```sql +SELECT answer +FROM research_team +WHERE question = 'What are the main benefits of retrieval-augmented generation?'; +``` + +### Batch queries + +```sql +SELECT t.question, m.answer +FROM my_questions AS t +JOIN research_team AS m; +``` + +## Configuration + +### Engine arguments + +| Parameter | Required | Default | Description | +|-----------|----------|---------|-------------| +| `openai_api_key` | Yes | — | API key for the agents' LLM | +| `model` | No | `gpt-4o-mini` | LLM model name | +| `api_type` | No | `openai` | API type (openai, anthropic, etc.) | +| `api_base` | No | — | Custom API base URL | + +### Model arguments + +| Parameter | Required | Default | Description | +|-----------|----------|---------|-------------| +| `agents` | No | Single assistant | JSON list of agent definitions | +| `max_rounds` | No | `8` | Max GroupChat rounds | +| `speaker_selection` | No | `auto` | Speaker selection: auto, round_robin, random | +| `mode` | No | `groupchat` | Mode: single or groupchat | + +### Agent definition format + +```json +[ + { + "name": "AgentName", + "system_message": "Agent's role and instructions." + } +] +``` + +## Modes + +- **single**: One assistant agent handles the query directly +- **groupchat**: Multiple agents collaborate via GroupChat with automatic speaker selection + +## Describe + +```sql +DESCRIBE MODEL research_team; +DESCRIBE MODEL research_team ATTRIBUTE args; +DESCRIBE MODEL research_team ATTRIBUTE agents; +``` diff --git a/mindsdb/integrations/handlers/ag2_handler/__about__.py b/mindsdb/integrations/handlers/ag2_handler/__about__.py new file mode 100644 index 00000000000..34ad701e213 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/__about__.py @@ -0,0 +1,9 @@ +__title__ = "MindsDB AG2 handler" +__package_name__ = "mindsdb_ag2_handler" +__version__ = "0.0.1" +__description__ = "MindsDB handler for AG2 multi-agent framework" +__author__ = "Faridun Mirzoev" +__github__ = "https://github.com/mindsdb/mindsdb" +__pypi__ = "https://pypi.org/project/mindsdb/" +__license__ = "MIT" +__copyright__ = "Copyright 2024- mindsdb" diff --git a/mindsdb/integrations/handlers/ag2_handler/__init__.py b/mindsdb/integrations/handlers/ag2_handler/__init__.py new file mode 100644 index 00000000000..0791ce71b60 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/__init__.py @@ -0,0 +1,32 @@ +from mindsdb.integrations.libs.const import HANDLER_TYPE + +from .__about__ import __version__ as version, __description__ as description +from .creation_args import creation_args +from .model_using_args import model_using_args + +try: + from .ag2_handler import AG2Handler as Handler + + import_error = None +except Exception as e: + Handler = None + import_error = e + +title = "AG2" +name = "ag2" +type = HANDLER_TYPE.ML +icon_path = "icon.svg" +permanent = False + +__all__ = [ + "Handler", + "version", + "name", + "type", + "title", + "description", + "import_error", + "icon_path", + "creation_args", + "model_using_args", +] diff --git a/mindsdb/integrations/handlers/ag2_handler/ag2_handler.py b/mindsdb/integrations/handlers/ag2_handler/ag2_handler.py new file mode 100644 index 00000000000..b67b41936a1 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/ag2_handler.py @@ -0,0 +1,260 @@ +"""AG2 multi-agent handler for MindsDB. + +Enables creating and querying AG2 multi-agent GroupChats via SQL. + +Usage: + -- Create engine + CREATE ML_ENGINE ag2_engine + FROM ag2 + USING openai_api_key = 'sk-...'; + + -- Create model (agent team) + CREATE MODEL my_agent_team + PREDICT answer + USING + engine = 'ag2_engine', + agents = '[ + {"name": "Researcher", "system_message": "You research topics thoroughly."}, + {"name": "Writer", "system_message": "You write clear summaries."}, + {"name": "Critic", "system_message": "You review for accuracy. Say TERMINATE when done."} + ]', + max_rounds = 8; + + -- Query the agent team + SELECT answer + FROM my_agent_team + WHERE question = 'Explain how transformers work'; +""" + +import json +import os +from typing import Any, Dict, Optional + +import pandas as pd + +from mindsdb.integrations.libs.base import BaseMLEngine +from mindsdb.utilities import log + +logger = log.getLogger(__name__) + + +class AG2Handler(BaseMLEngine): + """Handler for AG2 multi-agent framework.""" + + name = "ag2" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.generative = True + + def create_engine(self, connection_args: Dict) -> None: + """Validate engine connection args (API key).""" + api_key = connection_args.get("openai_api_key") or os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("openai_api_key is required. Pass it in USING clause or set OPENAI_API_KEY env var.") + + try: + from autogen import LLMConfig + + model = connection_args.get("model", "gpt-4o-mini") + api_type = connection_args.get("api_type", "openai") + + config = {"model": model, "api_key": api_key, "api_type": api_type} + if connection_args.get("api_base"): + config["base_url"] = connection_args["api_base"] + + LLMConfig(config) + except ImportError: + raise ImportError('AG2 is not installed. Run: pip install "ag2[openai]>=0.11.4,<1.0"') + except Exception as e: + raise ValueError(f"Failed to validate AG2 configuration: {e}") + + @staticmethod + def create_validation(target: str, args: Optional[Dict] = None, **kwargs: Any) -> None: + """Validate model creation args.""" + using_args = args.get("using", {}) + + agents_json = using_args.get("agents") + if agents_json: + try: + agents = json.loads(agents_json) if isinstance(agents_json, str) else agents_json + if not isinstance(agents, list) or len(agents) == 0: + raise ValueError("'agents' must be a non-empty JSON list.") + for agent in agents: + if "name" not in agent: + raise ValueError("Each agent must have a 'name' field.") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid 'agents' JSON: {e}") + + mode = using_args.get("mode", "groupchat") + if mode not in ("single", "groupchat"): + raise ValueError(f"Invalid mode '{mode}'. Must be 'single' or 'groupchat'.") + + selection = using_args.get("speaker_selection", "auto") + if selection not in ("auto", "round_robin", "random"): + raise ValueError(f"Invalid speaker_selection '{selection}'. Must be 'auto', 'round_robin', or 'random'.") + + def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: + """Store model configuration.""" + using_args = args.get("using", {}) + using_args["target"] = target + self.model_storage.json_set("args", using_args) + + def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: + """Run AG2 agents for each row in the input DataFrame. + + Expects a 'question' column. Returns a DataFrame with the target column. + """ + from autogen import LLMConfig + + stored_args = self.model_storage.json_get("args") + predict_args = args.get("predict_params", {}) if args else {} + merged_args = {**stored_args, **predict_args} + + # Build LLM config from engine args + engine_args = self.engine_storage.get_connection_args() + model = merged_args.get("model", engine_args.get("model", "gpt-4o-mini")) + api_type = merged_args.get("api_type", engine_args.get("api_type", "openai")) + + api_key = engine_args.get("openai_api_key") or os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("openai_api_key not found. Pass it in USING clause or set OPENAI_API_KEY env var.") + + config = { + "model": model, + "api_key": api_key, + "api_type": api_type, + } + if engine_args.get("api_base"): + config["base_url"] = engine_args["api_base"] + + llm_config = LLMConfig(config) + + # Parse agent definitions + agents_json = merged_args.get("agents") + if agents_json: + agent_defs = json.loads(agents_json) if isinstance(agents_json, str) else agents_json + else: + agent_defs = [ + { + "name": "Assistant", + "system_message": ( + "You are a helpful AI assistant. Provide clear, comprehensive " + "answers. Reply TERMINATE when the task is complete." + ), + }, + ] + + mode = merged_args.get("mode", "groupchat" if len(agent_defs) > 1 else "single") + max_rounds = int(merged_args.get("max_rounds", 8)) + speaker_selection = merged_args.get("speaker_selection", "auto") + + # Determine question column + question_col = "question" + if question_col not in df.columns: + question_col = df.columns[0] + + target = merged_args.get("target", "answer") + + results = [] + for _, row in df.iterrows(): + question = str(row[question_col]) + + try: + answer = self._run_agents( + llm_config=llm_config, + agent_defs=agent_defs, + question=question, + mode=mode, + max_rounds=max_rounds, + speaker_selection=speaker_selection, + ) + results.append({target: answer}) + except Exception as e: + logger.error(f"AG2 prediction error: {e}") + results.append({target: f"Error: {e}"}) + + return pd.DataFrame(results) + + def _run_agents( + self, + llm_config, + agent_defs: list, + question: str, + mode: str, + max_rounds: int, + speaker_selection: str, + ) -> str: + """Execute AG2 agent conversation and return the final answer.""" + from autogen import AssistantAgent, GroupChat, GroupChatManager, UserProxyAgent + + agents = [] + for agent_def in agent_defs: + agent = AssistantAgent( + name=agent_def["name"], + system_message=agent_def.get( + "system_message", + f"You are {agent_def['name']}. Be helpful and concise.", + ), + llm_config=llm_config, + ) + agents.append(agent) + + user_proxy = UserProxyAgent( + name="User", + human_input_mode="NEVER", + max_consecutive_auto_reply=0, + code_execution_config=False, + ) + + if mode == "single": + user_proxy.run(agents[0], message=question).process() + messages = agents[0].chat_messages.get(user_proxy, []) + else: + group_chat = GroupChat( + agents=[user_proxy] + agents, + messages=[], + max_round=max_rounds, + speaker_selection_method=speaker_selection, + ) + manager = GroupChatManager( + groupchat=group_chat, + llm_config=llm_config, + ) + user_proxy.run(manager, message=question).process() + messages = group_chat.messages + + # Extract last non-user, non-empty message as the answer + answer = "" + for msg in reversed(messages): + content = msg.get("content", "").strip() + name = msg.get("name", "") + if content and name != "User": + answer = content.replace("TERMINATE", "").strip() + if answer: + break + + return answer or "No answer generated." + + def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: + """Describe the model configuration.""" + stored_args = self.model_storage.json_get("args") + + if attribute == "args": + return pd.DataFrame([stored_args]) + elif attribute == "agents": + agents_json = stored_args.get("agents", "[]") + agents = json.loads(agents_json) if isinstance(agents_json, str) else agents_json + return pd.DataFrame(agents) if agents else pd.DataFrame() + else: + agents_raw = stored_args.get("agents", "[]") + agents = json.loads(agents_raw) if isinstance(agents_raw, str) else agents_raw + info = { + "name": "AG2 Multi-Agent Handler", + "version": "0.0.1", + "mode": stored_args.get("mode", "groupchat"), + "max_rounds": stored_args.get("max_rounds", 8), + "speaker_selection": stored_args.get("speaker_selection", "auto"), + "num_agents": len(agents) if isinstance(agents, list) else 0, + } + return pd.DataFrame([info]) diff --git a/mindsdb/integrations/handlers/ag2_handler/creation_args.py b/mindsdb/integrations/handlers/ag2_handler/creation_args.py new file mode 100644 index 00000000000..a7a14d7a4c2 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/creation_args.py @@ -0,0 +1,31 @@ +from collections import OrderedDict + +from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE + +creation_args = OrderedDict( + openai_api_key={ + "type": ARG_TYPE.STR, + "description": "OpenAI API key for the agents LLM backend. Falls back to OPENAI_API_KEY env var.", + "required": False, + "label": "OpenAI API key", + "secret": True, + }, + model={ + "type": ARG_TYPE.STR, + "description": "LLM model name (default: gpt-4o-mini).", + "required": False, + "label": "Model name", + }, + api_type={ + "type": ARG_TYPE.STR, + "description": "LLM API type: openai, anthropic, bedrock, etc. (default: openai).", + "required": False, + "label": "API type", + }, + api_base={ + "type": ARG_TYPE.STR, + "description": "Custom API base URL for OpenAI-compatible endpoints.", + "required": False, + "label": "API base URL", + }, +) diff --git a/mindsdb/integrations/handlers/ag2_handler/icon.svg b/mindsdb/integrations/handlers/ag2_handler/icon.svg new file mode 100644 index 00000000000..a2cbe1d7425 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/icon.svg @@ -0,0 +1 @@ + diff --git a/mindsdb/integrations/handlers/ag2_handler/model_using_args.py b/mindsdb/integrations/handlers/ag2_handler/model_using_args.py new file mode 100644 index 00000000000..24da7bcae0c --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/model_using_args.py @@ -0,0 +1 @@ +model_using_args = {"openai_api_key": {"secret": True}} diff --git a/mindsdb/integrations/handlers/ag2_handler/requirements.txt b/mindsdb/integrations/handlers/ag2_handler/requirements.txt new file mode 100644 index 00000000000..39b5451ec17 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/requirements.txt @@ -0,0 +1 @@ +ag2[openai]>=0.11.4,<1.0 diff --git a/mindsdb/integrations/handlers/ag2_handler/tests/__init__.py b/mindsdb/integrations/handlers/ag2_handler/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/mindsdb/integrations/handlers/ag2_handler/tests/test_ag2_handler.py b/mindsdb/integrations/handlers/ag2_handler/tests/test_ag2_handler.py new file mode 100644 index 00000000000..9fa0e9e8c23 --- /dev/null +++ b/mindsdb/integrations/handlers/ag2_handler/tests/test_ag2_handler.py @@ -0,0 +1,230 @@ +"""Tests for the AG2 handler.""" + +import json +import unittest +from unittest.mock import MagicMock, patch + +import pandas as pd + +from mindsdb.integrations.handlers.ag2_handler.ag2_handler import AG2Handler + + +class TestAG2HandlerValidation(unittest.TestCase): + """Test AG2Handler validation methods.""" + + def test_create_validation_valid_agents(self): + args = { + "using": { + "agents": json.dumps( + [ + {"name": "Agent1", "system_message": "You are agent 1."}, + {"name": "Agent2", "system_message": "You are agent 2."}, + ] + ), + } + } + # Should not raise + AG2Handler.create_validation("answer", args) + + def test_create_validation_no_agents(self): + args = {"using": {}} + # Should not raise — agents are optional + AG2Handler.create_validation("answer", args) + + def test_create_validation_invalid_agents_json(self): + args = {"using": {"agents": "not-valid-json"}} + with self.assertRaises(ValueError): + AG2Handler.create_validation("answer", args) + + def test_create_validation_empty_agents_list(self): + args = {"using": {"agents": "[]"}} + with self.assertRaises(ValueError): + AG2Handler.create_validation("answer", args) + + def test_create_validation_missing_agent_name(self): + args = { + "using": { + "agents": json.dumps([{"system_message": "No name here."}]), + } + } + with self.assertRaises(ValueError): + AG2Handler.create_validation("answer", args) + + def test_create_validation_invalid_mode(self): + args = {"using": {"mode": "invalid"}} + with self.assertRaises(ValueError): + AG2Handler.create_validation("answer", args) + + def test_create_validation_valid_modes(self): + for mode in ("single", "groupchat"): + args = {"using": {"mode": mode}} + AG2Handler.create_validation("answer", args) + + def test_create_validation_invalid_speaker_selection(self): + args = {"using": {"speaker_selection": "invalid"}} + with self.assertRaises(ValueError): + AG2Handler.create_validation("answer", args) + + def test_create_validation_valid_speaker_selections(self): + for sel in ("auto", "round_robin", "random"): + args = {"using": {"speaker_selection": sel}} + AG2Handler.create_validation("answer", args) + + +class TestAG2HandlerCreate(unittest.TestCase): + """Test AG2Handler create method.""" + + def _make_handler(self): + handler = AG2Handler.__new__(AG2Handler) + handler.model_storage = MagicMock() + handler.engine_storage = MagicMock() + handler.engine_storage.get_connection_args.return_value = { + "openai_api_key": "test-key", + "model": "gpt-4o-mini", + "api_type": "openai", + } + return handler + + def test_create_stores_args(self): + handler = self._make_handler() + args = { + "using": { + "agents": json.dumps([{"name": "Agent1"}]), + "max_rounds": 5, + } + } + handler.create("answer", args=args) + stored = handler.model_storage.json_set.call_args[0] + self.assertEqual(stored[0], "args") + self.assertEqual(stored[1]["target"], "answer") + self.assertEqual(stored[1]["max_rounds"], 5) + + def test_create_stores_target(self): + handler = self._make_handler() + handler.create("my_output", args={"using": {}}) + stored = handler.model_storage.json_set.call_args[0][1] + self.assertEqual(stored["target"], "my_output") + + +class TestAG2HandlerPredict(unittest.TestCase): + """Test AG2Handler predict method.""" + + def _make_handler(self): + handler = AG2Handler.__new__(AG2Handler) + handler.model_storage = MagicMock() + handler.engine_storage = MagicMock() + handler.engine_storage.get_connection_args.return_value = { + "openai_api_key": "test-key", + "model": "gpt-4o-mini", + "api_type": "openai", + } + return handler + + @patch.object(AG2Handler, "_run_agents") + def test_predict_calls_agents_per_row(self, mock_run): + mock_run.return_value = "Test answer" + handler = self._make_handler() + handler.model_storage.json_get.return_value = { + "agents": json.dumps([{"name": "Agent1"}]), + "max_rounds": 8, + "target": "answer", + } + + df = pd.DataFrame({"question": ["Q1", "Q2"]}) + result = handler.predict(df, args={}) + + self.assertEqual(len(result), 2) + self.assertEqual(result["answer"][0], "Test answer") + self.assertEqual(result["answer"][1], "Test answer") + self.assertEqual(mock_run.call_count, 2) + + @patch.object(AG2Handler, "_run_agents") + def test_predict_uses_target_column(self, mock_run): + mock_run.return_value = "Result" + handler = self._make_handler() + handler.model_storage.json_get.return_value = { + "target": "my_output", + } + + df = pd.DataFrame({"question": ["Q1"]}) + result = handler.predict(df, args={}) + + self.assertIn("my_output", result.columns) + self.assertEqual(result["my_output"][0], "Result") + + @patch.object(AG2Handler, "_run_agents") + def test_predict_falls_back_to_first_column(self, mock_run): + mock_run.return_value = "Answer" + handler = self._make_handler() + handler.model_storage.json_get.return_value = { + "target": "answer", + } + + df = pd.DataFrame({"prompt": ["Hello"]}) + handler.predict(df, args={}) + + # Should use first column when 'question' is not present + call_kwargs = mock_run.call_args + self.assertEqual(call_kwargs[1]["question"], "Hello") + + @patch.object(AG2Handler, "_run_agents") + def test_predict_handles_errors(self, mock_run): + mock_run.side_effect = RuntimeError("LLM failed") + handler = self._make_handler() + handler.model_storage.json_get.return_value = { + "target": "answer", + } + + df = pd.DataFrame({"question": ["Q1"]}) + result = handler.predict(df, args={}) + + self.assertIn("Error:", result["answer"][0]) + + +class TestAG2HandlerDescribe(unittest.TestCase): + """Test AG2Handler describe method.""" + + def _make_handler(self): + handler = AG2Handler.__new__(AG2Handler) + handler.model_storage = MagicMock() + handler.engine_storage = MagicMock() + return handler + + def test_describe_default(self): + handler = self._make_handler() + handler.model_storage.json_get.return_value = { + "mode": "groupchat", + "max_rounds": 8, + "speaker_selection": "auto", + "agents": json.dumps([{"name": "A"}, {"name": "B"}]), + } + result = handler.describe() + self.assertEqual(result["num_agents"][0], 2) + self.assertEqual(result["mode"][0], "groupchat") + + def test_describe_args(self): + handler = self._make_handler() + stored = {"mode": "single", "max_rounds": 5} + handler.model_storage.json_get.return_value = stored + result = handler.describe(attribute="args") + self.assertEqual(result["mode"][0], "single") + + def test_describe_agents(self): + handler = self._make_handler() + agents = [{"name": "Researcher"}, {"name": "Writer"}] + handler.model_storage.json_get.return_value = { + "agents": json.dumps(agents), + } + result = handler.describe(attribute="agents") + self.assertEqual(len(result), 2) + self.assertEqual(result["name"][0], "Researcher") + + def test_describe_no_agents(self): + handler = self._make_handler() + handler.model_storage.json_get.return_value = {} + result = handler.describe(attribute="agents") + self.assertTrue(result.empty) + + +if __name__ == "__main__": + unittest.main() From 24887d4c17863285f9d51222b2a500b95c5200cb Mon Sep 17 00:00:00 2001 From: Max Stepanov Date: Fri, 3 Apr 2026 15:50:14 +0300 Subject: [PATCH 121/125] add ag2 to check_requiremetns --- tests/scripts/check_requirements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index 8aa0f0f8bf8..a5c59684896 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -258,6 +258,7 @@ def get_requirements_with_DEP002(path): "python-dotenv": ["dotenv"], "pyjwt": ["jwt"], "sklearn": ["scikit-learn"], + "ag2": ["autogen"], } # We use this to exit with a non-zero status code if any check fails From f25c43bb2fd5e9f7350b2d3972dc219f4011b442 Mon Sep 17 00:00:00 2001 From: Andrey Date: Mon, 6 Apr 2026 14:41:08 +0300 Subject: [PATCH 122/125] FAISS: mixed search optimization (#12336) --- .../duckdb_faiss_table.py | 170 ++++++++++++++++-- .../duckdb_faiss_handler/faiss_index.py | 15 +- 2 files changed, 165 insertions(+), 20 deletions(-) diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py index 526fd2b3ff7..3ee59e93c01 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/duckdb_faiss_table.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import List +import math import pandas as pd import orjson @@ -32,6 +33,14 @@ class DuckDBFaissTable: + META_BATCH_SIZE = 10_000 + VECTOR_MARGIN_K = 5 + VECTOR_GROWTH_MULTIPLIER = 5 + VECTOR_MAX_RATE = 0.25 + VECTOR_MAX_LIMIT = 1_000_000 + VECTOR_MAX_ITERATIONS = 3 + DEFAULT_LIMIT = 100 + def __init__(self, table_name: str, table_dir: Path, handler): self.table_name = table_name self.handler = handler @@ -63,6 +72,10 @@ def close(self) -> None: self.faiss_index.close() self.connection.close() + @staticmethod + def _empty_result() -> pd.DataFrame: + return pd.DataFrame([], columns=["id", "content", "metadata", "distance"]) + def _create_kw_index(self): with self.connection.cursor() as cur: cur.execute("PRAGMA create_fts_index('meta_data', 'id', 'content')") @@ -136,30 +149,153 @@ def select( # If only content in filter: query faiss and attach to metadata return self._select_with_vector(vector_filter=vector_filter, limit=limit) + return self.mixed_search(vector_filter=vector_filter, meta_filters=meta_filters, limit=limit) + + def mixed_search(self, vector_filter, meta_filters, limit): + """ + 1. Measure selectivity of META_FILTERS: + Get predicted count of record after applying META_FILTERS using some of methods + Selectivity = count / total records + + 2. selectivity * total_recors > LIMIT / selectivity: + Use Vector-first search + Else: + Use Metadata-first search + """ + + if limit is None: + limit = self.DEFAULT_LIMIT + + total = self.faiss_index.get_size() + if total == 0 or limit == 0: + # no reason to do vector search + return self._empty_result() + + matched_count = self.get_metadata_search_count(meta_filters) + selectivity = matched_count / total + + # compare forecast count of affected records for vector and metadata search and choose what will take less + # do search even if selectivity is 0 because it might be approximate value in the future + if selectivity > 0 and selectivity * total > limit / selectivity: + df = self.vector_first_search(vector_filter, meta_filters, limit, selectivity) + else: + df = self.metadata_first_search(vector_filter, meta_filters, limit) + + return df[:limit] + + def get_metadata_search_count(self, meta_filters): """ - If metadata + content: - Query faiss, use limit = 1000 - Query duckdb with `id in (...)` - If count of results is less than input LIMIT value - Repeat the search with increased limit value - Limit value for step = 1000 * 5^i (1000, 2000, 25000, 125000 …) + Get count of records from duckdb with meta_filters """ - df = pd.DataFrame() + where_clause = self._translate_filters(meta_filters) + count_query = Select( + targets=[Function("count", args=[Star()], alias=Identifier("cnt"))], + from_table=Identifier("meta_data"), + where=where_clause, + ) + + with self.connection.cursor() as cur: + sql = self.handler.renderer.get_string(count_query, with_failback=True) + cur.execute(sql) + df = cur.fetchdf() + + return int(df["cnt"].iloc[0]) - total_size = self.get_total_size() + def vector_first_search(self, vector_filter, meta_filters, limit, selectivity): + """ + + Calculate required top results from faiss: it is predicted count of records, that required to be scanned + + Top_results = LIMIT / selectivity * VECTOR_MARGIN_K + + Circle: + Search Top_results vectors in faiss + Get ids + query duckdb with META_FILTERS and list of ids + If count of found records < LIMIT: + Increase Top_results = Top_results * VECTOR_GROWTH_MULTIPLIER to make next search iteration + If Top_results > total * VECTOR_MAX_RATE + or Top_results > VECTOR_MAX_LIMIT + or number of iteration >VECTOR_MAX_ITERATIONS: + Something went wrong, maybe META_FILTERS records has greater distance than average record + Break vector-first search and switch to metadata-first + If count of found records >= LIMIT: + Break and return results + """ - for i in range(10): - batch_size = 1000 * 5**i + total = self.faiss_index.get_size() - # TODO implement reverse search: - # if batch_size > 25% of db: search metadata first and then in faiss by list of ids + top_results = math.ceil(limit / selectivity * self.VECTOR_MARGIN_K) - df = self._select_with_vector(vector_filter=vector_filter, meta_filters=meta_filters, limit=batch_size) - if batch_size >= total_size or len(df) >= limit: + for i in range(self.VECTOR_MAX_ITERATIONS): + df = self._select_with_vector(vector_filter=vector_filter, meta_filters=meta_filters, limit=top_results) + if len(df) >= limit: + # found required size of data + return df + + top_results = top_results * self.VECTOR_GROWTH_MULTIPLIER + + if top_results > total * self.VECTOR_MAX_RATE or top_results > self.VECTOR_MAX_LIMIT: + # give up with vector_first search break - return df[:limit] + # failback to metadata-first search + return self.metadata_first_search(vector_filter, meta_filters, limit) + + def metadata_first_search(self, vector_filter, meta_filters, limit): + """ + Metadata-first search + + Query list of all ids from duckdb table using META_FILTERS + + Split into batches by META_BATCH. + Per batch: + Get batch of ids + Use ID selector to search in FAISS only by batch of ids + use LIMIT + Combine results in single list alongside with distances + After all batches + get top LIMIT vectors with min distances + Get their ids and find records in duckdb table for them + """ + + embedding = vector_filter.value + if isinstance(embedding, str): + embedding = orjson.loads(embedding) + + where_clause = self._translate_filters(meta_filters) + ids_query = Select( + targets=[Identifier("faiss_id")], + from_table=Identifier("meta_data"), + where=where_clause, + ) + + with self.connection.cursor() as cur: + sql = self.handler.renderer.get_string(ids_query, with_failback=True) + meta_df = cur.execute(sql).fetchdf() + + if meta_df.empty: + return self._empty_result() + + faiss_ids = meta_df["faiss_id"].tolist() + results = [] + for start in range(0, len(faiss_ids), self.META_BATCH_SIZE): + batch_ids = faiss_ids[start : start + self.META_BATCH_SIZE] + + distances, faiss_ids_found = self.faiss_index.search(embedding, limit, allowed_ids=batch_ids) + results.extend(zip(distances, faiss_ids_found)) + + results.sort(key=lambda x: x[0]) + + results = results[:limit] + if len(results) == 0: + raise RuntimeError("Something went wrong, faiss database didn't return results") + distances, faiss_ids = zip(*results) + + meta_df = self._select_from_metadata(faiss_ids=faiss_ids, meta_filters=meta_filters) + vector_df = pd.DataFrame({"faiss_id": faiss_ids, "distance": distances}) + return vector_df.merge(meta_df, on="faiss_id").drop("faiss_id", axis=1).sort_values(by="distance") def keyword_select( self, @@ -241,7 +377,7 @@ def _select_with_vector(self, vector_filter: FilterCondition, meta_filters=None, if isinstance(embedding, str): embedding = orjson.loads(embedding) - distances, faiss_ids = self.faiss_index.search(embedding, limit or 100) + distances, faiss_ids = self.faiss_index.search(embedding, limit or self.DEFAULT_LIMIT) # Fetch full data from DuckDB if len(faiss_ids) > 0: @@ -250,7 +386,7 @@ def _select_with_vector(self, vector_filter: FilterCondition, meta_filters=None, vector_df = pd.DataFrame({"faiss_id": faiss_ids, "distance": distances}) return vector_df.merge(meta_df, on="faiss_id").drop("faiss_id", axis=1).sort_values(by="distance") - return pd.DataFrame([], columns=["id", "content", "metadata", "distance"]) + return self._empty_result() def _select_from_metadata(self, faiss_ids=None, meta_filters=None, limit=None): query = Select( diff --git a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py index 66ab6a64a22..45b05451808 100644 --- a/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py +++ b/mindsdb/integrations/handlers/duckdb_faiss_handler/faiss_index.py @@ -1,5 +1,5 @@ import os -from typing import Iterable, List, Callable +from typing import Iterable, List, Callable, Optional import numpy as np import psutil from pathlib import Path @@ -266,7 +266,7 @@ def search( self, query: Iterable[float], limit: int = 10, - # allowed_ids: Optional[Sequence[int]] = None, + allowed_ids: Optional[Iterable[int]] = None, ): if self.index is None: return [], [] @@ -276,7 +276,16 @@ def search( if self._normalize_vectors: queries = _normalize_rows(queries) - ds, ids = self.index.search(queries, limit) + params = None + if allowed_ids is not None: + allowed_ids_array = np.asarray(list(allowed_ids), dtype=np.int64) + ids_selector = faiss.IDSelectorArray( + len(allowed_ids_array), + faiss.swig_ptr(allowed_ids_array), + ) + params = faiss.IVFSearchParameters(sel=ids_selector) + + ds, ids = self.index.search(queries, limit, params=params) list_id = [i for i in ids[0] if i != -1] list_distances = [1 - d for d in ds[0][: len(list_id)]] From 2920f952631654d9ecf999fd618e9756275fb55d Mon Sep 17 00:00:00 2001 From: Hamish Fagg Date: Wed, 8 Apr 2026 11:02:33 +1200 Subject: [PATCH 123/125] Fix security issues in deps (#12334) Co-authored-by: Lucas Koontz --- .github/workflows/build_deploy_dev.yml | 12 +++++ docker/mindsdb.Dockerfile | 6 +-- mindsdb/api/a2a/README.md | 2 +- .../databricks_handler/requirements.txt | 2 +- .../huggingface_handler/requirements.txt | 4 +- .../huggingface_handler/requirements_cpu.txt | 4 +- .../snowflake_handler/requirements.txt | 4 +- mindsdb/utilities/langfuse.py | 54 +++++++++++-------- requirements/requirements-agents.txt | 7 +-- requirements/requirements-langfuse.txt | 2 +- requirements/requirements-opentelemetry.txt | 12 ++--- requirements/requirements.txt | 14 ++--- tests/scripts/check_requirements.py | 1 + 13 files changed, 74 insertions(+), 50 deletions(-) diff --git a/.github/workflows/build_deploy_dev.yml b/.github/workflows/build_deploy_dev.yml index 78701c7b6da..cc9d25f7edd 100644 --- a/.github/workflows/build_deploy_dev.yml +++ b/.github/workflows/build_deploy_dev.yml @@ -77,6 +77,18 @@ jobs: platforms: linux/amd64 push-cache: false + scan-keycloak: + runs-on: mdb-dev + needs: [ build ] + name: Scan cloud-cpu image + steps: + - uses: actions/checkout@v4 + - uses: mindsdb/github-actions/snyk-docker-scan@main + with: + image: 168681354662.dkr.ecr.us-east-1.amazonaws.com/mindsdb:${{ github.event.pull_request.head.sha }}-cloud-cpu + snyk-token: ${{ secrets.SNYK_TOKEN }} + dockerfile: docker/mindsdb.Dockerfile + # Push cache layers to docker registry # This is separate to the build step so we can do other stuff in parallel build-cache: diff --git a/docker/mindsdb.Dockerfile b/docker/mindsdb.Dockerfile index 0b50ff8cfad..6050970eecb 100644 --- a/docker/mindsdb.Dockerfile +++ b/docker/mindsdb.Dockerfile @@ -1,7 +1,7 @@ # This stage's objective is to gather ONLY requirements.txt files and anything else needed to install deps. # This stage will be run almost every build, but it is fast and the resulting layer hash will be the same unless a deps file changes. # We do it this way because we can't copy all requirements files with a glob pattern in docker while maintaining the folder structure. -FROM python:3.10 AS deps +FROM python:3.10.20 AS deps WORKDIR /mindsdb # Copy everything to begin with @@ -19,7 +19,7 @@ COPY mindsdb/__about__.py mindsdb/ # Use the stage from above to install our deps with as much caching as possible -FROM python:3.10 AS build +FROM python:3.10.20 AS build WORKDIR /mindsdb # Configure apt to retain downloaded packages so we can store them in a cache mount @@ -54,7 +54,7 @@ COPY --from=deps /mindsdb . # - and finally declare `/mindsdb` as the target dir. ENV UV_LINK_MODE=copy \ UV_PYTHON_DOWNLOADS=never \ - UV_PYTHON=python3.10 \ + UV_PYTHON=python3.10.20 \ UV_PROJECT_ENVIRONMENT=/mindsdb \ VIRTUAL_ENV=/venv \ PATH=/venv/bin:$PATH diff --git a/mindsdb/api/a2a/README.md b/mindsdb/api/a2a/README.md index cddb2ccf8dd..787b2d8c409 100644 --- a/mindsdb/api/a2a/README.md +++ b/mindsdb/api/a2a/README.md @@ -14,7 +14,7 @@ The A2A API runs as part of the MindsDB HTTP API, allowing you to: ## Prerequisites - MindsDB running -- Python 3.10 or higher +- Python 3.10.20 or higher ## Running A2A API diff --git a/mindsdb/integrations/handlers/databricks_handler/requirements.txt b/mindsdb/integrations/handlers/databricks_handler/requirements.txt index 212e52860fc..0137133cc54 100644 --- a/mindsdb/integrations/handlers/databricks_handler/requirements.txt +++ b/mindsdb/integrations/handlers/databricks_handler/requirements.txt @@ -1 +1 @@ -databricks-sql-connector >= 3.7.1, < 4.0.0 +databricks-sql-connector==4.2.3 diff --git a/mindsdb/integrations/handlers/huggingface_handler/requirements.txt b/mindsdb/integrations/handlers/huggingface_handler/requirements.txt index f4291850dcf..eae77291d1f 100644 --- a/mindsdb/integrations/handlers/huggingface_handler/requirements.txt +++ b/mindsdb/integrations/handlers/huggingface_handler/requirements.txt @@ -2,6 +2,6 @@ datasets==2.16.1 evaluate==0.4.3 nltk==3.9.3 -huggingface-hub==0.29.3 +huggingface-hub==1.9.1 torch==2.8.0 -transformers >= 4.42.4 +transformers==5.5.0 diff --git a/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt b/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt index b60dc5172ae..b509a2942f4 100644 --- a/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt +++ b/mindsdb/integrations/handlers/huggingface_handler/requirements_cpu.txt @@ -2,6 +2,6 @@ datasets==2.16.1 evaluate==0.4.3 nltk==3.9.3 -huggingface-hub==0.29.3 +huggingface-hub==1.9.1 torch==2.8.0+cpu -transformers >= 4.42.4 \ No newline at end of file +transformers==5.5.0 diff --git a/mindsdb/integrations/handlers/snowflake_handler/requirements.txt b/mindsdb/integrations/handlers/snowflake_handler/requirements.txt index 706f9cd675f..b267c6e302d 100644 --- a/mindsdb/integrations/handlers/snowflake_handler/requirements.txt +++ b/mindsdb/integrations/handlers/snowflake_handler/requirements.txt @@ -1,2 +1,2 @@ -snowflake-connector-python[pandas]==3.15.0 -snowflake-sqlalchemy==1.7.0 +snowflake-connector-python[pandas]==4.4.0 +snowflake-sqlalchemy==1.9.0 diff --git a/mindsdb/utilities/langfuse.py b/mindsdb/utilities/langfuse.py index def4ec98c7e..92320c48d5e 100644 --- a/mindsdb/utilities/langfuse.py +++ b/mindsdb/utilities/langfuse.py @@ -5,8 +5,8 @@ from mindsdb.utilities import log if TYPE_CHECKING: - from langfuse.callback import CallbackHandler - from langfuse.client import StatefulSpanClient + from langfuse._client.span import LangfuseSpan + from langfuse.langchain import CallbackHandler logger = log.getLogger(__name__) @@ -111,6 +111,7 @@ def __init__( public_key=public_key, secret_key=secret_key, host=host, + environment=environment, release=release, debug=debug, timeout=timeout, @@ -145,13 +146,14 @@ def setup_trace( self.set_tags(tags) try: - self.trace = self.client.trace( - name=name, input=input, metadata=self.metadata, tags=self.tags, user_id=user_id, session_id=session_id - ) + # SDK v3+: root observation is a span; trace attributes are set via update_trace. + self.trace = self.client.start_span(name=name, input=input, metadata=self.metadata) + self.trace.update_trace(tags=self.tags, user_id=user_id, session_id=session_id) except Exception: - logger.exception(f"Something went wrong while processing Langfuse trace {self.trace.id}:") + logger.exception("Something went wrong while creating Langfuse trace") + return - logger.info(f"Langfuse trace configured with ID: {self.trace.id}") + logger.info(f"Langfuse trace configured with ID: {self.trace.trace_id}") def get_trace_id(self) -> typing.Optional[str]: """ @@ -166,9 +168,9 @@ def get_trace_id(self) -> typing.Optional[str]: logger.debug("Langfuse trace is not setup.") return "" - return self.trace.id + return self.trace.trace_id - def start_span(self, name: str, input: typing.Optional[typing.Any] = None) -> typing.Optional["StatefulSpanClient"]: + def start_span(self, name: str, input: typing.Optional[typing.Any] = None) -> typing.Optional["LangfuseSpan"]: """ Create span. If Langfuse is disabled, nothing will be done. @@ -181,9 +183,9 @@ def start_span(self, name: str, input: typing.Optional[typing.Any] = None) -> ty logger.debug("Langfuse is disabled.") return None - return self.trace.span(name=name, input=input) + return self.trace.start_span(name=name, input=input) - def end_span_stream(self, span: typing.Optional["StatefulSpanClient"] = None) -> None: + def end_span_stream(self, span: typing.Optional["LangfuseSpan"] = None) -> None: """ End span. If Langfuse is disabled, nothing will happen. Args: @@ -195,10 +197,10 @@ def end_span_stream(self, span: typing.Optional["StatefulSpanClient"] = None) -> return span.end() - self.trace.update() + self.client.flush() def end_span( - self, span: typing.Optional["StatefulSpanClient"] = None, output: typing.Optional[typing.Any] = None + self, span: typing.Optional["LangfuseSpan"] = None, output: typing.Optional[typing.Any] = None ) -> None: """ End trace. If Langfuse is disabled, nothing will be done. @@ -216,8 +218,10 @@ def end_span( logger.debug("Langfuse span is not created.") return - span.end(output=output) - self.trace.update(output=output) + if output is not None: + span.update(output=output) + span.end() + self.trace.update_trace(output=output) metadata = self.metadata or {} @@ -225,9 +229,9 @@ def end_span( # Ensure all batched traces are sent before fetching. self.client.flush() metadata["tool_usage"] = self._get_tool_usage() - self.trace.update(metadata=metadata) + self.trace.update_trace(metadata=metadata) except Exception: - logger.exception(f"Something went wrong while processing Langfuse trace {self.trace.id}:") + logger.exception(f"Something went wrong while processing Langfuse trace {self.trace.trace_id}:") def get_langchain_handler(self) -> typing.Optional["CallbackHandler"]: """ @@ -238,7 +242,13 @@ def get_langchain_handler(self) -> typing.Optional["CallbackHandler"]: logger.debug("Langfuse is disabled.") return None - return self.trace.get_langchain_handler() + try: + from langfuse.langchain import CallbackHandler + except ImportError: + logger.debug("langfuse.langchain CallbackHandler is not available (install langchain extra if needed).") + return None + + return CallbackHandler(public_key=self.public_key) def set_metadata(self, custom_metadata: dict = None) -> None: """ @@ -267,8 +277,8 @@ def _get_tool_usage(self) -> typing.Dict: tool_usage = {} try: - fetched_trace = self.client.get_trace(self.trace.id) - steps = [s.name for s in fetched_trace.observations] + fetched_trace = self.client.api.trace.get(self.trace.trace_id) + steps = [s.name for s in fetched_trace.observations if s.name] for step in steps: if "AgentAction" in step: tool_name = step.split("-")[1] @@ -276,8 +286,8 @@ def _get_tool_usage(self) -> typing.Dict: tool_usage[tool_name] = 0 tool_usage[tool_name] += 1 except TraceNotFoundError: - logger.warning(f"Langfuse trace {self.trace.id} not found") + logger.warning(f"Langfuse trace {self.trace.trace_id} not found") except Exception: - logger.exception(f"Something went wrong while processing Langfuse trace {self.trace.id}:") + logger.exception(f"Something went wrong while processing Langfuse trace {self.trace.trace_id}:") return tool_usage diff --git a/requirements/requirements-agents.txt b/requirements/requirements-agents.txt index dbf6acc1096..e96657bb724 100644 --- a/requirements/requirements-agents.txt +++ b/requirements/requirements-agents.txt @@ -1,7 +1,7 @@ -openai<3.0.0,>=2.9.0 +openai<3.0.0,>=2.11.0 # When using agents, some LLMs may require the 'transformers' library (like Ollama): -transformers >= 4.42.4 +transformers==5.5.0 # Required for KB mindsdb-evaluator == 0.0.21 @@ -10,4 +10,5 @@ mcp~=1.26.0 # Required for MCP server # A2A requirements httpx==0.28.1 jwcrypto==1.5.6 -typing-extensions==4.14.1 +# fastmcp (via pydantic-ai) requires typing-extensions>=4.15.0 (py-key-value-aio chain) +typing-extensions>=4.15.0,<5 diff --git a/requirements/requirements-langfuse.txt b/requirements/requirements-langfuse.txt index fffecd7da86..7cd73e32d75 100644 --- a/requirements/requirements-langfuse.txt +++ b/requirements/requirements-langfuse.txt @@ -1 +1 @@ -langfuse==2.53.3 # Latest as of November 4, 2024 \ No newline at end of file +langfuse==3.2.5 \ No newline at end of file diff --git a/requirements/requirements-opentelemetry.txt b/requirements/requirements-opentelemetry.txt index eae7c0601c4..0b262f9b35a 100644 --- a/requirements/requirements-opentelemetry.txt +++ b/requirements/requirements-opentelemetry.txt @@ -1,6 +1,6 @@ -opentelemetry-api==1.27.0 -opentelemetry-sdk==1.27.0 -opentelemetry-exporter-otlp==1.27.0 -opentelemetry-instrumentation-requests==0.48b0 -opentelemetry-instrumentation-flask==0.48b0 -opentelemetry-distro==0.48b0 \ No newline at end of file +opentelemetry-api==1.39.1 +opentelemetry-sdk==1.39.1 +opentelemetry-exporter-otlp==1.39.1 +opentelemetry-instrumentation-requests==0.60b1 +opentelemetry-instrumentation-flask==0.60b1 +opentelemetry-distro==0.60b1 \ No newline at end of file diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 55c81279c0f..759f05a0bcd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,15 +2,15 @@ packaging flask == 3.1.3 werkzeug == 3.1.6 flask-restx >= 1.3.0, < 2.0.0 -pandas == 2.2.3 +pandas==2.3.1 python-multipart == 0.0.22 -cryptography>=35.0 +cryptography>=46.0.5 psycopg[binary] psutil~=7.0 sqlalchemy >= 2.0.0, < 3.0.0 psycopg2-binary # This is required for using sqlalchemy with postgres alembic >= 1.3.3 -redis >=5.0.0, < 6.0.0 +redis==6.4.0 walrus==0.9.3 flask-compress >= 1.0.0 appdirs >= 1.0.0 @@ -37,17 +37,17 @@ sse-starlette==2.3.3 pydantic_core>=2.33.2 pyjwt==2.12.0 # files reading -pymupdf==1.25.2 +pymupdf==1.27.2 filetype charset-normalizer openpyxl # used by pandas to read txt and xlsx files xlrd>=2.0.1 # used by pandas to read legacy .xls files -aipdf==0.0.7.0 +aipdf==0.0.7.2 pyarrow<=19.0.0 # used by pandas to read feather files in Files handler orjson==3.11.6 -mind-castle >= 0.4.9 -pydantic-ai>=0.0.14 # Required for Pydantic AI agents +mind-castle==0.5.0 +pydantic-ai==1.77.0 # Required for Pydantic AI agents bs4 # for rag HTMLDocumentLoader urllib3>=2.6.3 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/tests/scripts/check_requirements.py b/tests/scripts/check_requirements.py index a5c59684896..f3bcc6382dc 100644 --- a/tests/scripts/check_requirements.py +++ b/tests/scripts/check_requirements.py @@ -110,6 +110,7 @@ def get_requirements_with_DEP002(path): "numba", # required in a few files for the hierarchicalforecast. Otherwise, uv may install an old version. "urllib3", # pinned by Snyk to avoid a vulnerability "faiss-cpu", + "pyopenssl", ], } From 41bf4d2b1cbc2cfa0709d2a60c7e7fb8ea0d438e Mon Sep 17 00:00:00 2001 From: andrew Date: Tue, 14 Apr 2026 19:02:08 +0300 Subject: [PATCH 124/125] test ci error --- mindsdb/integrations/handlers/access_handler/access_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindsdb/integrations/handlers/access_handler/access_handler.py b/mindsdb/integrations/handlers/access_handler/access_handler.py index c212c42e8be..6ab07c91564 100644 --- a/mindsdb/integrations/handlers/access_handler/access_handler.py +++ b/mindsdb/integrations/handlers/access_handler/access_handler.py @@ -46,6 +46,7 @@ def __init__(self, name: str, connection_data: Optional[dict], **kwargs): self.parser = parse_sql self.dialect = "access" self.connection_data = connection_data + self.kwargs = kwargs self.connection = None From a8eb50b8529563afe8e60455acd11554c46b418d Mon Sep 17 00:00:00 2001 From: andrew Date: Tue, 14 Apr 2026 19:16:11 +0300 Subject: [PATCH 125/125] more commit --- mindsdb/integrations/handlers/access_handler/access_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindsdb/integrations/handlers/access_handler/access_handler.py b/mindsdb/integrations/handlers/access_handler/access_handler.py index 6ab07c91564..4faa097a9ed 100644 --- a/mindsdb/integrations/handlers/access_handler/access_handler.py +++ b/mindsdb/integrations/handlers/access_handler/access_handler.py @@ -45,6 +45,7 @@ def __init__(self, name: str, connection_data: Optional[dict], **kwargs): super().__init__(name) self.parser = parse_sql self.dialect = "access" + self.connection_data = connection_data self.kwargs = kwargs