Skip to content

Commit 8645d65

Browse files
def-claude
andauthored
adapter: bound COPY FROM STDIN blocking-pool usage (#37037)
setup_copy_from_stdin spawned available_parallelism() workers on the shared tokio blocking pool eagerly at COPY start -- before any data -- and each worker held its blocking thread for the whole COPY by driving an idle batch_rx.recv().await via block_on. With no cap on worker count or concurrent COPYs, a low-privilege client could open ~512/cores idle, zero-byte COPY ... FROM STDIN connections and pin the entire 512-thread pool, stalling every other blocking-pool user (e.g. the mandatory "optimize peek" stage of any SELECT) process-wide until the client disconnected. Run each worker as a regular async task instead, so it holds no thread while parked waiting for the next chunk; offload only the CPU-bound per-chunk decode to the blocking pool for the duration of that decode. Also cap workers per COPY at COPY_FROM_STDIN_MAX_WORKERS (8) so one COPY cannot reserve an unbounded share of the pool even while actively decoding. Closes SQL-372. --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent f744141 commit 8645d65

2 files changed

Lines changed: 207 additions & 67 deletions

File tree

src/adapter/src/coord/sequencer/inner/copy_from.rs

Lines changed: 102 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ use crate::{AdapterError, ExecuteContext, ExecuteResponse};
4343
/// unbounded in-memory growth in a single giant batch.
4444
const COPY_FROM_STDIN_MAX_BATCH_BYTES: usize = 32 * 1024 * 1024;
4545

46+
/// Cap on the number of parallel decode workers spawned per COPY FROM STDIN.
47+
/// A single network-bound stream sees marginal gains past a handful of
48+
/// decoders, and capping bounds how much of the blocking pool any one COPY can
49+
/// occupy while actively decoding.
50+
const COPY_FROM_STDIN_MAX_WORKERS: usize = 8;
51+
4652
impl Coordinator {
4753
pub(crate) async fn sequence_copy_from(
4854
&mut self,
@@ -415,10 +421,14 @@ impl Coordinator {
415421
.collect::<Vec<_>>()
416422
.into();
417423

418-
// Determine number of parallel workers.
419-
let num_workers = std::thread::available_parallelism()
420-
.map(|n| n.get())
421-
.unwrap_or(1);
424+
// Determine number of parallel workers, capped so that a single COPY
425+
// cannot reserve an unbounded share of the shared blocking pool.
426+
let num_workers = std::cmp::min(
427+
std::thread::available_parallelism()
428+
.map(|n| n.get())
429+
.unwrap_or(1),
430+
COPY_FROM_STDIN_MAX_WORKERS,
431+
);
422432
tracing::info!(
423433
%target_id, num_workers,
424434
"starting parallel COPY FROM STDIN batch builders"
@@ -430,11 +440,12 @@ impl Coordinator {
430440
let collection_desc = Arc::new(collection_desc);
431441
let persist_client = self.persist_client.clone();
432442

433-
// Create per-worker channels and spawn workers on blocking threads.
434-
// Each worker does CPU-intensive TSV decoding + columnar encoding,
435-
// so they need dedicated OS threads (not tokio async tasks) for
436-
// true parallelism.
437-
let rt_handle = tokio::runtime::Handle::current();
443+
// Create per-worker channels and spawn one async task per worker. Each
444+
// worker offloads the CPU-intensive processing of a chunk (decode plus
445+
// the per-row transform/constraint-check/columnar encode) to the
446+
// blocking pool for the duration of that chunk (see
447+
// `copy_from_stdin_batch_builder`), so workers run in parallel while
448+
// doing CPU work but hold no thread while idle between chunks.
438449
let mut batch_txs = Vec::with_capacity(num_workers);
439450
let mut worker_handles = Vec::with_capacity(num_workers);
440451

@@ -464,24 +475,21 @@ impl Coordinator {
464475
// Only worker 0 receives the first chunk (round-robin), so only
465476
// it needs to skip the CSV header on its first chunk.
466477
let skip_header_on_first_chunk = worker_id == 0 && first_chunk_has_header;
467-
let rt = rt_handle.clone();
468478

469-
let handle = mz_ore::task::spawn_blocking(
479+
let handle = mz_ore::task::spawn(
470480
|| format!("copy_from_stdin_worker:{target_id}:{worker_id}"),
471-
move || {
472-
rt.block_on(Self::copy_from_stdin_batch_builder(
473-
persist_client,
474-
shard_id,
475-
collection_id,
476-
collection_desc,
477-
target_desc,
478-
column_transform,
479-
column_types,
480-
params,
481-
skip_header_on_first_chunk,
482-
batch_rx,
483-
))
484-
},
481+
Self::copy_from_stdin_batch_builder(
482+
persist_client,
483+
shard_id,
484+
collection_id,
485+
collection_desc,
486+
target_desc,
487+
column_transform,
488+
column_types,
489+
params,
490+
skip_header_on_first_chunk,
491+
batch_rx,
492+
),
485493
);
486494
worker_handles.push(handle);
487495
}
@@ -555,10 +563,11 @@ impl Coordinator {
555563
let mut batch_bytes: usize = 0;
556564
let mut proto_batches = Vec::new();
557565

566+
let rt = tokio::runtime::Handle::current();
558567
let mut is_first_chunk = true;
559568
while let Some(raw_bytes) = batch_rx.recv().await {
560-
// Decode raw bytes into rows. For the first chunk of worker 0,
561-
// re-enable header skipping so the real CSV header line is skipped.
569+
// For the first chunk of worker 0, re-enable header skipping so the
570+
// real CSV header line is skipped.
562571
let chunk_params = if is_first_chunk && skip_header_on_first_chunk {
563572
let mut p = params.clone();
564573
if let CopyFormatParams::Csv(ref mut csv) = p {
@@ -569,34 +578,73 @@ impl Coordinator {
569578
params.clone()
570579
};
571580
is_first_chunk = false;
572-
let rows = mz_pgcopy::decode_copy_format(&raw_bytes, &column_types, chunk_params)
573-
.map_err(|e| AdapterError::CopyFormatError(e.to_string()))?;
574-
575-
for row in rows {
576-
// Apply column transform if needed (add defaults, reorder).
577-
let full_row = if let Some(ref transform) = *column_transform {
578-
transform.apply(&row)
579-
} else {
580-
row
581-
};
582-
583-
// Check constraints.
584-
for (i, datum) in full_row.iter().enumerate() {
585-
target_desc.constraints_met(i, &datum).map_err(|e| {
586-
AdapterError::Unstructured(anyhow::anyhow!("constraint violation: {e}"))
587-
})?;
588-
}
589-
590-
let data = SourceData(Ok(full_row));
591-
batch_builder
592-
.add(&data, &(), &lower, &1)
593-
.await
594-
.map_err(|e| AdapterError::Unstructured(anyhow::anyhow!("persist add: {e}")))?;
595-
row_count += 1;
596-
row_count_in_batch += 1;
597-
}
581+
let raw_len = raw_bytes.len();
582+
583+
// Offload the entire CPU-bound per-chunk pipeline -- decode, column
584+
// transform, constraint checks, and the columnar persist encode
585+
// (`BatchBuilder::add` -> `PartBuilder::push`) -- to the blocking
586+
// pool. There is no yield point in the row loop until a batch fills
587+
// (`add` only awaits `flush_part`, and only once an *encoded* part
588+
// reaches `blob_target_size`, far beyond the 32 MiB *raw* batch
589+
// boundary), so left on the async runtime each chunk's rows would
590+
// run as one uninterrupted burst on a shared runtime worker thread,
591+
// starving other connections. The blocking thread is held only
592+
// while a chunk is in flight and released back to the pool between
593+
// chunks (during `recv().await`), so idle workers still hold no
594+
// thread. `block_on` is invoked once per chunk -- not per row -- to
595+
// drive the row loop and the rare `flush_part` it may await.
596+
let chunk_column_types = Arc::clone(&column_types);
597+
let chunk_transform = Arc::clone(&column_transform);
598+
let chunk_target_desc = Arc::clone(&target_desc);
599+
let chunk_rt = rt.clone();
600+
let (returned_builder, added_rows) = mz_ore::task::spawn_blocking(
601+
|| "copy_from_stdin_process_chunk",
602+
move || {
603+
let rows = mz_pgcopy::decode_copy_format(
604+
&raw_bytes,
605+
&chunk_column_types,
606+
chunk_params,
607+
)
608+
.map_err(|e| AdapterError::CopyFormatError(e.to_string()))?;
609+
610+
chunk_rt.block_on(async move {
611+
let mut added: u64 = 0;
612+
for row in rows {
613+
// Apply column transform if needed (add defaults, reorder).
614+
let full_row = if let Some(ref transform) = *chunk_transform {
615+
transform.apply(&row)
616+
} else {
617+
row
618+
};
619+
620+
// Check constraints.
621+
for (i, datum) in full_row.iter().enumerate() {
622+
chunk_target_desc.constraints_met(i, &datum).map_err(|e| {
623+
AdapterError::Unstructured(anyhow::anyhow!(
624+
"constraint violation: {e}"
625+
))
626+
})?;
627+
}
628+
629+
let data = SourceData(Ok(full_row));
630+
batch_builder
631+
.add(&data, &(), &lower, &1)
632+
.await
633+
.map_err(|e| {
634+
AdapterError::Unstructured(anyhow::anyhow!("persist add: {e}"))
635+
})?;
636+
added += 1;
637+
}
638+
Ok::<_, AdapterError>((batch_builder, added))
639+
})
640+
},
641+
)
642+
.await?;
643+
batch_builder = returned_builder;
644+
row_count += added_rows;
645+
row_count_in_batch += added_rows;
598646

599-
batch_bytes = batch_bytes.saturating_add(raw_bytes.len());
647+
batch_bytes = batch_bytes.saturating_add(raw_len);
600648
if batch_bytes >= COPY_FROM_STDIN_MAX_BATCH_BYTES {
601649
let batch = batch_builder.finish(upper.clone()).await.map_err(|e| {
602650
AdapterError::Unstructured(anyhow::anyhow!("persist finish: {e}"))

test/copy/mzcompose.py

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import random
1919
import string
20+
import threading
2021
import time
2122
from io import BytesIO, StringIO
2223
from textwrap import dedent
@@ -201,12 +202,24 @@ def workflow_nightly(c: Composition, parser: WorkflowArgumentParser) -> None:
201202

202203
def workflow_ci(c: Composition, _parser: WorkflowArgumentParser) -> None:
203204
"""
204-
Workflows to run during CI
205+
Run all workflows during CI.
206+
207+
Every workflow is run except for the exceptions below, so that a newly
208+
added regression test gets CI coverage automatically instead of silently
209+
needing to be added to a hand-maintained allowlist:
210+
- "default": meta-workflow that runs everything (would recurse).
211+
- "ci": this workflow itself (would recurse).
212+
- "nightly": heavy TPC-H suite run separately via the `nightly` pipeline
213+
step (`run: nightly`), not here.
205214
"""
206-
for name in ["auth", "http", "copy-from-csv-header", "copy-from-ssrf-redirect"]:
215+
excluded = {"default", "ci", "nightly"}
216+
217+
def process(name: str) -> None:
207218
with c.test_case(name):
208219
c.workflow(name)
209220

221+
c.test_parts([name for name in c.workflows.keys() if name not in excluded], process)
222+
210223

211224
def workflow_auth(c: Composition) -> None:
212225
c.up(Service("mc", idle=True), "materialized", "minio")
@@ -373,18 +386,18 @@ def workflow_test_column_dedup(c: Composition):
373386
c.testdrive(dedent("""
374387
$ postgres-execute connection=postgres://mz_system:materialize@${testdrive.materialize-internal-sql-addr}
375388
376-
> CREATE SECRET aws_secret AS '${arg.aws-secret-access-key}'
377-
> CREATE CONNECTION aws_conn
389+
> CREATE SECRET aws_secret_column_dedup AS '${arg.aws-secret-access-key}'
390+
> CREATE CONNECTION aws_conn_column_dedup
378391
TO AWS (
379392
ACCESS KEY ID = '${arg.aws-access-key-id}',
380-
SECRET ACCESS KEY = SECRET aws_secret,
393+
SECRET ACCESS KEY = SECRET aws_secret_column_dedup,
381394
ENDPOINT = '${arg.aws-endpoint}',
382395
REGION = 'us-east-1'
383396
)
384397
385398
> COPY (SELECT 1::int4 AS a, 2::int4 AS a, 3::int4 AS a2, 4::int4 AS a)
386399
TO 's3://copytos3/test/column_dedup/'
387-
WITH (AWS CONNECTION = aws_conn, FORMAT = 'parquet');
400+
WITH (AWS CONNECTION = aws_conn_column_dedup, FORMAT = 'parquet');
388401
389402
$ s3-verify-data bucket=copytos3 key=test/column_dedup
390403
1 2 3 4
@@ -405,17 +418,17 @@ def workflow_test_github_9627(c: Composition):
405418
> CREATE TABLE t (a int)
406419
> INSERT INTO t VALUES (1)
407420
408-
> CREATE SECRET aws_secret AS '${arg.aws-secret-access-key}'
409-
> CREATE CONNECTION aws_conn
421+
> CREATE SECRET aws_secret_github_9627 AS '${arg.aws-secret-access-key}'
422+
> CREATE CONNECTION aws_conn_github_9627
410423
TO AWS (
411424
ACCESS KEY ID = '${arg.aws-access-key-id}',
412-
SECRET ACCESS KEY = SECRET aws_secret,
425+
SECRET ACCESS KEY = SECRET aws_secret_github_9627,
413426
ENDPOINT = '${arg.aws-endpoint}',
414427
REGION = 'us-east-1'
415428
)
416429
417430
> COPY (SELECT * FROM t) TO 's3://copytos3/test/github_9627/'
418-
WITH (AWS CONNECTION = aws_conn, FORMAT = 'csv');
431+
WITH (AWS CONNECTION = aws_conn_github_9627, FORMAT = 'csv');
419432
"""))
420433

421434
# Check that the table's read frontier still advances.
@@ -533,7 +546,7 @@ def workflow_copy_from_csv_quoted_null(c: Composition) -> None:
533546
with cur.copy("COPY csv_null_default FROM STDIN WITH (FORMAT CSV)") as copy:
534547
copy.write('a,\nb,""\n"",c\n')
535548

536-
cur.execute("SELECT a, b FROM csv_null_default ORDER BY a NULLS LAST")
549+
cur.execute("SELECT a, b FROM csv_null_default ORDER BY a IS NULL, a = '', a")
537550
rows = cur.fetchall()
538551
assert rows == [
539552
("a", None),
@@ -549,7 +562,7 @@ def workflow_copy_from_csv_quoted_null(c: Composition) -> None:
549562
) as copy:
550563
copy.write('a,NULL\nb,"NULL"\nNULL,c\n')
551564

552-
cur.execute("SELECT a, b FROM csv_null_custom ORDER BY a NULLS LAST")
565+
cur.execute("SELECT a, b FROM csv_null_custom ORDER BY a IS NULL, a = '', a")
553566
rows = cur.fetchall()
554567
assert rows == [
555568
("a", None),
@@ -626,7 +639,8 @@ def workflow_copy_from_csv_crlf(c: Composition) -> None:
626639
) as copy:
627640
copy.write(f'a,{eol}b,""{eol}"",c{eol}')
628641
cur.execute(
629-
f"SELECT a, b FROM csv_{label}_null ORDER BY a NULLS LAST".encode()
642+
f"SELECT a, b FROM csv_{label}_null "
643+
"ORDER BY a IS NULL, a = '', a".encode()
630644
)
631645
rows = cur.fetchall()
632646
assert rows == [
@@ -700,3 +714,81 @@ def workflow_copy_from_csv_crlf_large_end_marker(c: Composition) -> None:
700714
f"expected count={rows_each_side}, max_id={rows_each_side - 1} "
701715
"(rows after the bare \\. leaked through parallel workers)"
702716
)
717+
718+
719+
# Must satisfy _NUM_IDLE_SESSIONS * effective_cores >= 512 (blocking-pool cap) to
720+
# re-starve SELECT 1 on a regression; 256 holds margin below the 4-core agent.
721+
_NUM_IDLE_SESSIONS = 256
722+
_SELECT_TIMEOUT_S = 30.0
723+
724+
725+
def _select_1_responsive(c: Composition, timeout_s: float) -> bool:
726+
box: dict[str, object] = {}
727+
728+
def run() -> None:
729+
try:
730+
conn = c.sql_connection()
731+
try:
732+
with conn.cursor() as cur:
733+
cur.execute("SELECT 1")
734+
cur.fetchall()
735+
box["ok"] = True
736+
finally:
737+
conn.close()
738+
except Exception as e:
739+
box["err"] = e
740+
741+
t = threading.Thread(target=run, daemon=True)
742+
t.start()
743+
t.join(timeout_s)
744+
if t.is_alive():
745+
return False
746+
if box.get("ok"):
747+
return True
748+
raise AssertionError(f"SELECT 1 probe errored unexpectedly: {box.get('err')!r}")
749+
750+
751+
def _open_idle_copy(c: Composition) -> tuple:
752+
conn = c.sql_connection()
753+
cur = conn.cursor()
754+
cm = cur.copy("COPY copy_idle_target FROM STDIN")
755+
cm.__enter__()
756+
return (conn, cur, cm)
757+
758+
759+
def _close_idle_copies(held: list) -> None:
760+
for conn, _cur, cm in held:
761+
try:
762+
conn.close()
763+
except Exception:
764+
pass
765+
gen = getattr(cm, "gen", None)
766+
if gen is not None:
767+
try:
768+
gen.close()
769+
except Exception:
770+
pass
771+
held.clear()
772+
773+
774+
def workflow_copy_from_stdin_many_idle_sessions(c: Composition) -> None:
775+
"""Many idle COPY FROM STDIN sessions must not prevent other queries from
776+
running."""
777+
c.up("materialized")
778+
779+
setup_conn = c.sql_connection()
780+
with setup_conn.cursor() as cur:
781+
cur.execute("DROP TABLE IF EXISTS copy_idle_target")
782+
cur.execute("CREATE TABLE copy_idle_target (a int4)")
783+
setup_conn.close()
784+
785+
held: list[tuple] = []
786+
try:
787+
for _ in range(_NUM_IDLE_SESSIONS):
788+
held.append(_open_idle_copy(c))
789+
assert _select_1_responsive(c, _SELECT_TIMEOUT_S), (
790+
f"SELECT 1 did not return within {_SELECT_TIMEOUT_S}s while "
791+
f"{len(held)} idle COPY FROM STDIN sessions were open"
792+
)
793+
finally:
794+
_close_idle_copies(held)

0 commit comments

Comments
 (0)