diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 44291626d0..219451de93 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -9,11 +9,13 @@ from sqlglot.parsers.databricks import DatabricksParser from sqlglot.tokens import TokenType from sqlglot.optimizer.annotate_types import TypeAnnotator +from sqlglot.typing.databricks import EXPRESSION_METADATA class Databricks(Spark): SAFE_DIVISION = False COPY_PARAMS_ARE_CSV = False + EXPRESSION_METADATA = EXPRESSION_METADATA.copy() COERCES_TO = defaultdict(set, deepcopy(TypeAnnotator.COERCES_TO)) for text_type in exp.DataType.TEXT_TYPES: diff --git a/sqlglot/typing/databricks.py b/sqlglot/typing/databricks.py new file mode 100644 index 0000000000..89b2505cd5 --- /dev/null +++ b/sqlglot/typing/databricks.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from sqlglot import exp +from sqlglot.typing.spark import EXPRESSION_METADATA + +EXPRESSION_METADATA = { + **EXPRESSION_METADATA, + **{ + exp_type: {"returns": exp.DType.INT} + for exp_type in { + exp.RegexpCount, + } + }, + **{ + exp_type: {"annotator": lambda self, e: self._annotate_by_args(e, "this", array=True)} + for exp_type in { + exp.RegexpExtractAll, + } + }, +} diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 7dea85bb21..8bf0f8e3a8 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -970,7 +970,27 @@ BIGINT; # dialect: spark2, spark, databricks tbl.double_col DIV tbl.double_col; -BIGINT; +BIGINT; + +# dialect: databricks +tbl.str_col REGEXP 'pattern'; +BOOLEAN; + +# dialect: databricks +tbl.str_col REGEXP tbl.str_col; +BOOLEAN; + +# dialect: databricks +REGEXP_COUNT(tbl.str_col, 'l'); +INT; + +# dialect: databricks +REGEXP_COUNT(tbl.str_col, tbl.str_col); +INT; + +# dialect: databricks +REGEXP_EXTRACT_ALL(tbl.str_col, 'pattern'); +ARRAY; # dialect: hive tbl.bigint DIV tbl.bigint;