diff --git a/pyproject.toml b/pyproject.toml index e56c47c3d1..32cca55b71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "databricks-switch-plugin~=0.1.7", # Temporary, until Switch is migrated to be a transpiler (LSP) plugin. "requests>=2.28.1,<3", # Matches databricks-sdk (and 'types-requests' below), to avoid conflicts. "pandas~=2.3.1", # Required for new configure assessment + "libcst>=1.4.0,<2", ] [project.urls] diff --git a/src/databricks/labs/lakebridge/cli.py b/src/databricks/labs/lakebridge/cli.py index 872e7e6163..d30cc682df 100644 --- a/src/databricks/labs/lakebridge/cli.py +++ b/src/databricks/labs/lakebridge/cli.py @@ -38,12 +38,18 @@ from databricks.labs.lakebridge.reconcile.recon_config import RECONCILE_OPERATION_NAME, AGG_RECONCILE_OPERATION_NAME from databricks.labs.lakebridge.transpiler.describe import TranspilersDescription from databricks.labs.lakebridge.transpiler.execute import transpile as do_transpile +from databricks.labs.lakebridge.transpiler.glue.glue_engine import GlueEngine from databricks.labs.lakebridge.transpiler.lsp.lsp_engine import LSPEngine from databricks.labs.lakebridge.transpiler.repository import TranspilerRepository from databricks.labs.lakebridge.transpiler.sqlglot.sqlglot_engine import SqlglotEngine from databricks.labs.lakebridge.transpiler.switch_runner import SwitchRunner from databricks.labs.lakebridge.transpiler.transpile_engine import TranspileEngine +# Built-in engine sentinels — do not require a config file on disk +_BUILTIN_ENGINES: dict[str, type[TranspileEngine]] = { + "glue": GlueEngine, +} + from databricks.labs.lakebridge.transpiler.transpile_status import ErrorSeverity from databricks.labs.switch.lsp import get_switch_dialects @@ -243,6 +249,9 @@ def __init__( @staticmethod def _validate_transpiler_config_path(transpiler_config_path: str, msg: str) -> None: """Validate the transpiler config path: it must be a valid path that exists.""" + # Built-in engine sentinels don't require a file on disk. + if transpiler_config_path in _BUILTIN_ENGINES: + return # Note: the content is not validated here, but during loading of the engine. if not Path(transpiler_config_path).exists(): raise_validation_exception(msg) @@ -508,8 +517,11 @@ def _check_lsp_engine(self) -> TranspileEngine: transpiler_config_path, f"Error: Invalid value for '--transpiler-config-path': '{str(transpiler_config_path)}', file does not exist.", ) - path = Path(transpiler_config_path) - engine = LSPEngine.from_config_path(path) + if transpiler_config_path in _BUILTIN_ENGINES: + engine = _BUILTIN_ENGINES[transpiler_config_path]() + else: + path = Path(transpiler_config_path) + engine = LSPEngine.from_config_path(path) else: engine = None del transpiler_config_path diff --git a/src/databricks/labs/lakebridge/transpiler/glue/__init__.py b/src/databricks/labs/lakebridge/transpiler/glue/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/lakebridge/transpiler/glue/glue_engine.py b/src/databricks/labs/lakebridge/transpiler/glue/glue_engine.py new file mode 100644 index 0000000000..b9cd76e033 --- /dev/null +++ b/src/databricks/labs/lakebridge/transpiler/glue/glue_engine.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import ast as _ast +import logging +from collections.abc import Mapping, Sequence +from pathlib import Path + +from databricks.labs.lakebridge.config import TranspileConfig, TranspileResult +from databricks.labs.lakebridge.transpiler.transpile_engine import TranspileEngine +from databricks.labs.lakebridge.transpiler.transpile_status import ( + ErrorKind, + ErrorSeverity, + TranspileError, +) +from databricks.labs.lakebridge.transpiler.glue.glue_transformer import GlueTransformer + +logger = logging.getLogger(__name__) + +_DEFAULT_ARGS_STYLE = "argparse" + + +def _extract_options(config: TranspileConfig) -> tuple[str | None, str]: + """Extract catalog and args_style from transpiler_options mapping.""" + opts = config.transpiler_options + if not isinstance(opts, Mapping): + return None, _DEFAULT_ARGS_STYLE + catalog = opts.get("catalog") or None + args_style = str(opts.get("args-style", _DEFAULT_ARGS_STYLE)) + if args_style not in ("argparse", "dbutils"): + logger.warning("Unknown args-style %r, falling back to 'argparse'.", args_style) + args_style = _DEFAULT_ARGS_STYLE + return catalog, args_style + + +class GlueEngine(TranspileEngine): + """Transpiles AWS Glue PySpark scripts to Databricks PySpark.""" + + def __init__(self) -> None: + self._catalog: str | None = None + self._args_style: str = _DEFAULT_ARGS_STYLE + + @property + def transpiler_name(self) -> str: + return "glue" + + @property + def supported_dialects(self) -> Sequence[str]: + return ["glue"] + + def is_supported_file(self, file: Path) -> bool: + return file.suffix.lower() == ".py" + + async def initialize(self, config: TranspileConfig) -> None: + self._catalog, self._args_style = _extract_options(config) + + async def shutdown(self) -> None: + pass + + async def transpile( + self, + source_dialect: str, + target_dialect: str, + source_code: str, + file_path: Path, + ) -> TranspileResult: + try: + transformer = GlueTransformer( + file_path, + catalog=self._catalog, + args_style=self._args_style, + ) + transpiled_code, warnings = transformer.transform(source_code) + + try: + _ast.parse(transpiled_code) + except SyntaxError as syn_err: + warnings.append(f"Generated code contains a syntax error: {syn_err}") + + errors = [ + TranspileError( + code="GLUE_WARNING", + kind=ErrorKind.GENERATION, + severity=ErrorSeverity.WARNING, + path=file_path, + message=msg, + ) + for msg in warnings + ] + return TranspileResult(transpiled_code, 1, errors) + + except SyntaxError as err: + error = TranspileError( + code="SYNTAX_ERROR", + kind=ErrorKind.PARSING, + severity=ErrorSeverity.ERROR, + path=file_path, + message=f"Python syntax error in source: {err}", + ) + return TranspileResult(source_code, 0, [error]) + + except Exception as err: # pylint: disable=broad-exception-caught + logger.exception("Unexpected error transpiling %s", file_path) + error = TranspileError( + code="GLUE_TRANSPILE_ERROR", + kind=ErrorKind.GENERATION, + severity=ErrorSeverity.ERROR, + path=file_path, + message=f"Unexpected transpilation error: {err}", + ) + return TranspileResult(source_code, 0, [error]) diff --git a/src/databricks/labs/lakebridge/transpiler/glue/glue_transformer.py b/src/databricks/labs/lakebridge/transpiler/glue/glue_transformer.py new file mode 100644 index 0000000000..8800c2170f --- /dev/null +++ b/src/databricks/labs/lakebridge/transpiler/glue/glue_transformer.py @@ -0,0 +1,554 @@ +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import Sequence, Union + +import libcst as cst +import libcst.matchers as m + +logger = logging.getLogger(__name__) + +_AWSGLUE_MODULES = frozenset( + ["awsglue.context", "awsglue.transforms", "awsglue.utils", "awsglue.dynamicframe", "awsglue.job"] +) + +_UNSUPPORTED_TRANSFORMS = frozenset([ + "DropNullFields", + "FillMissingValues", + "Filter", + "FindIncrementalMatches", + "FlatMap", + "Map", + "Relationalize", + "RenameField", + "ResolveChoice", + "SelectFields", + "SelectFromCollection", + "SplitFields", + "SplitRows", + "Unbox", +]) + +_GLUE_TYPE_MAP: dict[str, str] = { + "string": "string", + "char": "string", + "varchar": "string", + "int": "int", + "integer": "int", + "long": "bigint", + "short": "smallint", + "byte": "tinyint", + "double": "double", + "float": "float", + "boolean": "boolean", + "bool": "boolean", + "binary": "binary", + "date": "date", + "timestamp": "timestamp", + "decimal": "decimal", + "array": "array", + "map": "map", + "struct": "struct", +} + +_DECIMAL_RE = re.compile(r"^decimal\(\d+,\s*\d+\)$", re.IGNORECASE) + +_ROLE_SPARK_CONTEXT = "SparkContext" +_ROLE_GLUE_CONTEXT = "GlueContext" +_ROLE_SPARK_SESSION = "SparkSession" +_ROLE_JOB = "Job" + + +def _map_glue_type(glue_type: str) -> str: + lower = glue_type.lower().strip() + if _DECIMAL_RE.match(lower): + return lower + return _GLUE_TYPE_MAP.get(lower, lower) + + +def _attr_chain(node: cst.BaseExpression) -> list[str]: + parts: list[str] = [] + cur = node + while isinstance(cur, cst.Attribute): + parts.append(cur.attr.value) + cur = cur.value + if isinstance(cur, cst.Name): + parts.append(cur.value) + parts.reverse() + return parts + + +def _kwarg_str(args: Sequence[cst.Arg], name: str) -> str | None: + for arg in args: + if arg.keyword and arg.keyword.value == name: + if m.matches(arg.value, m.SimpleString()): + assert isinstance(arg.value, cst.SimpleString) + return arg.value.evaluated_value # type: ignore[return-value] + return None + + +def _kwarg_node(args: Sequence[cst.Arg], name: str) -> cst.BaseExpression | None: + for arg in args: + if arg.keyword and arg.keyword.value == name: + return arg.value + return None + + +def _kwarg_str_from_dict(dict_node: cst.BaseExpression | None, key: str) -> str | None: + if not isinstance(dict_node, cst.Dict): + return None + for el in dict_node.elements: + if isinstance(el, cst.DictElement) and m.matches(el.key, m.SimpleString()): + assert isinstance(el.key, cst.SimpleString) + if el.key.evaluated_value == key and isinstance(el.value, cst.SimpleString): + return el.value.evaluated_value # type: ignore[return-value] + return None + + +def _first_positional(args: Sequence[cst.Arg]) -> cst.BaseExpression | None: + for arg in args: + if arg.keyword is None: + return arg.value + return None + + +def _parse_stmt(code: str) -> cst.SimpleStatementLine: + module = cst.parse_module(code.strip() + "\n") + assert isinstance(module.body[0], cst.SimpleStatementLine) + return module.body[0] + + +def _parse_expr(code: str) -> cst.BaseExpression: + module = cst.parse_module(code.strip()) + stmt = module.body[0] + assert isinstance(stmt, cst.SimpleStatementLine) + assert isinstance(stmt.body[0], cst.Expr) + return stmt.body[0].value + + +class _BindingCollector(cst.CSTVisitor): + """Read-only first pass: collect variable→role assignments before transformation.""" + + def __init__(self) -> None: + self.var_roles: dict[str, str] = {} + + def visit_Assign(self, node: cst.Assign) -> None: + if not node.targets: + return + target = node.targets[0].target + if not isinstance(target, cst.Name): + return + value = node.value + for role in (_ROLE_SPARK_CONTEXT, _ROLE_GLUE_CONTEXT, _ROLE_JOB): + if m.matches(value, m.Call(func=m.Name(role))): + self.var_roles[target.value] = role + return + if m.matches(value, m.Attribute(attr=m.Name("spark_session"))): + self.var_roles[target.value] = _ROLE_SPARK_SESSION + + +class _GlueVisitor(cst.CSTTransformer): + def __init__( + self, + var_roles: dict[str, str], + catalog: str | None = None, + args_style: str = "argparse", + ) -> None: + super().__init__() + self.warnings: list[str] = [] + + self._var_roles = var_roles + self._catalog = catalog + self._args_style = args_style + + self._glue_context_var: str | None = next( + (v for v, r in var_roles.items() if r == _ROLE_GLUE_CONTEXT), None + ) + self._spark_session_var: str | None = next( + (v for v, r in var_roles.items() if r == _ROLE_SPARK_SESSION), None + ) + self._job_var: str | None = next( + (v for v, r in var_roles.items() if r == _ROLE_JOB), None + ) + + self._needs_spark_session_import = False + self._needs_argparse_import = False + self._needs_col_import = False + + self._argparse_params: list[str] | None = None + + def leave_ImportFrom( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> cst.ImportFrom | cst.RemovalSentinel: + if not isinstance(updated_node.module, (cst.Attribute, cst.Name)): + return updated_node + module_str = ".".join(_attr_chain(updated_node.module)) + if module_str not in _AWSGLUE_MODULES: + return updated_node + if module_str == "awsglue.context": + self._needs_spark_session_import = True + if module_str == "awsglue.utils": + self._needs_argparse_import = self._args_style == "argparse" + return cst.RemovalSentinel.REMOVE + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + new_imports: list[cst.SimpleStatementLine] = [] + if self._needs_spark_session_import: + new_imports.append(_parse_stmt("from pyspark.sql import SparkSession")) + if self._needs_argparse_import: + new_imports.append(_parse_stmt("import argparse")) + if self._needs_col_import: + new_imports.append(_parse_stmt("from pyspark.sql.functions import col")) + if not new_imports: + return updated_node + + # Insert after the last existing import, not at the top of the module. + last_import_idx = -1 + for i, stmt in enumerate(updated_node.body): + if isinstance(stmt, cst.SimpleStatementLine): + if any(isinstance(s, (cst.Import, cst.ImportFrom)) for s in stmt.body): + last_import_idx = i + + insert_at = last_import_idx + 1 + body = list(updated_node.body) + for offset, imp in enumerate(new_imports): + body.insert(insert_at + offset, imp) + return updated_node.with_changes(body=body) + + def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.BaseSmallStatement: + if not updated_node.targets: + return updated_node + target = updated_node.targets[0].target + if not isinstance(target, cst.Name): + return updated_node + var_name = target.value + value = updated_node.value + role = self._var_roles.get(var_name) + + if role == _ROLE_SPARK_CONTEXT: + return cst.RemovalSentinel.REMOVE + + if role == _ROLE_GLUE_CONTEXT: + return cst.RemovalSentinel.REMOVE + + if role == _ROLE_SPARK_SESSION: + new_value = _parse_expr("SparkSession.builder.getOrCreate()") + return updated_node.with_changes(value=new_value) + + if role == _ROLE_JOB: + return cst.RemovalSentinel.REMOVE + + if m.matches(value, m.Call(func=m.Name("getResolvedOptions"))): + assert isinstance(value, cst.Call) + return self._rewrite_get_resolved_options(updated_node, var_name, value) + + return updated_node + + def _rewrite_get_resolved_options( + self, original: cst.Assign, var_name: str, call: cst.Call + ) -> cst.BaseSmallStatement: + args = call.args + param_list_node = args[1].value if len(args) >= 2 else None + if param_list_node is None: + self.warnings.append("Could not find parameter list in getResolvedOptions call.") + return original + + params: list[str] = [] + if isinstance(param_list_node, cst.List): + for el in param_list_node.elements: + if isinstance(el.value, cst.SimpleString): + s = el.value.evaluated_value # type: ignore[assignment] + if s != "JOB_NAME": + params.append(str(s)) + else: + self.warnings.append( + "getResolvedOptions called with a non-literal parameter list; manual conversion required." + ) + return original + + if self._args_style == "dbutils": + self._argparse_params = params + # Widget registration is deferred to leave_SimpleStatementLine so the + # dbutils.widgets.text() calls are emitted before the dict comprehension. + new_value = _parse_expr(f'{{k: dbutils.widgets.get(k) for k in {repr(params)}}}') + return original.with_changes(value=new_value) + + self._argparse_params = params + new_value = _parse_expr("vars(_parser.parse_args())") + return original.with_changes(value=new_value) + + def leave_SimpleStatementLine( + self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + ) -> Union[cst.SimpleStatementLine, cst.RemovalSentinel]: + if not updated_node.body: + return cst.RemovalSentinel.REMOVE + if isinstance(updated_node.body[0], cst.RemovalSentinel): + return cst.RemovalSentinel.REMOVE + + # job.init() / job.commit() must be removed here rather than in leave_Assign + # because they are expression statements, not assignments. + if self._job_var and len(updated_node.body) == 1: + stmt = updated_node.body[0] + if isinstance(stmt, cst.Expr) and isinstance(stmt.value, cst.Call): + call = stmt.value + if isinstance(call.func, cst.Attribute): + chain = _attr_chain(call.func) + if len(chain) == 2 and chain[0] == self._job_var and chain[1] in ("init", "commit"): + return cst.RemovalSentinel.REMOVE + + if self._argparse_params is not None and isinstance(updated_node.body[0], cst.Assign): + params = self._argparse_params + self._argparse_params = None + if self._args_style == "dbutils": + prepend: list[cst.SimpleStatementLine] = [] + for p in params: + prepend.append(_parse_stmt(f'dbutils.widgets.text({repr(p)}, "")')) + return cst.FlattenSentinel([*prepend, updated_node]) + else: + prepend = [_parse_stmt("_parser = argparse.ArgumentParser()")] + for p in params: + prepend.append(_parse_stmt(f'_parser.add_argument("--{p}")')) + return cst.FlattenSentinel([*prepend, updated_node]) + + return updated_node + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression: + func = updated_node.func + if not isinstance(func, cst.Attribute): + return updated_node + + chain = _attr_chain(func) + + if ( + self._glue_context_var + and len(chain) == 3 + and chain[0] == self._glue_context_var + and chain[1] == "create_dynamic_frame" + and chain[2] == "from_catalog" + ): + return self._rewrite_from_catalog(updated_node) + + if ( + self._glue_context_var + and len(chain) == 3 + and chain[0] == self._glue_context_var + and chain[1] == "create_dynamic_frame" + and chain[2] == "from_options" + ): + return self._rewrite_from_options_read(updated_node) + + if ( + self._glue_context_var + and len(chain) == 3 + and chain[0] == self._glue_context_var + and chain[1] == "write_dynamic_frame" + and chain[2] == "from_options" + ): + return self._rewrite_write_dynamic_frame(updated_node) + + if len(chain) == 2 and chain[0] == "ApplyMapping" and chain[1] == "apply": + return self._rewrite_apply_mapping(updated_node) + + if chain[0] in _UNSUPPORTED_TRANSFORMS and len(chain) >= 2: + self.warnings.append(f"Unsupported Glue transform '{chain[0]}' — manual conversion required.") + return updated_node + + if len(chain) >= 2 and chain[-1] == "toDF": + frame_arg = _first_positional(updated_node.args) + if frame_arg is not None: + return frame_arg + if isinstance(func, cst.Attribute): + return func.value + + if len(chain) >= 3 and chain[-2] == "DynamicFrame" and chain[-1] == "fromDF": + df_arg = _first_positional(updated_node.args) + if df_arg is not None: + return df_arg + + return updated_node + + def _table_ref(self, database: str, table_name: str) -> str: + if self._catalog: + return repr(f"{self._catalog}.{database}.{table_name}") + return repr(f"{database}.{table_name}") + + def _rewrite_from_catalog(self, node: cst.Call) -> cst.BaseExpression: + database = _kwarg_str(node.args, "database") + table_name = _kwarg_str(node.args, "table_name") + push_down_predicate = _kwarg_str(node.args, "push_down_predicate") + + if database is None or table_name is None: + self.warnings.append( + "create_dynamic_frame.from_catalog with non-literal database/table_name — manual conversion required." + ) + return node + + spark_var = self._spark_session_var or "spark" + table_ref = self._table_ref(database, table_name) + + if push_down_predicate: + return _parse_expr(f"{spark_var}.read.table({table_ref}).where({repr(push_down_predicate)})") + return _parse_expr(f"{spark_var}.read.table({table_ref})") + + def _rewrite_from_options_read(self, node: cst.Call) -> cst.BaseExpression: + connection_type = _kwarg_str(node.args, "connection_type") + fmt = _kwarg_str(node.args, "format") + spark_var = self._spark_session_var or "spark" + conn_opts_node = _kwarg_node(node.args, "connection_options") + + if connection_type == "s3": + path = _kwarg_str_from_dict(conn_opts_node, "path") + read_format = fmt or "parquet" + if path: + return _parse_expr(f"{spark_var}.read.format({repr(read_format)}).load({repr(self._normalize_s3_path(path))})") + return _parse_expr(f"{spark_var}.read.format({repr(read_format)}).load(...)") + + if connection_type == "jdbc": + url = _kwarg_str_from_dict(conn_opts_node, "url") + dbtable = _kwarg_str_from_dict(conn_opts_node, "dbtable") + if url and dbtable: + return _parse_expr( + f"{spark_var}.read.format('jdbc')" + f".option('url', {repr(url)})" + f".option('dbtable', {repr(dbtable)})" + f".load()" + ) + self.warnings.append("JDBC read with dynamic options — manual conversion required.") + return node + + self.warnings.append( + f"create_dynamic_frame.from_options with connection_type={repr(connection_type)} — manual conversion required." + ) + return node + + def _rewrite_write_dynamic_frame(self, node: cst.Call) -> cst.BaseExpression: + connection_type = _kwarg_str(node.args, "connection_type") + fmt = _kwarg_str(node.args, "format") + frame_node = _kwarg_node(node.args, "frame") + conn_opts_node = _kwarg_node(node.args, "connection_options") + + df_expr = "df" + if frame_node is not None: + df_expr = cst.parse_module("").code_for_node(frame_node) + + if connection_type == "s3": + path = _kwarg_str_from_dict(conn_opts_node, "path") + partition_keys: list[str] = [] + if isinstance(conn_opts_node, cst.Dict): + for el in conn_opts_node.elements: + if isinstance(el, cst.DictElement) and m.matches(el.key, m.SimpleString()): + assert isinstance(el.key, cst.SimpleString) + if el.key.evaluated_value == "partitionKeys" and isinstance(el.value, cst.List): + for elem in el.value.elements: + if isinstance(elem.value, cst.SimpleString): + pk = elem.value.evaluated_value + if pk: + partition_keys.append(str(pk)) + + write_format = fmt or "parquet" + chain = f"{df_expr}.write.format({repr(write_format)})" + if partition_keys: + keys_str = ", ".join(repr(k) for k in partition_keys) + chain += f".partitionBy({keys_str})" + if path: + chain += f".save({repr(self._normalize_s3_path(path))})" + else: + chain += ".save(...)" + return _parse_expr(chain) + + if connection_type == "jdbc": + url = _kwarg_str_from_dict(conn_opts_node, "url") + dbtable = _kwarg_str_from_dict(conn_opts_node, "dbtable") + if url and dbtable: + return _parse_expr( + f"{df_expr}.write.format('jdbc')" + f".option('url', {repr(url)})" + f".option('dbtable', {repr(dbtable)})" + f".save()" + ) + self.warnings.append("JDBC write with dynamic options — manual conversion required.") + return node + + self.warnings.append( + f"write_dynamic_frame.from_options with connection_type={repr(connection_type)} — manual conversion required." + ) + return node + + def _rewrite_apply_mapping(self, node: cst.Call) -> cst.BaseExpression: + frame_node = _kwarg_node(node.args, "frame") + mappings_node = _kwarg_node(node.args, "mappings") + + if frame_node is None: + frame_node = _first_positional(node.args) + if mappings_node is None and len(node.args) >= 2: + mappings_node = node.args[1].value + + if frame_node is None: + self.warnings.append("ApplyMapping.apply: could not determine frame — manual conversion required.") + return node + + df_code = cst.parse_module("").code_for_node(frame_node) + + if not isinstance(mappings_node, cst.List): + self.warnings.append("ApplyMapping.apply: non-literal mappings list — manual conversion required.") + return node + + chain = df_code + used_col = False + for elem in mappings_node.elements: + if not isinstance(elem.value, (cst.Tuple, cst.List)): + self.warnings.append("ApplyMapping.apply: non-literal mapping tuple — skipping.") + continue + items = [e.value for e in elem.value.elements] + if len(items) < 4: + continue + src_name_node, _src_type_node, dst_name_node, dst_type_node = items[:4] + if not all(isinstance(n, cst.SimpleString) for n in [src_name_node, dst_name_node, dst_type_node]): + self.warnings.append("ApplyMapping.apply: non-string mapping fields — skipping.") + continue + src_name = str(src_name_node.evaluated_value) # type: ignore[union-attr] + dst_name = str(dst_name_node.evaluated_value) # type: ignore[union-attr] + dst_type = str(dst_type_node.evaluated_value) # type: ignore[union-attr] + + if src_name != dst_name: + chain += f".withColumnRenamed({repr(src_name)}, {repr(dst_name)})" + + spark_type = _map_glue_type(dst_type) + chain += f".withColumn({repr(dst_name)}, col({repr(dst_name)}).cast({repr(spark_type)}))" + used_col = True + + if used_col: + self._needs_col_import = True + + return _parse_expr(chain) + + @staticmethod + def _normalize_s3_path(path: str) -> str: + for prefix in ("s3a://", "s3n://"): + if path.startswith(prefix): + return "s3://" + path[len(prefix):] + return path + + +class GlueTransformer: + def __init__( + self, + file_path: Path | None = None, + catalog: str | None = None, + args_style: str = "argparse", + ) -> None: + self._file_path = file_path + self._catalog = catalog + self._args_style = args_style + + def transform(self, source_code: str) -> tuple[str, list[str]]: + tree = cst.parse_module(source_code) + + collector = _BindingCollector() + tree.visit(collector) + + visitor = _GlueVisitor(collector.var_roles, catalog=self._catalog, args_style=self._args_style) + modified = tree.visit(visitor) + return modified.code, visitor.warnings diff --git a/tests/resources/functional/glue/args/test_dbutils_widgets.py b/tests/resources/functional/glue/args/test_dbutils_widgets.py new file mode 100644 index 0000000000..5aef7a548a --- /dev/null +++ b/tests/resources/functional/glue/args/test_dbutils_widgets.py @@ -0,0 +1,7 @@ +import sys +from awsglue.utils import getResolvedOptions + +args = getResolvedOptions(sys.argv, ["JOB_NAME", "source_db", "output_path"]) + +source = args["source_db"] +output = args["output_path"] diff --git a/tests/resources/functional/glue/args/test_dbutils_widgets_expected.py b/tests/resources/functional/glue/args/test_dbutils_widgets_expected.py new file mode 100644 index 0000000000..8232f5de5a --- /dev/null +++ b/tests/resources/functional/glue/args/test_dbutils_widgets_expected.py @@ -0,0 +1,8 @@ +import sys +dbutils.widgets.text('source_db', "") +dbutils.widgets.text('output_path', "") + +args = {k: dbutils.widgets.get(k) for k in ['source_db', 'output_path']} + +source = args["source_db"] +output = args["output_path"] diff --git a/tests/resources/functional/glue/args/test_getresolvedoptions.py b/tests/resources/functional/glue/args/test_getresolvedoptions.py new file mode 100644 index 0000000000..f1a0c9622a --- /dev/null +++ b/tests/resources/functional/glue/args/test_getresolvedoptions.py @@ -0,0 +1,7 @@ +import sys +from awsglue.utils import getResolvedOptions + +args = getResolvedOptions(sys.argv, ["JOB_NAME", "source_bucket", "target_table"]) + +source = args["source_bucket"] +target = args["target_table"] diff --git a/tests/resources/functional/glue/args/test_getresolvedoptions_expected.py b/tests/resources/functional/glue/args/test_getresolvedoptions_expected.py new file mode 100644 index 0000000000..675fa40d5a --- /dev/null +++ b/tests/resources/functional/glue/args/test_getresolvedoptions_expected.py @@ -0,0 +1,10 @@ +import sys +import argparse +_parser = argparse.ArgumentParser() +_parser.add_argument("--source_bucket") +_parser.add_argument("--target_table") + +args = vars(_parser.parse_args()) + +source = args["source_bucket"] +target = args["target_table"] diff --git a/tests/resources/functional/glue/boilerplate/test_job_lifecycle.py b/tests/resources/functional/glue/boilerplate/test_job_lifecycle.py new file mode 100644 index 0000000000..1ce2818e8e --- /dev/null +++ b/tests/resources/functional/glue/boilerplate/test_job_lifecycle.py @@ -0,0 +1,14 @@ +from awsglue.context import GlueContext +from awsglue.job import Job +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +job = Job(glueContext) +job.init("my-glue-job", {}) + +result_df = spark.read.table("processed.data") + +job.commit() diff --git a/tests/resources/functional/glue/boilerplate/test_job_lifecycle_expected.py b/tests/resources/functional/glue/boilerplate/test_job_lifecycle_expected.py new file mode 100644 index 0000000000..4b68794051 --- /dev/null +++ b/tests/resources/functional/glue/boilerplate/test_job_lifecycle_expected.py @@ -0,0 +1,5 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +result_df = spark.read.table("processed.data") diff --git a/tests/resources/functional/glue/context/test_glue_context_setup.py b/tests/resources/functional/glue/context/test_glue_context_setup.py new file mode 100644 index 0000000000..ebfb7bb586 --- /dev/null +++ b/tests/resources/functional/glue/context/test_glue_context_setup.py @@ -0,0 +1,6 @@ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session diff --git a/tests/resources/functional/glue/context/test_glue_context_setup_expected.py b/tests/resources/functional/glue/context/test_glue_context_setup_expected.py new file mode 100644 index 0000000000..0e6cd3a4ba --- /dev/null +++ b/tests/resources/functional/glue/context/test_glue_context_setup_expected.py @@ -0,0 +1,3 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() diff --git a/tests/resources/functional/glue/e2e/test_full_glue_job.py b/tests/resources/functional/glue/e2e/test_full_glue_job.py new file mode 100644 index 0000000000..92cee251bc --- /dev/null +++ b/tests/resources/functional/glue/e2e/test_full_glue_job.py @@ -0,0 +1,42 @@ +"""Full realistic AWS Glue ETL job: read from catalog, apply mapping, write to S3.""" +import sys +from awsglue.context import GlueContext +from awsglue.transforms import * +from awsglue.utils import getResolvedOptions +from awsglue.job import Job +from pyspark.context import SparkContext + +args = getResolvedOptions(sys.argv, ["JOB_NAME", "output_path"]) + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +job = Job(glueContext) +job.init(args["JOB_NAME"], args) + +# Read source data +orders_df = glueContext.create_dynamic_frame.from_catalog( + database="source_db", + table_name="orders", +) + +# Apply column mapping and type casting +mapped_df = ApplyMapping.apply( + frame=orders_df, + mappings=[ + ("order_id", "string", "order_id", "long"), + ("customer_id", "string", "customer_id", "long"), + ("amount", "string", "amount", "double"), + ], +) + +# Write to S3 +glueContext.write_dynamic_frame.from_options( + frame=mapped_df, + connection_type="s3", + connection_options={"path": "s3://output-bucket/orders/", "partitionKeys": ["order_id"]}, + format="parquet", +) + +job.commit() diff --git a/tests/resources/functional/glue/e2e/test_full_glue_job_expected.py b/tests/resources/functional/glue/e2e/test_full_glue_job_expected.py new file mode 100644 index 0000000000..3b559f5230 --- /dev/null +++ b/tests/resources/functional/glue/e2e/test_full_glue_job_expected.py @@ -0,0 +1,20 @@ +"""Full realistic AWS Glue ETL job: read from catalog, apply mapping, write to S3.""" +import sys +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +import argparse +from pyspark.sql.functions import col +_parser = argparse.ArgumentParser() +_parser.add_argument("--output_path") + +args = vars(_parser.parse_args()) +spark = SparkSession.builder.getOrCreate() + +# Read source data +orders_df = spark.read.table('source_db.orders') + +# Apply column mapping and type casting +mapped_df = orders_df.withColumn('order_id', col('order_id').cast('bigint')).withColumn('customer_id', col('customer_id').cast('bigint')).withColumn('amount', col('amount').cast('double')) + +# Write to S3 +mapped_df.write.format('parquet').partitionBy('order_id').save('s3://output-bucket/orders/') diff --git a/tests/resources/functional/glue/imports/test_basic_imports.py b/tests/resources/functional/glue/imports/test_basic_imports.py new file mode 100644 index 0000000000..ee6511fee8 --- /dev/null +++ b/tests/resources/functional/glue/imports/test_basic_imports.py @@ -0,0 +1,7 @@ +import sys +from awsglue.context import GlueContext +from awsglue.transforms import * +from awsglue.utils import getResolvedOptions +from awsglue.dynamicframe import DynamicFrame +from awsglue.job import Job +from pyspark.context import SparkContext diff --git a/tests/resources/functional/glue/imports/test_basic_imports_expected.py b/tests/resources/functional/glue/imports/test_basic_imports_expected.py new file mode 100644 index 0000000000..aef355655b --- /dev/null +++ b/tests/resources/functional/glue/imports/test_basic_imports_expected.py @@ -0,0 +1,4 @@ +import sys +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +import argparse diff --git a/tests/resources/functional/glue/reads/test_from_catalog.py b/tests/resources/functional/glue/reads/test_from_catalog.py new file mode 100644 index 0000000000..525802918f --- /dev/null +++ b/tests/resources/functional/glue/reads/test_from_catalog.py @@ -0,0 +1,8 @@ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +customers_df = glueContext.create_dynamic_frame.from_catalog(database="sales_db", table_name="customers") diff --git a/tests/resources/functional/glue/reads/test_from_catalog_expected.py b/tests/resources/functional/glue/reads/test_from_catalog_expected.py new file mode 100644 index 0000000000..9f1fbf7d5f --- /dev/null +++ b/tests/resources/functional/glue/reads/test_from_catalog_expected.py @@ -0,0 +1,5 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +customers_df = spark.read.table('sales_db.customers') diff --git a/tests/resources/functional/glue/reads/test_from_options_jdbc.py b/tests/resources/functional/glue/reads/test_from_options_jdbc.py new file mode 100644 index 0000000000..8feb4636e8 --- /dev/null +++ b/tests/resources/functional/glue/reads/test_from_options_jdbc.py @@ -0,0 +1,14 @@ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +jdbc_df = glueContext.create_dynamic_frame.from_options( + connection_type="jdbc", + connection_options={ + "url": "jdbc:postgresql://host:5432/mydb", + "dbtable": "public.transactions", + }, +) diff --git a/tests/resources/functional/glue/reads/test_from_options_jdbc_expected.py b/tests/resources/functional/glue/reads/test_from_options_jdbc_expected.py new file mode 100644 index 0000000000..3237938ca5 --- /dev/null +++ b/tests/resources/functional/glue/reads/test_from_options_jdbc_expected.py @@ -0,0 +1,5 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +jdbc_df = spark.read.format('jdbc').option('url', 'jdbc:postgresql://host:5432/mydb').option('dbtable', 'public.transactions').load() diff --git a/tests/resources/functional/glue/reads/test_from_options_s3.py b/tests/resources/functional/glue/reads/test_from_options_s3.py new file mode 100644 index 0000000000..00e28e783f --- /dev/null +++ b/tests/resources/functional/glue/reads/test_from_options_s3.py @@ -0,0 +1,12 @@ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +raw_df = glueContext.create_dynamic_frame.from_options( + connection_type="s3", + connection_options={"path": "s3://my-bucket/raw/events/"}, + format="parquet", +) diff --git a/tests/resources/functional/glue/reads/test_from_options_s3_expected.py b/tests/resources/functional/glue/reads/test_from_options_s3_expected.py new file mode 100644 index 0000000000..719724a4bf --- /dev/null +++ b/tests/resources/functional/glue/reads/test_from_options_s3_expected.py @@ -0,0 +1,5 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +raw_df = spark.read.format('parquet').load('s3://my-bucket/raw/events/') diff --git a/tests/resources/functional/glue/transforms/test_apply_mapping_cast.py b/tests/resources/functional/glue/transforms/test_apply_mapping_cast.py new file mode 100644 index 0000000000..9f3b176f4c --- /dev/null +++ b/tests/resources/functional/glue/transforms/test_apply_mapping_cast.py @@ -0,0 +1,17 @@ +from awsglue.context import GlueContext +from awsglue.transforms import * +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +source_df = glueContext.create_dynamic_frame.from_catalog(database="raw", table_name="orders") + +mapped_df = ApplyMapping.apply( + frame=source_df, + mappings=[ + ("id", "string", "order_id", "long"), + ("amount", "string", "total_amount", "double"), + ], +) diff --git a/tests/resources/functional/glue/transforms/test_apply_mapping_cast_expected.py b/tests/resources/functional/glue/transforms/test_apply_mapping_cast_expected.py new file mode 100644 index 0000000000..23acda65ea --- /dev/null +++ b/tests/resources/functional/glue/transforms/test_apply_mapping_cast_expected.py @@ -0,0 +1,8 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +from pyspark.sql.functions import col +spark = SparkSession.builder.getOrCreate() + +source_df = spark.read.table('raw.orders') + +mapped_df = source_df.withColumnRenamed('id', 'order_id').withColumn('order_id', col('order_id').cast('bigint')).withColumnRenamed('amount', 'total_amount').withColumn('total_amount', col('total_amount').cast('double')) diff --git a/tests/resources/functional/glue/transforms/test_apply_mapping_simple.py b/tests/resources/functional/glue/transforms/test_apply_mapping_simple.py new file mode 100644 index 0000000000..3983cdc28b --- /dev/null +++ b/tests/resources/functional/glue/transforms/test_apply_mapping_simple.py @@ -0,0 +1,17 @@ +from awsglue.context import GlueContext +from awsglue.transforms import * +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +source_df = glueContext.create_dynamic_frame.from_catalog(database="raw", table_name="orders") + +mapped_df = ApplyMapping.apply( + frame=source_df, + mappings=[ + ("order_id", "string", "order_id", "string"), + ("customer_name", "string", "customer_name", "string"), + ], +) diff --git a/tests/resources/functional/glue/transforms/test_apply_mapping_simple_expected.py b/tests/resources/functional/glue/transforms/test_apply_mapping_simple_expected.py new file mode 100644 index 0000000000..68a1c50c3c --- /dev/null +++ b/tests/resources/functional/glue/transforms/test_apply_mapping_simple_expected.py @@ -0,0 +1,8 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +from pyspark.sql.functions import col +spark = SparkSession.builder.getOrCreate() + +source_df = spark.read.table('raw.orders') + +mapped_df = source_df.withColumn('order_id', col('order_id').cast('string')).withColumn('customer_name', col('customer_name').cast('string')) diff --git a/tests/resources/functional/glue/transforms/test_apply_mapping_types.py b/tests/resources/functional/glue/transforms/test_apply_mapping_types.py new file mode 100644 index 0000000000..aa6ed104e6 --- /dev/null +++ b/tests/resources/functional/glue/transforms/test_apply_mapping_types.py @@ -0,0 +1,21 @@ +from awsglue.context import GlueContext +from awsglue.transforms import * +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +src = glueContext.create_dynamic_frame.from_catalog(database="raw", table_name="events") + +mapped = ApplyMapping.apply( + frame=src, + mappings=[ + ("id", "byte", "id", "byte"), + ("counter", "short", "counter", "short"), + ("total", "long", "total", "long"), + ("active", "bool", "active", "boolean"), + ("label", "char", "label", "string"), + ("amount", "decimal(10,2)", "amount", "decimal(10,2)"), + ], +) diff --git a/tests/resources/functional/glue/transforms/test_apply_mapping_types_expected.py b/tests/resources/functional/glue/transforms/test_apply_mapping_types_expected.py new file mode 100644 index 0000000000..b21071743c --- /dev/null +++ b/tests/resources/functional/glue/transforms/test_apply_mapping_types_expected.py @@ -0,0 +1,8 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +from pyspark.sql.functions import col +spark = SparkSession.builder.getOrCreate() + +src = spark.read.table('raw.events') + +mapped = src.withColumn('id', col('id').cast('tinyint')).withColumn('counter', col('counter').cast('smallint')).withColumn('total', col('total').cast('bigint')).withColumn('active', col('active').cast('boolean')).withColumn('label', col('label').cast('string')).withColumn('amount', col('amount').cast('decimal(10,2)')) diff --git a/tests/resources/functional/glue/writes/test_write_jdbc.py b/tests/resources/functional/glue/writes/test_write_jdbc.py new file mode 100644 index 0000000000..dfffb1f3db --- /dev/null +++ b/tests/resources/functional/glue/writes/test_write_jdbc.py @@ -0,0 +1,17 @@ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +result_df = spark.read.table("processed.orders") + +glueContext.write_dynamic_frame.from_options( + frame=result_df, + connection_type="jdbc", + connection_options={ + "url": "jdbc:postgresql://host:5432/mydb", + "dbtable": "public.orders_out", + }, +) diff --git a/tests/resources/functional/glue/writes/test_write_jdbc_expected.py b/tests/resources/functional/glue/writes/test_write_jdbc_expected.py new file mode 100644 index 0000000000..876f08efff --- /dev/null +++ b/tests/resources/functional/glue/writes/test_write_jdbc_expected.py @@ -0,0 +1,7 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +result_df = spark.read.table("processed.orders") + +result_df.write.format('jdbc').option('url', 'jdbc:postgresql://host:5432/mydb').option('dbtable', 'public.orders_out').save() diff --git a/tests/resources/functional/glue/writes/test_write_partitioned.py b/tests/resources/functional/glue/writes/test_write_partitioned.py new file mode 100644 index 0000000000..2725e66c0c --- /dev/null +++ b/tests/resources/functional/glue/writes/test_write_partitioned.py @@ -0,0 +1,15 @@ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +events_df = spark.read.table("processed.events") + +glueContext.write_dynamic_frame.from_options( + frame=events_df, + connection_type="s3", + connection_options={"path": "s3://my-bucket/output/events/", "partitionKeys": ["year", "month"]}, + format="parquet", +) diff --git a/tests/resources/functional/glue/writes/test_write_partitioned_expected.py b/tests/resources/functional/glue/writes/test_write_partitioned_expected.py new file mode 100644 index 0000000000..29193187a8 --- /dev/null +++ b/tests/resources/functional/glue/writes/test_write_partitioned_expected.py @@ -0,0 +1,7 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +events_df = spark.read.table("processed.events") + +events_df.write.format('parquet').partitionBy('year', 'month').save('s3://my-bucket/output/events/') diff --git a/tests/resources/functional/glue/writes/test_write_s3_parquet.py b/tests/resources/functional/glue/writes/test_write_s3_parquet.py new file mode 100644 index 0000000000..d78ffdc96d --- /dev/null +++ b/tests/resources/functional/glue/writes/test_write_s3_parquet.py @@ -0,0 +1,15 @@ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +output_df = spark.read.table("processed.orders") + +glueContext.write_dynamic_frame.from_options( + frame=output_df, + connection_type="s3", + connection_options={"path": "s3://my-bucket/output/orders/"}, + format="parquet", +) diff --git a/tests/resources/functional/glue/writes/test_write_s3_parquet_expected.py b/tests/resources/functional/glue/writes/test_write_s3_parquet_expected.py new file mode 100644 index 0000000000..8b5ae65dae --- /dev/null +++ b/tests/resources/functional/glue/writes/test_write_s3_parquet_expected.py @@ -0,0 +1,7 @@ +from pyspark.context import SparkContext +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +output_df = spark.read.table("processed.orders") + +output_df.write.format('parquet').save('s3://my-bucket/output/orders/') diff --git a/tests/unit/transpiler/test_glue_engine.py b/tests/unit/transpiler/test_glue_engine.py new file mode 100644 index 0000000000..5ae10cc56a --- /dev/null +++ b/tests/unit/transpiler/test_glue_engine.py @@ -0,0 +1,222 @@ +"""Unit tests for GlueEngine and the parametrized fixture-based transpilation.""" +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from databricks.labs.lakebridge.transpiler.glue.glue_engine import GlueEngine +from databricks.labs.lakebridge.transpiler.transpile_status import ErrorSeverity + +_GLUE_FIXTURES = Path(__file__).parent.parent.parent / "resources" / "functional" / "glue" + +# Fixtures that require non-default engine options are excluded from the default +# parametrized suite and tested separately below. +_SKIP_IN_DEFAULT_SUITE = frozenset(["test_dbutils_widgets"]) + + +def _load_fixture_pairs(root: Path, skip: frozenset[str] = frozenset()) -> list[tuple[Path, Path]]: + pairs: list[tuple[Path, Path]] = [] + for input_file in sorted(root.rglob("test_*.py")): + if input_file.stem.endswith("_expected"): + continue + if input_file.stem in skip: + continue + expected_file = input_file.with_stem(input_file.stem + "_expected") + if expected_file.exists(): + pairs.append((input_file, expected_file)) + return pairs + + +_ALL_FIXTURES = _load_fixture_pairs(_GLUE_FIXTURES, skip=_SKIP_IN_DEFAULT_SUITE) +_FIXTURE_IDS = [f"{f.parent.name}/{f.stem}" for f, _ in _ALL_FIXTURES] + + +@pytest.fixture +def engine() -> GlueEngine: + return GlueEngine() + + +# ────────────────────────────────────────────────────────────────────────────── +# Engine contract tests +# ────────────────────────────────────────────────────────────────────────────── + + +def test_transpiler_name(engine: GlueEngine): + assert engine.transpiler_name == "glue" + + +def test_supported_dialects(engine: GlueEngine): + assert "glue" in engine.supported_dialects + + +def test_is_supported_file_py(engine: GlueEngine): + assert engine.is_supported_file(Path("job.py")) + + +def test_is_supported_file_py_uppercase(engine: GlueEngine): + assert engine.is_supported_file(Path("job.PY")) + + +def test_is_not_supported_file_sql(engine: GlueEngine): + assert not engine.is_supported_file(Path("query.sql")) + + +def test_is_not_supported_file_ipynb(engine: GlueEngine): + assert not engine.is_supported_file(Path("notebook.ipynb")) + + +def test_transpile_returns_transpile_result(engine: GlueEngine): + from databricks.labs.lakebridge.config import TranspileResult + + result = asyncio.run(engine.transpile("glue", "databricks", "x = 1\n", Path("x.py"))) + assert isinstance(result, TranspileResult) + + +def test_transpile_success_count_is_one(engine: GlueEngine): + result = asyncio.run(engine.transpile("glue", "databricks", "x = 1\n", Path("x.py"))) + assert result.success_count == 1 + + +def test_transpile_syntax_error_returns_zero_success(engine: GlueEngine): + result = asyncio.run(engine.transpile("glue", "databricks", "def bad(:\n pass", Path("bad.py"))) + assert result.success_count == 0 + assert result.error_list + assert result.error_list[0].severity.value == "ERROR" + + +def test_transpile_warnings_appear_in_error_list(engine: GlueEngine): + source = """\ +from awsglue.context import GlueContext +from awsglue.transforms import * +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +df = glueContext.create_dynamic_frame.from_catalog(database="db", table_name="t") +out = ResolveChoice.apply(frame=df, choice="make_cols") +""" + result = asyncio.run(engine.transpile("glue", "databricks", source, Path("job.py"))) + assert result.success_count == 1 + assert any(e.severity == ErrorSeverity.WARNING for e in result.error_list) + + +# ────────────────────────────────────────────────────────────────────────────── +# Engine options: catalog and args_style +# ────────────────────────────────────────────────────────────────────────────── + + +def test_engine_catalog_option_applied(): + """Unity Catalog prefix is prepended when catalog option is configured.""" + from databricks.labs.lakebridge.config import TranspileConfig + + engine = GlueEngine() + config = TranspileConfig( + transpiler_config_path="glue", + source_dialect="glue", + output_folder="/tmp", + transpiler_options={"catalog": "my_catalog"}, + ) + asyncio.run(engine.initialize(config)) + + source = """\ +from awsglue.context import GlueContext +from pyspark.context import SparkContext + +sc = SparkContext() +glueContext = GlueContext(sc) +spark = glueContext.spark_session + +df = glueContext.create_dynamic_frame.from_catalog(database="mydb", table_name="orders") +""" + result = asyncio.run(engine.transpile("glue", "databricks", source, Path("job.py"))) + assert "my_catalog.mydb.orders" in result.transpiled_code + + +def test_engine_dbutils_args_style_applied(): + """dbutils.widgets blocks are generated when args-style=dbutils is configured.""" + from databricks.labs.lakebridge.config import TranspileConfig + + engine = GlueEngine() + config = TranspileConfig( + transpiler_config_path="glue", + source_dialect="glue", + output_folder="/tmp", + transpiler_options={"args-style": "dbutils"}, + ) + asyncio.run(engine.initialize(config)) + + source = """\ +import sys +from awsglue.utils import getResolvedOptions + +args = getResolvedOptions(sys.argv, ["JOB_NAME", "source_db"]) + +val = args["source_db"] +""" + result = asyncio.run(engine.transpile("glue", "databricks", source, Path("job.py"))) + assert "dbutils.widgets.text" in result.transpiled_code + assert "argparse" not in result.transpiled_code + + +def test_engine_invalid_output_adds_warning(monkeypatch): + """If the transformer produces syntactically invalid Python, a warning is added.""" + from databricks.labs.lakebridge.transpiler.glue import glue_engine as ge_mod + + original_transformer = ge_mod.GlueTransformer + + class _BrokenTransformer: + def __init__(self, *args, **kwargs): + pass + + def transform(self, source_code: str) -> tuple[str, list[str]]: + return "def broken(:\n pass\n", [] + + monkeypatch.setattr(ge_mod, "GlueTransformer", _BrokenTransformer) + + engine = GlueEngine() + result = asyncio.run(engine.transpile("glue", "databricks", "x = 1\n", Path("x.py"))) + assert any("syntax error" in w.message.lower() for w in result.error_list) + + +# ────────────────────────────────────────────────────────────────────────────── +# Fixture test: dbutils style (requires custom engine) +# ────────────────────────────────────────────────────────────────────────────── + + +def test_dbutils_widgets_fixture(): + """Fixture test for args/test_dbutils_widgets.py using args_style=dbutils engine.""" + from databricks.labs.lakebridge.config import TranspileConfig + + engine = GlueEngine() + config = TranspileConfig( + transpiler_config_path="glue", + source_dialect="glue", + output_folder="/tmp", + transpiler_options={"args-style": "dbutils"}, + ) + asyncio.run(engine.initialize(config)) + + input_file = _GLUE_FIXTURES / "args" / "test_dbutils_widgets.py" + expected_file = _GLUE_FIXTURES / "args" / "test_dbutils_widgets_expected.py" + + source = input_file.read_text() + expected = expected_file.read_text() + result = asyncio.run(engine.transpile("glue", "databricks", source, input_file)) + assert result.transpiled_code.strip() == expected.strip() + + +# ────────────────────────────────────────────────────────────────────────────── +# Parametrized fixture tests (default engine, argparse mode) +# ────────────────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize("input_file,expected_file", _ALL_FIXTURES, ids=_FIXTURE_IDS) +def test_glue_fixture(engine: GlueEngine, input_file: Path, expected_file: Path): + source = input_file.read_text() + expected = expected_file.read_text() + result = asyncio.run(engine.transpile("glue", "databricks", source, input_file)) + assert result.transpiled_code.strip() == expected.strip() diff --git a/tests/unit/transpiler/test_glue_transformer.py b/tests/unit/transpiler/test_glue_transformer.py new file mode 100644 index 0000000000..7910cab87f --- /dev/null +++ b/tests/unit/transpiler/test_glue_transformer.py @@ -0,0 +1,655 @@ +"""Unit tests for GlueTransformer — inline string-based, no fixture files.""" +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest + +from databricks.labs.lakebridge.transpiler.glue.glue_transformer import ( + GlueTransformer, + _map_glue_type, +) + + +def _transform(source: str, catalog: str | None = None, args_style: str = "argparse") -> tuple[str, list[str]]: + t = GlueTransformer(Path("test_job.py"), catalog=catalog, args_style=args_style) + return t.transform(textwrap.dedent(source)) + + +# ────────────────────────────────────────────────────────────────────────────── +# Import rewriting +# ────────────────────────────────────────────────────────────────────────────── + + +def test_awsglue_context_import_replaced(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + x = 1 + """ + ) + assert "from pyspark.sql import SparkSession" in code + assert "awsglue" not in code + + +def test_awsglue_transforms_import_removed(): + code, _ = _transform( + """\ + from awsglue.transforms import * + x = 1 + """ + ) + assert "awsglue" not in code + + +def test_awsglue_utils_import_replaced_with_argparse(): + code, _ = _transform( + """\ + from awsglue.utils import getResolvedOptions + x = 1 + """ + ) + assert "import argparse" in code + assert "awsglue" not in code + + +def test_awsglue_dynamicframe_import_removed(): + code, _ = _transform( + """\ + from awsglue.dynamicframe import DynamicFrame + x = 1 + """ + ) + assert "awsglue" not in code + + +def test_awsglue_job_import_removed(): + code, _ = _transform( + """\ + from awsglue.job import Job + x = 1 + """ + ) + assert "awsglue" not in code + + +def test_non_awsglue_imports_preserved(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + import pandas as pd + from pyspark.sql import functions as F + x = 1 + """ + ) + assert "import pandas as pd" in code + assert "from pyspark.sql import functions as F" in code + + +def test_new_imports_inserted_after_existing_imports(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + import os + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + """ + ) + lines = [l for l in code.splitlines() if l.strip()] + import_lines = [i for i, l in enumerate(lines) if l.startswith(("import ", "from "))] + non_import_lines = [i for i, l in enumerate(lines) if not l.startswith(("import ", "from "))] + # All imports should appear before non-import statements + assert max(import_lines) < min(non_import_lines) + + +# ────────────────────────────────────────────────────────────────────────────── +# GlueContext setup collapsed +# ────────────────────────────────────────────────────────────────────────────── + + +def test_glue_context_setup_collapsed(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + """ + ) + assert "SparkSession.builder.getOrCreate()" in code + assert "SparkContext()" not in code + assert "GlueContext(" not in code + assert "spark_session" not in code + + +# ────────────────────────────────────────────────────────────────────────────── +# DynamicFrame reads +# ────────────────────────────────────────────────────────────────────────────── + + +def test_from_catalog_simple(): + code, warnings = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_catalog(database="mydb", table_name="mytable") + """ + ) + # repr() produces single quotes for string literals + assert "spark.read.table('mydb.mytable')" in code + assert not warnings + + +def test_from_catalog_with_unity_catalog(): + code, warnings = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_catalog(database="mydb", table_name="mytable") + """, + catalog="my_catalog", + ) + assert "spark.read.table('my_catalog.mydb.mytable')" in code + assert not warnings + + +def test_from_options_s3_read(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_options( + connection_type="s3", + connection_options={"path": "s3://bucket/prefix/"}, + format="parquet", + ) + """ + ) + assert "spark.read.format('parquet').load('s3://bucket/prefix/')" in code + + +def test_from_options_s3a_path_normalized(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_options( + connection_type="s3", + connection_options={"path": "s3a://bucket/prefix/"}, + format="parquet", + ) + """ + ) + assert "s3a://" not in code + assert "s3://bucket/prefix/" in code + + +def test_from_options_s3n_path_normalized(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_options( + connection_type="s3", + connection_options={"path": "s3n://bucket/prefix/"}, + format="parquet", + ) + """ + ) + assert "s3n://" not in code + assert "s3://bucket/prefix/" in code + + +# JDBC read +def test_from_options_jdbc_read(): + code, warnings = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_options( + connection_type="jdbc", + connection_options={ + "url": "jdbc:postgresql://host:5432/db", + "dbtable": "public.orders", + }, + ) + """ + ) + assert "spark.read.format('jdbc')" in code + assert ".option('url', 'jdbc:postgresql://host:5432/db')" in code + assert ".option('dbtable', 'public.orders')" in code + assert ".load()" in code + assert not warnings + + +def test_from_options_unknown_type_warns(): + _, warnings = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_options( + connection_type="dynamodb", + connection_options={"tableName": "mytable"}, + ) + """ + ) + assert any("dynamodb" in w for w in warnings) + + +# ────────────────────────────────────────────────────────────────────────────── +# ApplyMapping +# ────────────────────────────────────────────────────────────────────────────── + + +def test_apply_mapping_injects_col_import(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from awsglue.transforms import * + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + src = glueContext.create_dynamic_frame.from_catalog(database="db", table_name="t") + out = ApplyMapping.apply(frame=src, mappings=[("col_a", "string", "col_a", "string")]) + """ + ) + assert "from pyspark.sql.functions import col" in code + + +def test_apply_mapping_same_name_no_rename(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from awsglue.transforms import * + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + src = glueContext.create_dynamic_frame.from_catalog(database="db", table_name="t") + out = ApplyMapping.apply(frame=src, mappings=[("col_a", "string", "col_a", "string")]) + """ + ) + # Same name → no withColumnRenamed, only cast + assert "withColumnRenamed" not in code + assert ".withColumn('col_a', col('col_a').cast('string'))" in code + + +def test_apply_mapping_rename_and_cast(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from awsglue.transforms import * + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + src = glueContext.create_dynamic_frame.from_catalog(database="db", table_name="t") + out = ApplyMapping.apply(frame=src, mappings=[("old_id", "string", "new_id", "long")]) + """ + ) + assert ".withColumnRenamed('old_id', 'new_id')" in code + assert ".cast('bigint')" in code + + +def test_apply_mapping_glue_type_conversions(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from awsglue.transforms import * + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + src = glueContext.create_dynamic_frame.from_catalog(database="db", table_name="t") + out = ApplyMapping.apply(frame=src, mappings=[ + ("a", "byte", "a", "byte"), + ("b", "short", "b", "short"), + ("c", "bool", "c", "bool"), + ("d", "char", "d", "char"), + ]) + """ + ) + assert ".cast('tinyint')" in code # byte → tinyint + assert ".cast('smallint')" in code # short → smallint + assert ".cast('boolean')" in code # bool → boolean + assert ".cast('string')" in code # char → string + + +def test_apply_mapping_decimal_preserved(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from awsglue.transforms import * + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + src = glueContext.create_dynamic_frame.from_catalog(database="db", table_name="t") + out = ApplyMapping.apply(frame=src, mappings=[ + ("amount", "decimal(10,2)", "amount", "decimal(10,2)"), + ]) + """ + ) + assert ".cast('decimal(10,2)')" in code + + +# ────────────────────────────────────────────────────────────────────────────── +# Writes +# ────────────────────────────────────────────────────────────────────────────── + + +def test_write_s3_parquet(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + output_df = spark.read.table("db.t") + + glueContext.write_dynamic_frame.from_options( + frame=output_df, + connection_type="s3", + connection_options={"path": "s3://bucket/out/"}, + format="parquet", + ) + """ + ) + assert "output_df.write.format('parquet').save('s3://bucket/out/')" in code + + +def test_write_partitioned(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + out_df = spark.read.table("db.t") + + glueContext.write_dynamic_frame.from_options( + frame=out_df, + connection_type="s3", + connection_options={"path": "s3://bucket/out/", "partitionKeys": ["year", "month"]}, + format="parquet", + ) + """ + ) + assert ".partitionBy('year', 'month')" in code + assert ".save('s3://bucket/out/')" in code + + +def test_write_jdbc(): + code, warnings = _transform( + """\ + from awsglue.context import GlueContext + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + result_df = spark.read.table("db.t") + + glueContext.write_dynamic_frame.from_options( + frame=result_df, + connection_type="jdbc", + connection_options={ + "url": "jdbc:postgresql://host:5432/db", + "dbtable": "public.orders_out", + }, + ) + """ + ) + assert "result_df.write.format('jdbc')" in code + assert ".option('url', 'jdbc:postgresql://host:5432/db')" in code + assert ".option('dbtable', 'public.orders_out')" in code + assert ".save()" in code + assert not warnings + + +# ────────────────────────────────────────────────────────────────────────────── +# Job boilerplate +# ────────────────────────────────────────────────────────────────────────────── + + +def test_job_boilerplate_removed(): + code, _ = _transform( + """\ + from awsglue.context import GlueContext + from awsglue.job import Job + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + job = Job(glueContext) + job.init("my-job", {}) + + result = 1 + + job.commit() + """ + ) + assert "Job(" not in code + assert "job.init(" not in code + assert "job.commit()" not in code + assert "result = 1" in code + + +# ────────────────────────────────────────────────────────────────────────────── +# getResolvedOptions → argparse +# ────────────────────────────────────────────────────────────────────────────── + + +def test_args_static_list(): + code, warnings = _transform( + """\ + import sys + from awsglue.utils import getResolvedOptions + + args = getResolvedOptions(sys.argv, ["JOB_NAME", "param1", "param2"]) + + val = args["param1"] + """ + ) + assert "_parser = argparse.ArgumentParser()" in code + assert '_parser.add_argument("--param1")' in code + assert '_parser.add_argument("--param2")' in code + # JOB_NAME should NOT become an argparse argument + assert "JOB_NAME" not in code + assert "args = vars(_parser.parse_args())" in code + assert not warnings + + +def test_args_dynamic_list_warns(): + _, warnings = _transform( + """\ + import sys + from awsglue.utils import getResolvedOptions + + param_list = ["param1", "param2"] + args = getResolvedOptions(sys.argv, param_list) + """ + ) + assert any("non-literal" in w for w in warnings) + + +def test_args_dbutils_style(): + code, warnings = _transform( + """\ + import sys + from awsglue.utils import getResolvedOptions + + args = getResolvedOptions(sys.argv, ["JOB_NAME", "source_db", "output_path"]) + + source = args["source_db"] + """, + args_style="dbutils", + ) + assert "dbutils.widgets.text('source_db', \"\")" in code + assert "dbutils.widgets.text('output_path', \"\")" in code + assert "dbutils.widgets.get" in code + assert "argparse" not in code + assert "JOB_NAME" not in code + assert not warnings + + +# ────────────────────────────────────────────────────────────────────────────── +# Unsupported transforms +# ────────────────────────────────────────────────────────────────────────────── + + +def test_unsupported_transform_warns(): + _, warnings = _transform( + """\ + from awsglue.context import GlueContext + from awsglue.transforms import * + from pyspark.context import SparkContext + + sc = SparkContext() + glueContext = GlueContext(sc) + spark = glueContext.spark_session + + df = glueContext.create_dynamic_frame.from_catalog(database="db", table_name="t") + out = ResolveChoice.apply(frame=df, choice="make_cols") + """ + ) + assert any("ResolveChoice" in w for w in warnings) + + +# ────────────────────────────────────────────────────────────────────────────── +# Comment & whitespace preservation +# ────────────────────────────────────────────────────────────────────────────── + + +def test_comments_preserved(): + code, _ = _transform( + """\ + # This is a top-level comment + import sys # inline comment + + # Section comment + x = 1 + """ + ) + assert "# This is a top-level comment" in code + assert "# inline comment" in code + assert "# Section comment" in code + + +def test_blank_lines_preserved(): + source = textwrap.dedent( + """\ + import sys + + x = 1 + + y = 2 + """ + ) + code, _ = _transform(source) + assert "\n\n" in code + + +# ────────────────────────────────────────────────────────────────────────────── +# Error handling +# ────────────────────────────────────────────────────────────────────────────── + + +def test_syntax_error_raises(): + t = GlueTransformer(Path("bad.py")) + with pytest.raises(Exception): + t.transform("def broken(:\n pass") + + +# ────────────────────────────────────────────────────────────────────────────── +# _map_glue_type +# ────────────────────────────────────────────────────────────────────────────── + + +@pytest.mark.parametrize( + "glue_type,expected", + [ + ("long", "bigint"), + ("short", "smallint"), + ("byte", "tinyint"), + ("bool", "boolean"), + ("char", "string"), + ("varchar", "string"), + ("int", "int"), + ("integer", "int"), + ("double", "double"), + ("float", "float"), + ("string", "string"), + ("binary", "binary"), + ("date", "date"), + ("timestamp", "timestamp"), + # decimal with precision/scale preserved as-is + ("decimal(10,2)", "decimal(10,2)"), + ("DECIMAL(18, 4)", "decimal(18, 4)"), + # unknown type passed through + ("custom_type", "custom_type"), + ], +) +def test_map_glue_type(glue_type: str, expected: str): + assert _map_glue_type(glue_type) == expected