Skip to content

fix(perf): Don't use data-dependent op when batch_size and max_num_nodes are provided to to_dense_batch#10660

Open
akihironitta wants to merge 2 commits into
masterfrom
aki/todensebatch
Open

fix(perf): Don't use data-dependent op when batch_size and max_num_nodes are provided to to_dense_batch#10660
akihironitta wants to merge 2 commits into
masterfrom
aki/todensebatch

Conversation

@akihironitta
Copy link
Copy Markdown
Member

Summary

This PR addresses two issues in to_dense_batch:

  1. Currently, to_dense_batch triggers a D2H sync even when batch_size and max_num_nodes are provided because of num_nodes.max() > max_num_nodes to decide whether to run boolean masking on x and idx.
  2. In addition, the boolean masking is a data-dependent op which causes a graph break. An alternative would be to enable capture_dynamic_output_shape_ops=True, but in our case, this new implementation guarantees that the shape is static for a given batch_size and max_num_nodes at compile time, so there's no need to continue to use the op.

Benchmark

Before (3a4b881)

[---------- bs=1024 mn=64 ----------]
                     |  cpu   |  cuda
8 threads: --------------------------
      eager          |  44.2  |  4.3
      overflow       |   9.2  |  2.9
      torch.compile  |  41.9  |  3.7

Times are in milliseconds (ms).

[---------- bs=1024 mn=512 ----------]
                     |   cpu   |  cuda
8 threads: ---------------------------
      eager          |  465.1  |  24.4
      overflow       |  252.9  |  14.7
      torch.compile  |  455.5  |  23.8

Times are in milliseconds (ms).

[---------- bs=1024 mn=2048 ----------]
                     |   cpu    |  cuda
8 threads: ----------------------------
      eager          |  1670.7  |  95.9
      overflow       |   941.0  |  56.2
      torch.compile  |  1651.8  |  93.2

Times are in milliseconds (ms).

After (this PR)

[---------- bs=1024 mn=64 ----------]
                     |  cpu   |  cuda
8 threads: --------------------------
      eager          |  34.7  |  2.8
      overflow       |  16.3  |  2.5
      torch.compile  |  33.3  |  1.4

Times are in milliseconds (ms).

[---------- bs=1024 mn=512 ----------]
                     |   cpu   |  cuda
8 threads: ---------------------------
      eager          |  290.3  |  16.8
      overflow       |  214.0  |  12.4
      torch.compile  |  294.4  |  15.5

Times are in milliseconds (ms).

[---------- bs=1024 mn=2048 ----------]
                     |   cpu    |  cuda
8 threads: ----------------------------
      eager          |  1096.5  |  67.1
      overflow       |   827.8  |  49.6
      torch.compile  |  1062.9  |  63.0

Times are in milliseconds (ms).

@akihironitta akihironitta marked this pull request as ready for review April 4, 2026 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant