From 50ef68d0995d009557713208e50e07de122b7e3f Mon Sep 17 00:00:00 2001 From: Bob Kline Date: Mon, 15 Dec 2025 08:59:38 -0500 Subject: [PATCH] Avoid heap corruption with NULLs in TVPs (#1450) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `GetParamType()` function calls `SQLDescribeParam()` and caches the type information it gets from the driver in a dynamically allocated array with one slot for each top-level parameter. This should not be done for nested table column "parameters" because: - that corrupts the heap when a TVP has more columns than slots; - it overwrites top-level parameter types with the wrong values; - `SQLDescribeParam()` fails for nested column "parameters" anyway. So for NULL TVP columns we fall back on SQL_VARCHAR. With the fix contained in this commit, the new test never fails (for my testing, at least). Without the fix, that test sometimes passes and sometimes crashes the Python interpreter (and *always* crashed Python when the `tests/__pycache__` directory was not present in my testing--not sure why 🤷‍♂️). Closes #1450 --- src/params.cpp | 12 ++++++++--- tests/sqlserver_test.py | 48 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/params.cpp b/src/params.cpp index 089202a8..78029701 100644 --- a/src/params.cpp +++ b/src/params.cpp @@ -534,9 +534,15 @@ static void FreeInfos(ParamInfo* a, Py_ssize_t count) PyMem_Free(a); } -static bool GetNullInfo(Cursor* cur, Py_ssize_t index, ParamInfo& info) +static bool GetNullInfo(Cursor* cur, Py_ssize_t index, ParamInfo& info, bool isTVP) { - if (!GetParamType(cur, index, info.ParameterType)) + // GetParamType won't work for TVP columns, so we fall back on SQL_VARCHAR. + if (isTVP) + { + if (info.ParameterType == SQL_UNKNOWN_TYPE) + info.ParameterType = SQL_VARCHAR; + } + else if (!GetParamType(cur, index, info.ParameterType)) return false; info.ValueType = SQL_C_DEFAULT; @@ -1023,7 +1029,7 @@ bool GetParameterInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& // Populates `info`. if (param == Py_None) - return GetNullInfo(cur, index, info); + return GetNullInfo(cur, index, info, isTVP); if (param == null_binary) return GetNullBinaryInfo(cur, index, info); diff --git a/tests/sqlserver_test.py b/tests/sqlserver_test.py index 9d13ddd9..856a3583 100755 --- a/tests/sqlserver_test.py +++ b/tests/sqlserver_test.py @@ -1,5 +1,6 @@ #!/usr/bin/python +import gc import os import re import uuid @@ -1617,6 +1618,53 @@ def test_tvp_diffschema(cursor: pyodbc.Cursor): _test_tvp(cursor, True) +def _test_tvp_with_nulls_cleanup(cursor: pyodbc.Cursor, procname: str, typename: str): + """Leave the forest as pristine as you found it.""" + + cursor.execute(f"""\ + IF OBJECT_ID(N'dbo.{procname}', N'P') IS NOT NULL + DROP PROCEDURE dbo.{procname}; + """) + cursor.execute(f""" + IF TYPE_ID(N'dbo.{typename}') IS NOT NULL + DROP TYPE dbo.{typename}; + """) + + +@pytest.mark.skipif(SQLSERVER_YEAR < 2008, reason="TVP not supported until 2008") +@pytest.mark.skipif(IS_FREEDTS, reason="FreeTDS does not support TVP") +def test_tvp_with_nulls(cursor: pyodbc.Cursor): + """Make sure NULL values in a TVP don't crash the interpreter.""" + + # Start with a clean slate. + typename = "typeTestNullsInTVP" + procname = "spTestNullsInTVP" + _test_tvp_with_nulls_cleanup(cursor, procname, typename) + + # Create the custom type and stored procedure. + ncols = 100 + cols = ", ".join([f"col_{c:03d} DECIMAL(36,20)" for c in range(1, ncols+1)]) + cursor.execute(f"CREATE TYPE dbo.{typename} AS TABLE ({cols})") + cursor.execute(f"""\ + CREATE PROCEDURE dbo.{procname} + @data dbo.{typename} READONLY + AS + BEGIN + RETURN 0; + END; + """) + cursor.commit() + + # Invoke the stored procedure. + tvp: list[list] = [[3.14159] * ncols, [None] * ncols] + cursor.execute(f"EXEC [dbo].{procname} @data=?", [tvp]) + gc.collect() + + # Be a good digital citizen. + _test_tvp_with_nulls_cleanup(cursor, procname, typename) + cursor.commit() + + @pytest.mark.skipif(SQLSERVER_YEAR < 2000, reason='sql_variant not supported until 2000') def test_sql_variant(cursor: pyodbc.Cursor): """