Skip to content

[LED][Longformer] Replace for-loop with unfold in _chunk ONNX-export path#46169

Open
guinik wants to merge 2 commits into
huggingface:mainfrom
guinik:led-longformer-onnx-unfold-chunk
Open

[LED][Longformer] Replace for-loop with unfold in _chunk ONNX-export path#46169
guinik wants to merge 2 commits into
huggingface:mainfrom
guinik:led-longformer-onnx-unfold-chunk

Conversation

@guinik

@guinik guinik commented May 23, 2026

Copy link
Copy Markdown

What does this PR do?

Removes a now-obsolete workaround in LEDEncoderSelfAttention._chunk
and the mirrored LongformerSelfAttention._chunk. Both files carry
the same blocking TODO:

# TODO replace this with
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
# once `unfold` is supported

The ONNX-export branch used a Python for loop because torch.unfold
was not supported by the ONNX exporter at the time. PyTorch's ONNX
exporters now handle aten::unfold, so the workaround can go.

Fixes the stale TODO in both modeling_led.py and
modeling_longformer.py — same change applied to both files.

Compatibility note (please read first)

This is technically a behavior change for one narrow export path, so
flagging it up front:

The legacy TorchScript ONNX exporter's unfold symbolic requires the
unfolded dimension to be statically known at export time. After this
change, anyone exporting LED/Longformer under the legacy exporter
with dynamic_axes on the sequence dimension will get a clear error:
"Unfold, input size not accessible".

Affected users have two paths forward:

  1. Export with static sequence shapes under the legacy exporter (works
    without changes).
  2. Switch to the dynamo-based exporter (default in PyTorch 2.9+;
    handles symbolic shapes natively).

The legacy TorchScript exporter is deprecated as of PyTorch 2.9, so
(2) is the documented migration path.

If you'd prefer to preserve a fallback for users still on the legacy
exporter + dynamic shapes
, happy to add one, but I don't think a
PyTorch ≥ 2.9 version check is the right shape for it. Two reasons:

  1. PyTorch 2.9 made dynamo=True the default for
    torch.onnx.export, but the legacy exporter is still callable and
    in active use. Users on 2.9+ pass dynamo=False for real reasons
    (avoiding onnxscript as a transitive dependency, older optimum
    pins, custom export wrappers that still drive the legacy path). A
    version-only check would silently break them.
  2. There's no public PyTorch API I'm aware of to detect mid-trace
    which exporter is currently running.
    torch.onnx.is_in_onnx_export() tells you an export is in flight,
    not which one, so a runtime capability check isn't really
    available either.

The cleanest alternative I can see is an explicit opt-in config flag,
mirroring the existing config.onnx_export flag this function already
reads:

@staticmethod
def _chunk(hidden_states, window_overlap, onnx_export: bool = False, legacy_onnx_chunk: bool = False):
    """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
    if not onnx_export:
        # ... unchanged as_strided path ...
        return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)

    # ONNX export path.
    #
    # PyTorch's modern dynamo-based ONNX exporter handles `aten::unfold`
    # cleanly, including with symbolic shapes. The legacy TorchScript
    # exporter's unfold symbolic requires the unfolded dimension to be
    # statically known at export time, so users on the legacy exporter
    # with `dynamic_axes` on the sequence dim need the slow-loop fallback.
    #
    # There is no public API to detect which exporter is currently
    # tracing, so we expose the choice as an explicit config flag.
    # Default is the fast `unfold` path; users hitting the legacy
    # exporter limitation can set `config.onnx_export_legacy_chunk = True`
    # to restore the original loop.
    if not legacy_onnx_chunk:
        return hidden_states.unfold(
            dimension=1, size=window_overlap * 2, step=window_overlap
        ).transpose(2, 3)

    # Legacy fallback: slow but symbolic-shape-safe under the legacy
    # TorchScript ONNX exporter. Allocates a fresh tensor per call.
    chunk_size = [
        hidden_states.size(0),
        torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
        window_overlap * 2,
        hidden_states.size(2),
    ]
    overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device)
    for chunk in range(chunk_size[1]):
        overlapping_chunks[:, chunk, :, :] = hidden_states[
            :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
        ]
    return overlapping_chunks

Call sites in _sliding_chunks_query_key_matmul would thread the flag
through alongside the existing onnx_export one:

query = self._chunk(
    query,
    window_overlap,
    getattr(self.config, "onnx_export", False),
    getattr(self.config, "onnx_export_legacy_chunk", False),
)

Why this shape:

  • Default behavior is the fast unfold path, so modern users get the
    memory win and cleaner graph with no action required.
  • Users on the legacy exporter with dynamic shapes have a one-line
    escape hatch (config.onnx_export_legacy_chunk = True) instead of a
    hard break.
  • The flag lives next to the existing onnx_export flag, so it's
    consistent with the surrounding code rather than new API surface in
    an unrelated spot.
  • When the legacy TorchScript exporter is eventually removed upstream,
    the flag and the loop can both be dropped without touching the fast
    path.

That said, I don't have strong feelings here, the legacy exporter is
deprecated as of 2.9 and the documented migration is the dynamo
exporter, so a clean break is also defensible. Happy to push the flag
as a follow-up commit on this PR if you want it, or leave the current
diff as-is. Let me know which you prefer.

Why make the change

  • Memory: the for-loop allocates a fresh torch.empty(...) tensor
    on every call. For led-base-16384 production shapes
    (bh=24, seq=16384, w=512) this is ~195 MB per call. The unfold
    replacement is a view that shares storage with the input, allocating
    nothing, matching the zero-copy semantics of the non-ONNX
    as_strided path that already exists in the same function.
  • Clarity: a one-line view replaces a 12-line loop, expressing
    intent directly.
  • Exported graph quality: the resulting ONNX graph contains a
    single view op rather than N scatter writes into a freshly allocated
    tensor (N = 31 at xl shapes), giving downstream runtimes a cleaner
    graph to optimize.

Verification

Verified locally with a parametrized test covering 24 combinations:

  • 4 shape configs: (bh=8, seq=512, w=64), (24, 2048, 256),
    (24, 4096, 512), (24, 16384, 512)
  • ONNX opsets 14, 17, 20
  • Both the legacy TorchScript exporter and the modern dynamo-based
    exporter (default since PyTorch 2.9)

Each case verifies (a) the exported model passes
onnx.checker.check_model, (b) runs under onnxruntime's CPU EP, and
(c) produces output numerically identical (atol=1e-6) to the existing
_chunk(..., onnx_export=True) reference computed on the same input.

The legacy exporter is tested with static shapes per its documented
unfold symbolic limitation; the dynamo exporter is tested with its
native symbolic-shape handling.

Testing code
# test_chunk_onnx_export.py
"""
Verifies that the unfold-based _chunk replacement:
(a) exports cleanly under both the legacy TorchScript exporter and the
    modern dynamo-based exporter, within each exporter's supported envelope
