Skip to content

Commit 15a32a8

Browse files
fix nextset() iteration for executemany with returning=True
1 parent b7c8502 commit 15a32a8

1 file changed

Lines changed: 126 additions & 0 deletions

File tree

drift/instrumentation/psycopg/instrumentation.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,80 @@ def fetchall():
19141914
except AttributeError:
19151915
pass
19161916

1917+
# Set up initial fetch methods for the first result set (for code that uses nextset() instead of results())
1918+
first_column_names = first_set.get("column_names")
1919+
FirstRowClass = create_row_class(first_column_names)
1920+
1921+
def make_fetchone_replay(cn, RC):
1922+
def fetchone():
1923+
if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue]
1924+
row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue]
1925+
cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
1926+
return transform_row(row, cn, RC)
1927+
return None
1928+
return fetchone
1929+
1930+
def make_fetchmany_replay(cn, RC):
1931+
def fetchmany(size=cursor.arraysize):
1932+
rows = []
1933+
for _ in range(size):
1934+
if cursor._mock_index < len(cursor._mock_rows): # pyright: ignore[reportAttributeAccessIssue]
1935+
row = cursor._mock_rows[cursor._mock_index] # pyright: ignore[reportAttributeAccessIssue]
1936+
cursor._mock_index += 1 # pyright: ignore[reportAttributeAccessIssue]
1937+
rows.append(transform_row(row, cn, RC))
1938+
else:
1939+
break
1940+
return rows
1941+
return fetchmany
1942+
1943+
def make_fetchall_replay(cn, RC):
1944+
def fetchall():
1945+
rows = cursor._mock_rows[cursor._mock_index:] # pyright: ignore[reportAttributeAccessIssue]
1946+
cursor._mock_index = len(cursor._mock_rows) # pyright: ignore[reportAttributeAccessIssue]
1947+
return [transform_row(row, cn, RC) for row in rows]
1948+
return fetchall
1949+
1950+
cursor.fetchone = make_fetchone_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue]
1951+
cursor.fetchmany = make_fetchmany_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue]
1952+
cursor.fetchall = make_fetchall_replay(first_column_names, FirstRowClass) # pyright: ignore[reportAttributeAccessIssue]
1953+
1954+
# Patch nextset() to work with _mock_result_sets
1955+
def patched_nextset():
1956+
"""Move to the next result set in _mock_result_sets."""
1957+
next_index = cursor._mock_result_set_index + 1 # pyright: ignore[reportAttributeAccessIssue]
1958+
if next_index < len(cursor._mock_result_sets): # pyright: ignore[reportAttributeAccessIssue]
1959+
cursor._mock_result_set_index = next_index # pyright: ignore[reportAttributeAccessIssue]
1960+
next_set = cursor._mock_result_sets[next_index] # pyright: ignore[reportAttributeAccessIssue]
1961+
cursor._mock_rows = next_set["rows"] # pyright: ignore[reportAttributeAccessIssue]
1962+
cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue]
1963+
1964+
# Update fetch methods for the new result set
1965+
next_column_names = next_set.get("column_names")
1966+
NextRowClass = create_row_class(next_column_names)
1967+
cursor.fetchone = make_fetchone_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue]
1968+
cursor.fetchmany = make_fetchmany_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue]
1969+
cursor.fetchall = make_fetchall_replay(next_column_names, NextRowClass) # pyright: ignore[reportAttributeAccessIssue]
1970+
1971+
# Update description for next result set
1972+
next_description_data = next_set.get("description")
1973+
if next_description_data:
1974+
next_desc = [
1975+
(col["name"], col.get("type_code"), None, None, None, None, None)
1976+
for col in next_description_data
1977+
]
1978+
try:
1979+
cursor._tusk_description = next_desc # pyright: ignore[reportAttributeAccessIssue]
1980+
except AttributeError:
1981+
try:
1982+
cursor.description = next_desc # pyright: ignore[reportAttributeAccessIssue]
1983+
except AttributeError:
1984+
pass
1985+
1986+
return True
1987+
return None
1988+
1989+
cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue]
1990+
19171991
def _finalize_query_span(
19181992
self,
19191993
span: trace.Span,
@@ -2216,6 +2290,58 @@ def patched_fetchall():
22162290

22172291
cursor.results = patched_results # pyright: ignore[reportAttributeAccessIssue]
22182292

2293+
# Set up the first result set immediately for user code that uses nextset() instead of results()
2294+
if all_rows_collected:
2295+
cursor._tusk_rows = all_rows_collected[0] # pyright: ignore[reportAttributeAccessIssue]
2296+
cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue]
2297+
cursor._tusk_result_set_index = 0 # pyright: ignore[reportAttributeAccessIssue]
2298+
2299+
# Create initial fetch methods for the first result set
2300+
def make_patched_fetchone_record():
2301+
def patched_fetchone():
2302+
if cursor._tusk_index < len(cursor._tusk_rows): # pyright: ignore[reportAttributeAccessIssue]
2303+
row = cursor._tusk_rows[cursor._tusk_index] # pyright: ignore[reportAttributeAccessIssue]
2304+
cursor._tusk_index += 1 # pyright: ignore[reportAttributeAccessIssue]
2305+
return row
2306+
return None
2307+
return patched_fetchone
2308+
2309+
def make_patched_fetchmany_record():
2310+
def patched_fetchmany(size=cursor.arraysize):
2311+
result = cursor._tusk_rows[cursor._tusk_index : cursor._tusk_index + size] # pyright: ignore[reportAttributeAccessIssue]
2312+
cursor._tusk_index += len(result) # pyright: ignore[reportAttributeAccessIssue]
2313+
return result
2314+
return patched_fetchmany
2315+
2316+
def make_patched_fetchall_record():
2317+
def patched_fetchall():
2318+
result = cursor._tusk_rows[cursor._tusk_index :] # pyright: ignore[reportAttributeAccessIssue]
2319+
cursor._tusk_index = len(cursor._tusk_rows) # pyright: ignore[reportAttributeAccessIssue]
2320+
return result
2321+
return patched_fetchall
2322+
2323+
cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue]
2324+
cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue]
2325+
cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue]
2326+
2327+
# Patch nextset() to work with _tusk_result_sets
2328+
def patched_nextset():
2329+
"""Move to the next result set in _tusk_result_sets."""
2330+
next_index = cursor._tusk_result_set_index + 1 # pyright: ignore[reportAttributeAccessIssue]
2331+
if next_index < len(cursor._tusk_result_sets): # pyright: ignore[reportAttributeAccessIssue]
2332+
cursor._tusk_result_set_index = next_index # pyright: ignore[reportAttributeAccessIssue]
2333+
cursor._tusk_rows = cursor._tusk_result_sets[next_index] # pyright: ignore[reportAttributeAccessIssue]
2334+
cursor._tusk_index = 0 # pyright: ignore[reportAttributeAccessIssue]
2335+
2336+
# Update fetch methods for the new result set
2337+
cursor.fetchone = make_patched_fetchone_record() # pyright: ignore[reportAttributeAccessIssue]
2338+
cursor.fetchmany = make_patched_fetchmany_record() # pyright: ignore[reportAttributeAccessIssue]
2339+
cursor.fetchall = make_patched_fetchall_record() # pyright: ignore[reportAttributeAccessIssue]
2340+
return True
2341+
return None
2342+
2343+
cursor.nextset = patched_nextset # pyright: ignore[reportAttributeAccessIssue]
2344+
22192345
else:
22202346
output_value = {"rowcount": cursor.rowcount if hasattr(cursor, "rowcount") else -1}
22212347

0 commit comments

Comments
 (0)