diff --git a/CHANGELOG.md b/CHANGELOG.md index 595fc076ec..2780f7f653 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### Bug Fixes - Fixed a bug where stage paths and file format names that contain single quotes were not consistently escaped when generating SQL, which could produce malformed statements. This affects `INFER_SCHEMA` (used by `DataFrameReader.csv`/`json`/`parquet`/`orc`/`avro`) and `COPY FILES` (used by `FileOperation.copy_files`). +- Fixed a bug where single quotes and backslashes in stage/file paths were not correctly escaped when generating `COPY INTO` / `PUT` / `GET` SQL, which could produce malformed statements. This affects `DataFrame.write.csv`/`copy_into_location` and the Snowpark-pandas `DataFrame.to_csv` stage path. - Fixed a bug where column names containing quote characters returned by an external database were not correctly escaped when generating the `SELECT` query for `DataFrameReader.dbapi`, which could produce malformed SQL. Embedded quote characters in identifiers are now doubled (backticks for Databricks/MySQL, double quotes for Oracle/PostgreSQL/SQL Server). - Fixed a bug where the destination passed to `DataFrameWriter.copy_into_location` (and `csv`/`json`/`parquet`/`save`) was embedded into the generated `COPY INTO` statement without quoting, which could produce malformed SQL for locations containing single quotes. The location is now consistently quoted and escaped, and a string that merely starts and ends with a single quote but contains unescaped interior quotes is no longer treated as an already-quoted literal; it is fully escaped so it stays a single SQL string literal. - Fixed a bug where UDF default argument values reconstructed from a source file in `register_from_file` were evaluated with `eval()`; they are now evaluated only against the documented set of supported default-value types, and unsupported expressions are ignored. diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index d04ee6a926..dc2fd7ecd7 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -481,7 +481,17 @@ def normalize_path(path: str, is_local: bool) -> str: return path if is_local and OPERATING_SYSTEM == "Windows": path = path.replace("\\", "/") - path = path.strip().replace("'", "\\'") + # Escape the backslash before the single quote so the path stays a single + # Snowflake string literal; the reverse order would let an escaped quote + # close the literal early and produce invalid SQL. Constants keep the + # replacements readable (no Python escape double-counting). + BACKSLASH = "\\" + SINGLE_QUOTE = "'" + path = ( + path.strip() + .replace(BACKSLASH, BACKSLASH * 2) # \ -> \\ + .replace(SINGLE_QUOTE, BACKSLASH + SINGLE_QUOTE) # ' -> \' + ) if not any(path.startswith(prefix) for prefix in prefixes): path = f"{prefixes[0]}{path}" return f"'{path}'" diff --git a/tests/integ/modin/io/test_to_csv.py b/tests/integ/modin/io/test_to_csv.py index e2027445cb..34c2befc7c 100644 --- a/tests/integ/modin/io/test_to_csv.py +++ b/tests/integ/modin/io/test_to_csv.py @@ -13,7 +13,7 @@ from numpy.testing import assert_equal import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.utils.sql_counter import sql_count_checker +from tests.integ.utils.sql_counter import sql_count_checker, SqlCounter from tests.utils import Utils temp_dir = tempfile.TemporaryDirectory() @@ -293,3 +293,57 @@ def test_timedeltaindex_to_csv_dataframe_local(): pd.DataFrame(native_df).to_csv(snow_path) assert_file_equal(snow_path, native_path, is_compressed=False) + + +def test_to_csv_stage_path_escapes_special_characters(sf_stage, session): + """Snowpark-pandas ``to_csv`` to a stage path must escape special characters + in the path. + + ``DataFrame.to_csv(path_or_buf="@stage/...")`` routes server-side into + ``snowpark_df.write.csv(location=...)`` -> ``normalize_path`` -> + ``COPY INTO ``. A path containing a backslash immediately followed + by a single quote must stay inside the stage-location string literal so the + generated ``COPY INTO`` is valid and the path is treated as literal data. + """ + snow_df = pd.DataFrame({"A": ["one", "two", "three"], "B": [1, 2, 3]}) + # None index name is not supported when writing to a Snowflake stage. + snow_df.index.set_names(["X"], inplace=True) + + # (a) Stage path whose directory name contains a single quote. The quote is + # escaped as literal data, so the write succeeds and the file lands under + # that exact name. ``to_csv`` to a stage emits one query (the COPY INTO); + # downloading it back confirms the path was treated as a literal file + # name and not parsed as SQL. + quote_name = "o'clock/mods.csv" + quote_path = f"@{sf_stage}/{quote_name}" + with SqlCounter(query_count=1): + snow_df.to_csv(quote_path, index=False) + listed = [row[0] for row in session.sql(f"LIST '@{sf_stage}'").collect()] + assert any(name.endswith(quote_name) for name in listed), listed + + download_dir = tempfile.mkdtemp() + session.file.get(quote_path, download_dir) + downloaded = [ + f + for f in os.listdir(download_dir) + if os.path.isfile(os.path.join(download_dir, f)) + ] + assert len(downloaded) == 1, downloaded + with open(os.path.join(download_dir, downloaded[0])) as fh: + content = fh.read() + data_rows = [line for line in content.splitlines() if line.strip()] + # Header ("A,B") + the DataFrame's own 3 data rows == 4 lines. + assert len(data_rows) == 4, content + assert content == "A,B\none,1\ntwo,2\nthree,3\n", content + + # (b) A file name mixing a backslash, a single quote, parentheses, a comma + # and a trailing ``--`` must produce valid SQL: before the fix the + # unescaped backslash/quote closed the location literal early. The write + # must succeed with a single COPY INTO query -- if the path were parsed as + # SQL the statement would error instead. Stage storage does not preserve a + # literal backslash as a path character, so we assert the write succeeds + # rather than reading back the exact name. + special_name = "report\\' , (note) -- draft" + special_path = f"@{sf_stage}/{special_name}" + with SqlCounter(query_count=1): + snow_df.to_csv(special_path, index=False) diff --git a/tests/integ/scala/test_dataframe_writer_suite.py b/tests/integ/scala/test_dataframe_writer_suite.py index 835a4f3589..190af4aaa2 100644 --- a/tests/integ/scala/test_dataframe_writer_suite.py +++ b/tests/integ/scala/test_dataframe_writer_suite.py @@ -989,6 +989,45 @@ def test_writer_csv(session, temp_stage, caplog): Utils.check_answer(data7, df) +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="COPY INTO is not supported in Local Testing", +) +def test_writer_csv_stage_path_escapes_special_characters(session, temp_stage): + """``DataFrame.write.csv`` routes the destination through ``normalize_path``, + which must escape both backslashes and single quotes so that a path + containing a backslash immediately followed by a single quote stays inside + the stage-location string literal in the generated ``COPY INTO`` and the + SQL is always valid. + + Each write below uses a path with characters that, before the fix, would + close the location string literal early and produce invalid SQL (a + backslash, a single quote, a ``\\'`` combination, parentheses, a comma and a + trailing ``--``). The writes must now succeed with the DataFrame's own rows + unloaded, which proves the path is escaped as literal data and not parsed as + SQL. Note: a literal backslash is not preserved as a directory separator by + stage storage, so we assert the write succeeds rather than a read-back + round-trip. + """ + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + + special_paths = [ + # Directory name containing a backslash. + f"{temp_stage}/back\\slash_dir/data.csv", + # Directory name containing a single quote. + f"{temp_stage}/o'clock/data.csv", + # Directory name containing a backslash immediately followed by a quote. + f"{temp_stage}/mix\\'both/data.csv", + # File name mixing a backslash-quote, parentheses, a comma and a + # trailing ``--`` -- all must be treated as literal path characters. + f"@{temp_stage}/out\\' , (note) -- draft", + ] + for path in special_paths: + result = df.write.csv(path, single=True) + # The DataFrame's own rows are unloaded; the path is not parsed as SQL. + assert result[0].rows_unloaded == 2, path + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="BUG: SNOW-1235716 should raise not implemented error not AttributeError: 'MockExecutionPlan' object has no attribute 'replace_repeated_subquery_with_cte', FEAT: parquet support", diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index 02ca8412e0..d226039a7a 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -167,7 +167,9 @@ def test_normalize_file(is_local): assert normalize_path(name2, is_local) == f"'{symbol}sta\\'ge'" name3 = "s ta\\'ge " assert normalize_path(name3, is_local) == ( - f"'{symbol}s ta/\\'ge'" if is_local and IS_WINDOWS else f"'{symbol}s ta\\\\'ge'" + f"'{symbol}s ta/\\'ge'" + if is_local and IS_WINDOWS + else f"'{symbol}s ta\\\\\\'ge'" ) diff --git a/tests/unit/test_internal_utils.py b/tests/unit/test_internal_utils.py index fcad4a21c4..f09cfc63d3 100644 --- a/tests/unit/test_internal_utils.py +++ b/tests/unit/test_internal_utils.py @@ -76,6 +76,78 @@ def test_normalize_path(path: str, is_local: bool, expected: str) -> None: assert expected == actual +def _decode_snowflake_literal(literal: str) -> str: + """Simulate Snowflake's decoding of a single-quoted string literal. + + Snowflake treats ``\\`` as an escape character inside a single-quoted literal, + so ``\\\\`` decodes to one backslash and ``\\'`` decodes to one single quote. + An unescaped single quote closes the literal. This helper returns the decoded + literal value and raises if the literal is closed early -- which would mean the + path was not escaped correctly and the generated SQL is invalid. + """ + assert literal.startswith("'") and literal.endswith( + "'" + ), f"not a quoted literal: {literal!r}" + body = literal[1:-1] + out = [] + i = 0 + while i < len(body): + ch = body[i] + if ch == "\\" and i + 1 < len(body): + out.append(body[i + 1]) + i += 2 + elif ch == "'": + raise AssertionError( + f"unescaped quote closes literal early at index {i}: {literal!r}" + ) + else: + out.append(ch) + i += 1 + return "".join(out) + + +@pytest.mark.parametrize("is_local", [True, False]) +@pytest.mark.parametrize( + "raw_path", + [ + # Paths containing a backslash immediately followed by a single quote, + # plus parentheses, commas and a trailing ``--``. Before the fix the + # backslash was not escaped, so ``\'`` was written as ``\\'`` and closed + # the literal early, producing invalid SQL. + "@~/out\\' , (note) FILE_FORMAT=(TYPE=CSV) -- draft", + "report\\' , (v2) draft --", + # Plain special characters that must round-trip as literal data. + "@stage/o'clock/file.csv", + "@stage/back\\slash/file.csv", + "@stage/double\\\\back/file.csv", + '@stage/dquote"/file.csv', + "@stage/uniƩcode/file.csv", + "@stage/all\\'\"mix/file.csv", + ], +) +def test_normalize_path_escapes_backslash_and_quote(raw_path, is_local): + """``normalize_path`` must produce a Snowflake string literal that decodes back + to the original path. A backslash followed by a single quote must stay inside + the literal and not close it early, so the generated SQL is always valid and + the path is treated as literal data.""" + literal = utils.normalize_path(raw_path, is_local) + # The output must be a well-formed single-quoted literal: decoding it must not + # raise (i.e. the literal is not closed early). + decoded = _decode_snowflake_literal(literal) + # The decoded literal must end with the (stripped) raw path -- the prefix may + # differ only by an added ``@`` / ``file://`` scheme prefix. + expected_tail = raw_path.strip() + # Local paths on Windows are normalized (backslashes -> forward slashes) + # before escaping, so mirror that transform here. This only affects the + # round-trip comparison; the escaping guarantee checked above (the literal + # never closes early) still holds for every input on every platform. + if is_local and utils.OPERATING_SYSTEM == "Windows": + expected_tail = expected_tail.replace("\\", "/") + assert decoded.endswith( + expected_tail + ), f"decoded={decoded!r} does not end with {expected_tail!r}" + + def test__pandas_importer(): imported_pandas = _pandas_importer() try: