Skip to content

Commit 878ddd2

Browse files
committed
only support this in fb_numeric
1 parent eeb4702 commit 878ddd2

7 files changed

Lines changed: 168 additions & 353 deletions

File tree

src/firebolt/async_db/cursor.py

Lines changed: 6 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
TimeoutException,
1717
codes,
1818
)
19-
from sqlparse import parse as parse_sql # type: ignore
2019

2120
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
2221
from firebolt.common._types import ColType, ParameterType, SetParameter
@@ -47,6 +46,7 @@
4746
from firebolt.common.row_set.asynchronous.streaming import StreamingAsyncRowSet
4847
from firebolt.common.statement_formatter import create_statement_formatter
4948
from firebolt.utils.exception import (
49+
ConfigurationError,
5050
EngineNotRunningError,
5151
FireboltDatabaseError,
5252
FireboltError,
@@ -219,6 +219,7 @@ async def _do_execute(
219219
timeout: Optional[float] = None,
220220
async_execution: bool = False,
221221
streaming: bool = False,
222+
bulk_insert: bool = False,
222223
) -> None:
223224
await self._close_rowset_and_reset()
224225
self._row_set = StreamingAsyncRowSet() if streaming else InMemoryAsyncRowSet()
@@ -231,7 +232,7 @@ async def _do_execute(
231232
)
232233

233234
plan = statement_planner.create_execution_plan(
234-
raw_query, parameters, skip_parsing, async_execution, streaming
235+
raw_query, parameters, skip_parsing, async_execution, streaming, bulk_insert
235236
)
236237
await self._execute_plan(plan, timeout)
237238
self._state = CursorState.DONE
@@ -422,105 +423,9 @@ async def executemany(
422423
Returns:
423424
int: Query row count.
424425
"""
425-
if bulk_insert:
426-
return await self._executemany_bulk_insert(
427-
query, parameters_seq, timeout_seconds
428-
)
429-
await self._do_execute(query, parameters_seq, timeout=timeout_seconds)
430-
return self.rowcount
431-
432-
def _validate_bulk_insert_query(self, query: str) -> None:
433-
"""Validate that query is an INSERT statement for bulk_insert."""
434-
query_normalized = query.lstrip().lower()
435-
436-
if not query_normalized.startswith("insert"):
437-
raise ProgrammingError(
438-
"bulk_insert is only supported for INSERT statements"
439-
)
440-
441-
if ";" in query.strip().rstrip(";"):
442-
raise ProgrammingError(
443-
"bulk_insert does not support multi-statement queries"
444-
)
445-
446-
async def _executemany_bulk_insert(
447-
self,
448-
query: str,
449-
parameters_seq: Sequence[Sequence[ParameterType]],
450-
timeout_seconds: Optional[float],
451-
) -> int:
452-
"""Execute multiple INSERT queries as a single batch."""
453-
self._validate_bulk_insert_query(query)
454-
455-
if not parameters_seq:
456-
raise ProgrammingError("bulk_insert requires at least one parameter set")
457-
458-
from firebolt.async_db import paramstyle
459-
460-
try:
461-
parameter_style = ParameterStyle(paramstyle)
462-
except ValueError:
463-
raise ProgrammingError(f"Unsupported paramstyle: {paramstyle}")
464-
465-
concatenated_query = "; ".join([query] * len(parameters_seq))
466-
467-
await self._close_rowset_and_reset()
468-
self._row_set = InMemoryAsyncRowSet()
469-
470-
try:
471-
if parameter_style == ParameterStyle.FB_NUMERIC:
472-
flattened_params: List[ParameterType] = []
473-
for param_set in parameters_seq:
474-
flattened_params.extend(param_set)
475-
476-
Cursor._log_query(concatenated_query)
477-
timeout_controller = TimeoutController(timeout_seconds)
478-
timeout_controller.raise_if_timeout()
479-
480-
query_params = self._build_fb_numeric_query_params(
481-
[flattened_params],
482-
streaming=False,
483-
async_execution=False,
484-
extra_params={"merge_prepared_statement_batches": "true"},
485-
)
486-
487-
resp = await self._api_request(
488-
concatenated_query,
489-
query_params,
490-
timeout=timeout_controller.remaining(),
491-
)
492-
await self._raise_if_error(resp)
493-
await self._parse_response_headers(resp.headers)
494-
await self._append_row_set_from_response(resp)
495-
else:
496-
formatted_queries = []
497-
statements = parse_sql(query)
498-
for param_set in parameters_seq:
499-
formatted_query = self._formatter.format_statement(
500-
statements[0], param_set
501-
)
502-
formatted_queries.append(formatted_query)
503-
504-
concatenated_query = "; ".join(formatted_queries)
505-
506-
query_params = {
507-
"output_format": self._get_output_format(False),
508-
"merge_prepared_statement_batches": "true",
509-
}
510-
511-
Cursor._log_query(concatenated_query)
512-
resp = await self._api_request(
513-
concatenated_query, query_params, timeout=timeout_seconds
514-
)
515-
await self._raise_if_error(resp)
516-
await self._parse_response_headers(resp.headers)
517-
await self._append_row_set_from_response(resp)
518-
519-
self._state = CursorState.DONE
520-
except Exception:
521-
self._state = CursorState.ERROR
522-
raise
523-
426+
await self._do_execute(
427+
query, parameters_seq, timeout=timeout_seconds, bulk_insert=bulk_insert
428+
)
524429
return self.rowcount
525430

526431
@check_not_closed

src/firebolt/common/cursor/statement_planners.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
JSON_LINES_OUTPUT_FORMAT,
1313
JSON_OUTPUT_FORMAT,
1414
)
15-
from firebolt.utils.exception import FireboltError, ProgrammingError
15+
from firebolt.utils.exception import ConfigurationError, FireboltError, ProgrammingError
1616

1717
if TYPE_CHECKING:
1818
from firebolt.common.statement_formatter import StatementFormatter
@@ -44,6 +44,7 @@ def create_execution_plan(
4444
skip_parsing: bool = False,
4545
async_execution: bool = False,
4646
streaming: bool = False,
47+
bulk_insert: bool = False,
4748
) -> ExecutionPlan:
4849
"""Create an execution plan for the given statement and parameters."""
4950

@@ -65,13 +66,32 @@ def create_execution_plan(
6566
skip_parsing: bool = False,
6667
async_execution: bool = False,
6768
streaming: bool = False,
69+
bulk_insert: bool = False,
6870
) -> ExecutionPlan:
6971
"""Create execution plan for fb_numeric parameter style."""
70-
query_params = self._build_fb_numeric_query_params(
71-
parameters, streaming, async_execution
72-
)
72+
if bulk_insert:
73+
# Validate bulk_insert requirements
74+
query_normalized = raw_query.lstrip().lower()
75+
if not query_normalized.startswith("insert"):
76+
raise ConfigurationError("bulk_insert is only supported for INSERT statements")
77+
if ";" in raw_query.strip().rstrip(";"):
78+
raise ConfigurationError("bulk_insert does not support multi-statement queries")
79+
if not parameters:
80+
raise ConfigurationError("bulk_insert requires at least one parameter set")
81+
82+
# Prepare bulk insert query and parameters
83+
processed_query, processed_params = self._prepare_bulk_insert(raw_query, parameters)
84+
query_params = self._build_fb_numeric_query_params(
85+
processed_params, streaming, async_execution, {"merge_prepared_statement_batches": "true"}
86+
)
87+
else:
88+
processed_query = raw_query
89+
query_params = self._build_fb_numeric_query_params(
90+
parameters, streaming, async_execution
91+
)
92+
7393
return ExecutionPlan(
74-
queries=[raw_query],
94+
queries=[processed_query],
7595
query_params=query_params,
7696
is_multi_statement=False,
7797
async_execution=async_execution,
@@ -83,6 +103,7 @@ def _build_fb_numeric_query_params(
83103
parameters: Sequence[Sequence[ParameterType]],
84104
streaming: bool,
85105
async_execution: bool,
106+
extra_params: Optional[Dict[str, Any]] = None,
86107
) -> Dict[str, Any]:
87108
"""Build query parameters for fb_numeric style."""
88109
param_list = parameters[0] if parameters else []
@@ -101,8 +122,35 @@ def _build_fb_numeric_query_params(
101122
query_params["query_parameters"] = json.dumps(query_parameters)
102123
if async_execution:
103124
query_params["async"] = True
125+
if extra_params:
126+
query_params.update(extra_params)
104127
return query_params
105128

129+
def _prepare_bulk_insert(
130+
self, query: str, parameters_seq: Sequence[Sequence[ParameterType]]
131+
) -> tuple[str, Sequence[Sequence[ParameterType]]]:
132+
"""Execute multiple INSERT queries as a single batch."""
133+
if not parameters_seq:
134+
raise ProgrammingError("bulk_insert requires at least one parameter set")
135+
136+
# For bulk insert, we need to create unique parameter names for each INSERT
137+
# Example: ($1, $2); ($3, $4); ($5, $6) instead of ($1, $2); ($1, $2); ($1, $2)
138+
queries = []
139+
param_offset = 0
140+
for param_set in parameters_seq:
141+
# Replace parameter placeholders with unique numbers
142+
modified_query = query
143+
for i in range(len(param_set)):
144+
old_param = f"${i + 1}"
145+
new_param = f"${param_offset + i + 1}"
146+
modified_query = modified_query.replace(old_param, new_param)
147+
queries.append(modified_query)
148+
param_offset += len(param_set)
149+
150+
combined_query = "; ".join(queries)
151+
parameters = [param for param_set in parameters_seq for param in param_set]
152+
return combined_query, [parameters]
153+
106154

107155
class QmarkStatementPlanner(BaseStatementPlanner):
108156
"""Statement planner for qmark parameter style."""
@@ -114,8 +162,13 @@ def create_execution_plan(
114162
skip_parsing: bool = False,
115163
async_execution: bool = False,
116164
streaming: bool = False,
165+
bulk_insert: bool = False,
117166
) -> ExecutionPlan:
118167
"""Create execution plan for qmark parameter style."""
168+
# Validate bulk_insert is not used with qmark
169+
if bulk_insert:
170+
raise ConfigurationError("bulk_insert is only supported for fb_numeric")
171+
119172
queries: List[Union[SetParameter, str]] = (
120173
[raw_query]
121174
if skip_parsing

src/firebolt/db/cursor.py

Lines changed: 6 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
TimeoutException,
2525
codes,
2626
)
27-
from sqlparse import parse as parse_sql # type: ignore
2827

2928
from firebolt.client import Client, ClientV1, ClientV2
3029
from firebolt.common._types import ColType, ParameterType, SetParameter
@@ -56,6 +55,7 @@
5655
from firebolt.common.statement_formatter import create_statement_formatter
5756
from firebolt.utils.cache import ConnectionInfo, DatabaseInfo, EngineInfo
5857
from firebolt.utils.exception import (
58+
ConfigurationError,
5959
EngineNotRunningError,
6060
FireboltDatabaseError,
6161
FireboltError,
@@ -225,6 +225,7 @@ def _do_execute(
225225
timeout: Optional[float] = None,
226226
async_execution: bool = False,
227227
streaming: bool = False,
228+
bulk_insert: bool = False,
228229
) -> None:
229230
self._close_rowset_and_reset()
230231
self._row_set = StreamingRowSet() if streaming else InMemoryRowSet()
@@ -237,7 +238,7 @@ def _do_execute(
237238
)
238239

239240
plan = statement_planner.create_execution_plan(
240-
raw_query, parameters, skip_parsing, async_execution, streaming
241+
raw_query, parameters, skip_parsing, async_execution, streaming, bulk_insert
241242
)
242243
self._execute_plan(plan, timeout)
243244
self._state = CursorState.DONE
@@ -424,103 +425,9 @@ def executemany(
424425
Returns:
425426
int: Query row count.
426427
"""
427-
if bulk_insert:
428-
return self._executemany_bulk_insert(query, parameters_seq, timeout_seconds)
429-
self._do_execute(query, parameters_seq, timeout=timeout_seconds)
430-
return self.rowcount
431-
432-
def _validate_bulk_insert_query(self, query: str) -> None:
433-
"""Validate that query is an INSERT statement for bulk_insert."""
434-
query_normalized = query.lstrip().lower()
435-
436-
if not query_normalized.startswith("insert"):
437-
raise ProgrammingError(
438-
"bulk_insert is only supported for INSERT statements"
439-
)
440-
441-
if ";" in query.strip().rstrip(";"):
442-
raise ProgrammingError(
443-
"bulk_insert does not support multi-statement queries"
444-
)
445-
446-
def _executemany_bulk_insert(
447-
self,
448-
query: str,
449-
parameters_seq: Sequence[Sequence[ParameterType]],
450-
timeout_seconds: Optional[float],
451-
) -> int:
452-
"""Execute multiple INSERT queries as a single batch."""
453-
self._validate_bulk_insert_query(query)
454-
455-
if not parameters_seq:
456-
raise ProgrammingError("bulk_insert requires at least one parameter set")
457-
458-
from firebolt.db import paramstyle
459-
460-
try:
461-
parameter_style = ParameterStyle(paramstyle)
462-
except ValueError:
463-
raise ProgrammingError(f"Unsupported paramstyle: {paramstyle}")
464-
465-
concatenated_query = "; ".join([query] * len(parameters_seq))
466-
467-
self._close_rowset_and_reset()
468-
self._row_set = InMemoryRowSet()
469-
470-
try:
471-
if parameter_style == ParameterStyle.FB_NUMERIC:
472-
flattened_params: List[ParameterType] = []
473-
for param_set in parameters_seq:
474-
flattened_params.extend(param_set)
475-
476-
Cursor._log_query(concatenated_query)
477-
timeout_controller = TimeoutController(timeout_seconds)
478-
timeout_controller.raise_if_timeout()
479-
480-
query_params = self._build_fb_numeric_query_params(
481-
[flattened_params],
482-
streaming=False,
483-
async_execution=False,
484-
extra_params={"merge_prepared_statement_batches": "true"},
485-
)
486-
487-
resp = self._api_request(
488-
concatenated_query,
489-
query_params,
490-
timeout=timeout_controller.remaining(),
491-
)
492-
self._raise_if_error(resp)
493-
self._parse_response_headers(resp.headers)
494-
self._append_row_set_from_response(resp)
495-
else:
496-
formatted_queries = []
497-
statements = parse_sql(query)
498-
for param_set in parameters_seq:
499-
formatted_query = self._formatter.format_statement(
500-
statements[0], param_set
501-
)
502-
formatted_queries.append(formatted_query)
503-
504-
concatenated_query = "; ".join(formatted_queries)
505-
506-
query_params = {
507-
"output_format": self._get_output_format(False),
508-
"merge_prepared_statement_batches": "true",
509-
}
510-
511-
Cursor._log_query(concatenated_query)
512-
resp = self._api_request(
513-
concatenated_query, query_params, timeout=timeout_seconds
514-
)
515-
self._raise_if_error(resp)
516-
self._parse_response_headers(resp.headers)
517-
self._append_row_set_from_response(resp)
518-
519-
self._state = CursorState.DONE
520-
except Exception:
521-
self._state = CursorState.ERROR
522-
raise
523-
428+
self._do_execute(
429+
query, parameters_seq, timeout=timeout_seconds, bulk_insert=bulk_insert
430+
)
524431
return self.rowcount
525432

526433
@check_not_closed

0 commit comments

Comments
 (0)