(b) runs in onnxruntime
(c) produces output numerically identical to the PyTorch reference
    (the existing _chunk(..., onnx_export=True) implementation)

Note on exporter coverage:
The dynamo exporter handles aten::unfold with symbolic (dynamic) shapes.
The legacy TorchScript exporter's unfold symbolic requires the unfolded
dimension to be statically known at export time, so this test exports
without dynamic_axes when use_dynamo=False.
"""
import io
import numpy as np
import onnx
import onnxruntime as ort
import pytest
import torch

from transformers.models.led.modeling_led import LEDEncoderSelfAttention


class ChunkModule(torch.nn.Module):
    """Mirrors the body of _chunk(..., onnx_export=True) after the fix."""

    def __init__(self, window_overlap):
        super().__init__()
        self.w = window_overlap

    def forward(self, x):
        return x.unfold(dimension=1, size=self.w * 2, step=self.w).transpose(2, 3)


CONFIGS = [
    (8, 512, 64, 64),
    (24, 2048, 64, 256),
    (24, 4096, 64, 512),
    (24, 16384, 64, 512),
]

OPSETS = [14, 17, 20]


@pytest.mark.parametrize("bh,seq,hd,w", CONFIGS)
@pytest.mark.parametrize("opset", OPSETS)
@pytest.mark.parametrize("use_dynamo", [False, True])
def test_chunk_onnx_roundtrip(bh, seq, hd, w, opset, use_dynamo):
    torch.manual_seed(0)
    x = torch.randn(bh, seq, hd)
    model = ChunkModule(w).eval()

    ref = LEDEncoderSelfAttention._chunk(x, w, onnx_export=True)

    buf = io.BytesIO()
    torch.onnx.export(
        model,
        (x,),
        buf,
        opset_version=opset,
        dynamo=use_dynamo,
        input_names=["x"],
        output_names=["chunks"],
    )

    buf.seek(0)

    model_proto = onnx.load_from_string(buf.getvalue())
    onnx.checker.check_model(model_proto)

    sess = ort.InferenceSession(buf.getvalue(), providers=["CPUExecutionProvider"])
    input_name = sess.get_inputs()[0].name
    (out_np,) = sess.run(None, {input_name: x.numpy()})

    np.testing.assert_allclose(out_np, ref.numpy(), atol=1e-6, rtol=1e-6)

Result: 24/24 PASSED on Windows 11 / PyTorch (current) / onnx /
onnxruntime.

I didn't add this to the repo because it would pull onnx and
onnxruntime into the test deps. Happy to add it gated behind
require_onnx / require_onnxruntime decorators (matching the
existing pattern for optional deps in transformers) if you'd prefer
it in-tree — one-line ask, just say the word.

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

Who can review?

Small ONNX-export cleanup, applied identically in LED and Longformer.
Per the tagging guide:

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: led, longformer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant