Skip to content

Hang with StreamingDataLoader + ParquetLoader(low_memory=False) on Linux — fork() + Polars interaction #823

@Mengyan1

Description

@Mengyan1

Summary

When StreamingDataset is configured with item_loader=ParquetLoader(low_memory=False) and iterated through StreamingDataLoader(num_workers > 0) on Linux, iteration hangs on the first batch. The DataLoader main thread sits in multiprocessing.connection.poll and the worker subprocesses are stuck inside polars.LazyFrame.collect().

This looks like the documented incompatibility between os.fork() (PyTorch DataLoader's default multiprocessing_context on Linux) and Polars's Rust/Rayon thread pool:

In low_memory=False mode, ParquetLoader._get_item() calls pl.scan_parquet(chunk).collect() inside each DataLoader worker, which is the call pattern the Polars issue describes. Passing multiprocessing_context=mp.get_context("spawn") to StreamingDataLoader resolves it.

Minimal reproduction

Self-contained, no cloud access needed. Reliably reproduces on the versions listed at the bottom.

#!/usr/bin/env python3
"""Reproduce litdata + Polars + fork() DataLoader hang with low_memory=False."""
import multiprocessing as mp
import os
import shutil
import signal
import sys
import time

DATA_DIR = "/tmp/litdata_polars_fork_repro"
TIMEOUT_SEC = 90


def main() -> None:
    mode = sys.argv[1] if len(sys.argv) > 1 else "fork"
    assert mode in ("fork", "spawn", "forkserver"), f"unknown mode: {mode!r}"

    import numpy as np
    import pyarrow as pa
    import pyarrow.parquet as pq
    import polars as pl
    import litdata as ld
    from litdata.streaming.item_loader import ParquetLoader

    print(f"[env] python={sys.version.split()[0]}  "
          f"default_mp_start_method={mp.get_start_method(allow_none=False)!r}  "
          f"litdata={ld.__version__}  polars={pl.__version__}")

    if os.path.exists(DATA_DIR):
        shutil.rmtree(DATA_DIR)
    os.makedirs(DATA_DIR, exist_ok=True)

    rng = np.random.default_rng(0)
    n_files, rows_per_file, seq_len = 8, 50_000, 1024
    for i in range(n_files):
        tbl = pa.table({
            "ids":    rng.integers(0, 32_000, size=(rows_per_file, seq_len), dtype=np.int32).tolist(),
            "labels": rng.integers(0, 2, size=rows_per_file, dtype=np.int64),
            "score":  rng.random(rows_per_file, dtype=np.float32),
            "tag":    [f"tag_{j}" for j in range(rows_per_file)],
        })
        pq.write_table(tbl, os.path.join(DATA_DIR, f"part-{i}.parquet"))

    ld.index_parquet_dataset(DATA_DIR)

    # Touch Polars in the parent first — mirrors what happens when a validation
    # sanity check runs before training workers are forked.
    _ = pl.scan_parquet(os.path.join(DATA_DIR, "part-0.parquet")).collect()

    dataset = ld.StreamingDataset(
        DATA_DIR,
        item_loader=ParquetLoader(low_memory=False),
    )

    loader_kwargs: dict = {"batch_size": 8, "num_workers": 4}
    if mode != "fork":
        loader_kwargs["multiprocessing_context"] = mp.get_context(mode)

    loader = ld.StreamingDataLoader(dataset, **loader_kwargs)
    t0 = time.time()
    for i, batch in enumerate(loader):
        print(f"    got batch {i}  ({time.time() - t0:.1f}s)")
        if i >= 5:
            break
    print(f"[OK] completed normally in {time.time() - t0:.1f}s")


if __name__ == "__main__":
    def _on_timeout(signum, frame):
        sys.stderr.write(f"\n[!] TIMED OUT after {TIMEOUT_SEC}s -- this is the hang.\n")
        os._exit(124)

    signal.signal(signal.SIGALRM, _on_timeout)
    signal.alarm(TIMEOUT_SEC)
    main()

Observed

Default (fork) — hangs, script's 90 s timeout fires:

$ python /tmp/repro_litdata_polars_fork.py
[env] python=3.11.11  default_mp_start_method='fork'  litdata=0.2.61  polars=1.33.1
[+] created 8 parquet files (775.3 MB total) under /tmp/litdata_polars_fork_repro
Indexing progress: 100%|██████████████████████| 8/8 [00:00<00:00, 2637.72step/s]
[+] indexed parquet dataset
[+] warmed up Polars in parent process (initializes Rayon pool)
You have set low_memory=False in ParquetLoader. This may result in high memory usage when processing large Parquet chunk files. Consider setting low_memory=True to reduce memory consumption.
[+] using PyTorch DataLoader default mp context (== 'fork' on this OS)
[+] iterating loader (timeout in 90s)...

[!] TIMED OUT after 90s -- this is the hang.

With multiprocessing_context="spawn" — completes in ~4 s:

$ python /tmp/repro_litdata_polars_fork.py spawn
[env] python=3.11.11  default_mp_start_method='fork'  litdata=0.2.61  polars=1.33.1
[+] created 8 parquet files (775.3 MB total) under /tmp/litdata_polars_fork_repro
Indexing progress: 100%|██████████████████████| 8/8 [00:00<00:00, 2686.50step/s]
[+] indexed parquet dataset
[+] warmed up Polars in parent process (initializes Rayon pool)
You have set low_memory=False in ParquetLoader. This may result in high memory usage when processing large Parquet chunk files. Consider setting low_memory=True to reduce memory consumption.
[+] using multiprocessing_context='spawn'
[+] iterating loader (timeout in 90s)...
    got batch 0  (2.2s)
    got batch 1  (3.1s)
    got batch 2  (3.1s)
    got batch 3  (3.1s)
    got batch 4  (3.1s)
    got batch 5  (3.8s)
[OK] completed normally in 4.2s

Stack traces from the original hang

From an 8-rank DDP training job. After ~480 s the PyTorch NCCL watchdog fires on ranks 1–7 (waiting at a gradient AllReduce that rank 0 never reached); rank 0's watchdog is silent because it has no pending NCCL op.

DataLoader main thread on rank 0 (via py-spy dump):

select (selectors.py:415)
wait (multiprocessing/connection.py:948)
poll (multiprocessing/connection.py:257)
get (multiprocessing/queues.py:113)
_try_get_data (torch/utils/data/dataloader.py:1275)
_get_data (torch/utils/data/dataloader.py:1444)
_next_data (torch/utils/data/dataloader.py:1482)
__next__ (torch/utils/data/dataloader.py:732)
__iter__ (litdata/streaming/dataloader.py:675)
__next__ (lightning/pytorch/utilities/combined_loader.py:341)
__next__ (lightning/pytorch/loops/fetchers.py:134)
advance (lightning/pytorch/loops/training_epoch_loop.py:311)
...
fit (lightning/pytorch/trainer/trainer.py:584)

DataLoader worker subprocess (all four workers identical):

collect (polars/lazyframe/frame.py:2407)              ← stuck inside Polars
wrapper (polars/lazyframe/opt_flags.py:330)
wrapper (polars/_utils/deprecation.py:97)
_get_item (litdata/streaming/item_loader.py:776)      ← pl.scan_parquet(...).collect()
load_item_from_chunk (litdata/streaming/item_loader.py:697)
read (litdata/streaming/reader.py:460)
__getitem__ (litdata/streaming/cache.py:155)
__getitem__ (litdata/streaming/dataset.py:494)
__next__ (litdata/streaming/dataset.py:556)
fetch (torch/utils/data/_utils/fetch.py:33)
_worker_loop (torch/utils/data/_utils/worker.py:349)

Root cause (our reading)

  1. ParquetLoader._get_item() (the low_memory=False path) calls into Polars in each DataLoader worker — see src/litdata/streaming/item_loader.py:
    self._df[chunk_index] = pl.scan_parquet(chunk_filepath, low_memory=True).collect()
  2. StreamingDataLoader is a thin subclass of torch.utils.data.DataLoader, and on Linux DataLoader defaults multiprocessing_context to fork.
  3. Polars's Rayon thread pool is initialized in the parent process the first time a Polars API is touched. After fork(), the worker inherits a pool struct that references dead thread IDs. The first LazyFrame.collect() in the worker dispatches to that pool and hangs.

From the maintainer reply on pola-rs/polars#24162:

"We don't support using Polars in combination with multiprocessing using fork: https://docs.pola.rs/user-guide/misc/multiprocessing/"

And the Polars docs:

"Polars is multithreaded as to provide strong performance out-of-the-box. Thus, it cannot be combined with fork. … One should use spawn, or forkserver, instead. … Using fork as the method, instead of spawn, will cause a dead lock."

In our testing, low_memory=True (the current default) did not reproduce on the minimal script above — that path uses pyarrow.parquet.read_row_group() + pl.from_arrow() and dispatches less Rayon work. A similar hang may still be possible under heavier load.

Workaround

import multiprocessing as mp
from litdata import StreamingDataLoader
loader = StreamingDataLoader(
    dataset,
    batch_size=...,
    num_workers=...,
    multiprocessing_context=mp.get_context("spawn"),
)

Environment

  • litdata: 0.2.61
  • polars: 1.33.1
  • torch: 2.9.1+cu129
  • pytorch-lightning: 2.6.1
  • Python: 3.11.11
  • OS: Ubuntu 22.04.5 LTS, Linux 6.8.0-1030-gcp x86_64
  • Default multiprocessing start method: fork
  • Hardware: 8-GPU Linux x86_64 node (the minimal repro reproduces with no GPU involvement)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions