diff --git a/redash/query_runner/trino.py b/redash/query_runner/trino.py index 6bb6c5bccc..94304d65b4 100644 --- a/redash/query_runner/trino.py +++ b/redash/query_runner/trino.py @@ -123,6 +123,8 @@ def get_schema(self, get_stats=False): else: catalogs = self._get_catalogs() + schema_filter = self.configuration.get("schema") + schema = {} for catalog in catalogs: query = f""" @@ -130,6 +132,10 @@ def get_schema(self, get_stats=False): FROM {catalog}.information_schema.columns WHERE table_schema NOT IN ('pg_catalog', 'information_schema') """ + if schema_filter: + safe_schema_filter = schema_filter.replace("'", "''") + query += f" AND table_schema = '{safe_schema_filter}'" + results, error = self.run_query(query, None) if error is not None: diff --git a/tests/query_runner/test_trino.py b/tests/query_runner/test_trino.py index b5fad8e6ea..2ef14ff0ed 100644 --- a/tests/query_runner/test_trino.py +++ b/tests/query_runner/test_trino.py @@ -29,6 +29,46 @@ def test_get_schema_catalog_set(self, mock_run_query, mock__get_catalogs): runner = Trino({"catalog": TestTrino.catalog_name}) self._assert_schema_catalog(mock_run_query, mock__get_catalogs, runner) + @patch.object(Trino, "run_query") + def test_get_schema_with_schema_filter(self, mock_run_query): + runner = Trino({"catalog": TestTrino.catalog_name, "schema": TestTrino.schema_name}) + mock_run_query.return_value = ( + { + "rows": [ + { + "table_schema": TestTrino.schema_name, + "table_name": TestTrino.table_name, + "column_name": TestTrino.column_name, + "data_type": TestTrino.column_type, + } + ] + }, + None, + ) + runner.get_schema() + query_arg = mock_run_query.call_args[0][0] + self.assertIn(f"AND table_schema = '{TestTrino.schema_name}'", query_arg) + + @patch.object(Trino, "run_query") + def test_get_schema_without_schema_filter(self, mock_run_query): + runner = Trino({"catalog": TestTrino.catalog_name}) + mock_run_query.return_value = ( + { + "rows": [ + { + "table_schema": TestTrino.schema_name, + "table_name": TestTrino.table_name, + "column_name": TestTrino.column_name, + "data_type": TestTrino.column_type, + } + ] + }, + None, + ) + runner.get_schema() + query_arg = mock_run_query.call_args[0][0] + self.assertNotIn("AND table_schema =", query_arg) + def _assert_schema_catalog(self, mock_run_query, mock__get_catalogs, runner): mock_run_query.return_value = ( {