Skip to content

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829

Open
sudhakarsingh27 wants to merge 46 commits into
NVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag
Open

[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism#2829
sudhakarsingh27 wants to merge 46 commits into
NVIDIA:mainfrom
sudhakarsingh27:cp_thd_swa_with_ag

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Member

@sudhakarsingh27 sudhakarsingh27 commented Apr 3, 2026

Description

Add THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather. Previously, AllGather-based CP only supported fixed-length formats (bshd/sbhd). THD format packs variable-length sequences into a single [t, h, d] tensor tracked by cu_seqlens, which is needed for workloads with heterogeneous sequence lengths.

The key challenge is that AllGather CP splits Q across 2 steps (one per local chunk), but THD tensors can't be naively sliced like fixed-length formats. This PR uses an offset-based approach: the full Q tensor is passed to the cuDNN kernel each step, with per-step cu_seqlens_q_padded values directing the kernel to read the correct chunk. This avoids tensor slicing entirely and leverages cuDNN's back-padding convention (valid tokens at the beginning of each padded allocation).

Fixes # (issue)

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

  • Offset-based Q chunking: Per-step cu_seqlens_q_padded selects which chunk the kernel reads from the full Q tensor, instead of slicing Q per step
  • Per-step KV cu_seqlens: Computes visible KV token counts per step for causal masking (chunks 0..chunk_id) and non-causal (all tokens)
  • THD reorder reuse: Reuses the existing reorder_seq_chunks_*_thd helpers (originally for A2A) to reorder all-gathered KV into contiguous per-sequence order
  • max_logit masking fix: Handles non-zero-starting cu_seqlens_q_padded in the valid-token mask (step 1's padded offsets don't start at 0)
  • Test gates: Enables THD+all_gather for FusedAttention tests; skips FlashAttention (no THD padding support)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

… cu_seqlens

- Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing
- Use padded cu_seqlens_kv for K/V reordering (ensures divisibility)
- Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature
- Compute per-step Q and KV cu_seqlens correctly from actual seqlens
- Support non-causal attention (all KV visible)
- Zero-initialize out/dq for THD to avoid garbage in padding regions
- Save per-step cu_seqlens in ctx for backward (avoid recomputation)

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded

The interleaved valid mask computation assumed cu_seqlens_q_padded starts
at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at
a non-zero offset, causing a size mismatch. Use absolute positions from
cu_seqlens_q_padded to build the valid mask instead.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
if qkv_format == "thd":
# [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d]
chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device)
k_ag = reorder_seq_chunks_after_a2a_before_attn_thd(
Copy link
Copy Markdown
Member Author

@sudhakarsingh27 sudhakarsingh27 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd

Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
@sudhakarsingh27 sudhakarsingh27 changed the title Cp thd swa with ag [PyTorch][CP] Add THD format support for AllGather-based Context Parallelism Apr 13, 2026
@sudhakarsingh27 sudhakarsingh27 marked this pull request as ready for review April 13, 2026 21:53
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 13, 2026

Greptile Summary

This PR adds THD (variable-length sequence) format support to AttnFuncWithCPAndKVAllGather, using an offset-based approach where the full Q tensor is passed to each cuDNN kernel step and cu_seqlens_q_padded directs the kernel to the correct per-step chunk — avoiding tensor slicing for non-uniform sequence lengths. The change also lifts the FusedAttention block for thd + all_gather, adds permutation-cache helpers for THD KV reordering, and extends get_attention_backend to allow FA3 with pad_between_seqs.

Confidence Score: 4/5

Safe to merge for non-THD workloads; THD+AllGather has unresolved P1 concerns from prior review rounds (dK/dV zeroing assumption, cache key missing CUDA device) that should be confirmed before enabling in production.

Several P1 findings from prior review rounds appear unresolved in the current diff (dK/dV kernel-zeroing assumption undocumented, _thd_reorder_perm_cache key missing device). The cp_stream.wait_stream fix was confirmed applied. New findings in this pass are P2 only (redundant out_f16 allocation, Python loop in max_logit masking). P1 ceiling applies.

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — dK/dV backward zeroing assumption and permutation cache key require follow-up before production use of THD+AllGather.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Core THD+AllGather implementation: offset-based Q chunking, per-step cu_seqlens, THD KV reordering, and backward dQ/dK/dV accumulation. Several unresolved issues from previous review rounds remain open (dK/dV zeroing assumption, cache key missing device).
transformer_engine/pytorch/cpp_extensions/fused_attn.py max_logit masking refactored from vectorized repeat_interleave to a per-batch Python loop to handle non-zero cu_seqlens_q_padded start positions; logic is correct but slower for large batch sizes.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Backend selection: FA2 page-size check tightened to modulo-256 divisibility, FA3 deterministic guard extended to head_dim_v, and THD+all_gather no longer disables FusedAttention. All changes look correct.
transformer_engine/pytorch/attention/dot_product_attention/backends.py FlashAttention.forward gains cu_seqlens_q/kv_padded and pad_between_seqs parameters to support FA3's seqused_q/k path for pad-between-seqs THD. Changes look correct and consistent.
tests/pytorch/attention/run_attention_with_cp.py Test harness extended with fa_pad_between_seqs flag; adds padding-zeroing logic for FA3 tile spillover and verifies CP backward tensors have clean padding. Logic is complex but appears intentional.
tests/pytorch/attention/test_attention_with_cp.py THD+all_gather skip gates updated: FlashAttention correctly skipped for THD+all_gather, FusedAttention now exercised. Deterministic OOM guard added for sm90. Changes look correct.

Sequence Diagram

sequenceDiagram
    participant Fwd as AllGather Forward
    participant AG as gather_along_first_dim
    participant Reorder as reorder_thd_to_contiguous
    participant Stream0 as current_stream (step 0)
    participant Stream1 as cp_stream (step 1)
    participant cuDNN as fused_attn_fwd
    participant Out as out [t,h,d] (zeros)

    Fwd->>AG: AllGather K,V → k_ag [cp*t, h, d]
    AG-->>Fwd: k_ag, v_ag (rank-concatenated)
    Fwd->>Reorder: reorder k_ag,v_ag using cu_seqlens_kv_padded (P_inv)
    Reorder-->>Fwd: k_ag, v_ag (contiguous sequence order)
    Fwd->>Stream1: cp_stream.wait_stream(current_stream)
    Note over Fwd: Pre-compute per-step cu_seqlens_q / cu_seqlens_q_padded / cu_seqlens_kv
    par Step 0 on current_stream
        Stream0->>cuDNN: fwd(q_full, k_ag, v_ag, cu_seqlens_q_padded_step0)
        cuDNN-->>Stream0: out_per_step[0] [t,h,d]
    and Step 1 on cp_stream
        Stream1->>cuDNN: fwd(q_full, k_ag, v_ag, cu_seqlens_q_padded_step1)
        cuDNN-->>Stream1: out_per_step[1] [t,h,d]
    end
    Stream0->>Out: copy valid ranges (step 0 offsets) → out[s0:s0+sz0]
    Stream1->>Out: copy valid ranges (step 1 offsets) → out[s1:s1+sz1]
    Fwd->>Fwd: current_stream.wait_stream(cp_stream)
    Fwd->>Fwd: all_reduce max_logit across CP ranks
Loading

Reviews (7): Last reviewed commit: "Merge branch 'main' of github.com:NVIDIA..." | Re-trigger Greptile

Comment on lines +3436 to +3440
# dK/dV: add full tensor (kernel zeros non-valid positions)
if i > 1:
flash_attn_streams[i - 1].wait_event(dkv_update_done)
dk.add_(dk_per_step[i - 1])
dv.add_(dv_per_step[i - 1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 THD backward dK/dV relies on unverified cuDNN zeroing behavior

The comment says "kernel zeros non-valid positions", but this assumption is not documented in the cuDNN/TE spec. The A2A backward for THD uses tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, ...) specifically to handle the valid/padding boundary — a plain add_() was not considered sufficient there. If fused_attn_bwd leaves positions beyond cu_seqlens_kv_per_step[i] uninitialised in its output, both steps contribute garbage at non-overlapping KV ranges, which propagates through reduce_scatter_along_first_dim into the final dK/dV.

Before merging, either confirm (and document) that NVTE_F16_arbitrary_seqlen zeros non-valid dK/dV entries, or add explicit zeroing/use tex.thd_grad_correction if applicable to the contiguous-KV layout.

Comment on lines +3007 to +3011
# [AG+THD] Is this needed?
visible_actual = [
torch.minimum(actual_seqlens_kv, visible_padded_split)
for visible_padded_split in visible_padded
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unresolved development comment left in production code

# [AG+THD] Is this needed? reads like an open question from a debug session. The torch.minimum clamp is required: for sequences whose length is not a multiple of 2 * cp_size, padded_chunk_sizes_kv * (chunk_id + 1) can exceed actual_seqlens_kv[b], causing cu_seqlens_kv passed to the kernel to count padding as valid tokens. The comment should be resolved or removed.

if ctx.qkv_format == "thd":
cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded
thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step
cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead variable in backward pass

cu_seqlens_q_padded_rank is computed here but never read in the backward. The padded offsets are loaded from ctx.thd_cu_seqlens_q_padded_per_step a few lines later. This line can be removed.

sudhakarsingh27 and others added 2 commits April 16, 2026 11:28
…ific helpers

The AllGather THD path was not extending KV visibility beyond the causal
boundary when window_size had a right component > 0, meaning tokens right
of the diagonal were invisible to the kernel. Fix by adding window_size[1]
to visible_padded (clamped at actual seqlen) and max_seqlen_kv_.

Also rename reorder helpers to backend-neutral names since AllGather now
uses them too, and add a clarifying comment for non-causal KV cu_seqlens.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
pytest.skip(
"FlashAttention does not support THD padding; use FusedAttention for"
" THD+all_gather CP."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe swap the words a little bit so it doesn't sounds like FlashAttention doesn't support THD, but just our CP implementation with it doesn't? (Also, THD implies padding in our terminology?)

s = step_padded[b].item()
sz = (step_valid[b + 1] - step_valid[b]).item()
if sz > 0:
out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "for" loop might be too costly. Could this logic be written in another way in Python, or simply in C++/CUDA? Do we have some thd kernels that already do this?

s = step_padded[b].item()
sz = (step_valid[b + 1] - step_valid[b]).item()
if sz > 0:
dq[s : s + sz].copy_(dq_per_step[i - 1][s : s + sz])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, regarding the "for" loop.

pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment - please run the CP file with "test_essential=False" offline because the essential tests may not cover everything.

Comment on lines +3160 to +3170
elif qkv_format == "thd":
# Copy valid token ranges from this step's output.
# Each step writes at different positions (no overlap, no correction needed).
step_padded = thd_cu_seqlens_q_padded_per_step[i - 1]
step_valid = thd_cu_seqlens_q_per_step[i - 1]
batch_size = step_valid.shape[0] - 1
for b in range(batch_size):
s = step_padded[b].item()
sz = (step_valid[b + 1] - step_valid[b]).item()
if sz > 0:
out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 THD forward output copy runs on wrong stream for step 1

For step 1 (i == 2), the with torch.cuda.stream(flash_attn_streams[i - 1]) block streams the copy onto cp_stream. But out_per_step[1] is produced on cp_stream by the step-1 attention kernel, and the copy also runs on cp_stream, so there is no race between the kernel and the copy.

However, the if return_max_logit: block at line 3172–3173 runs outside the with block (i.e. on the default stream) and reads max_logit_per_step[i - 1] — which was produced on cp_stream for step 1. There is no current_stream.wait_stream(cp_stream) before this point (that sync only happens at line 3175, after the loop). As a result the torch.maximum kernel launched at line 3173 on the default stream can race against the step-1 attention kernel still running on cp_stream.

This is a pre-existing pattern, but it is newly exercised by THD+AllGather since this is the first path that has both return_max_logit=True and dual-stream THD step execution. Moving the max_logit merge inside the with torch.cuda.stream(flash_attn_streams[i - 1]): block would eliminate the race.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
sudhakarsingh27 and others added 21 commits April 24, 2026 15:35
Add test parametrization for pad_between_seqs in flash attention tests.
Update run_attention_with_cp.py to support the new parameter and fix
batch boundary alignment in the non-CP FA3 path. Run tests in parallel
when multiple GPUs are available.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH
positional arg and fix GPU threshold for parallel test execution.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…raint

The previous check disabled FA3 for deterministic mode whenever
head_dim_qk > 128, which was overly conservative — FA3 forward supports
deterministic execution at any head dim. The actual constraint from
flash_api.cpp is that the backward pass does not support deterministic
mode when max(head_size, head_size_v) >= 256.

Narrow the gate to only disable FA3 during training (backward) and
raise the threshold to >= 256, checking both head_dim_qk and head_dim_v
to handle MLA configs with asymmetric head dimensions.

Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The pad_between_seqs gate in get_attention_backend only disabled
FlashAttention 2, letting FA4 leak through to the test-time
fused-vs-flash comparison. On B200 runners that install flash-attn-4,
this caused test_dpa_qkv_layout_thd to compare FusedAttention against
an FA4 output whose padded positions contain garbage, producing 48
numerics failures in L3_pytorch_FA_versions_test--B200_1GPU.

The log message already claimed FA4 would be disabled — this change
makes the code match the message: set use_flash_attention_4 = False
alongside use_flash_attention_2 when pad_between_seqs is True. FA3
continues to support pad_between_seqs via seqused_q/seqused_k.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass`
adds cutlass/base_dsl/ to sys.path. That directory contains a utils/
package that shadows tests/pytorch/utils.py, breaking collection of
test_attention_with_cp.py with:
  ImportError: cannot import name 'ModelConfig' from 'utils'

Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py
is always resolved first, regardless of what FA4 dependencies install.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…s its a known cudnn issue

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ransformerEngine into flash_attn_pad_bw_seqs
…utation cache

The reorder_seq_chunks_{before,after}_a2a_*_thd functions used a Python
for-loop of torch.arange calls (2*cp_size*batch iterations) to build
index tensors, which dominated wall-clock time at high batch counts.

Replace with the existing thd_get_partitioned_indices CUDA kernel (one
call per CP rank) plus a permutation cache keyed on (cu_seqlens, cp_size).
This collapses thousands of tiny elementwise kernel launches into a
handful of kernel calls on first use, then a dict lookup thereafter.

Rename to reorder_thd_sequences_to_{rank_sharded,contiguous} since these
are used by both A2A and AllGather CP paths, not just A2A.

Measured speedups (cp=2, bf16, H100x2, 50 iters):
  cp_thd_2 (B=16): a2a 10.6x, all_gather 4.5x
  cp_thd_3 (B=8):  a2a 5.3x,  all_gather 3.3x
  bariamis_8k (B=2): a2a 3.0x, all_gather 2.2x
  bariamis_262k (B=2, S=262k): a2a 1.0x, all_gather 12.0x

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
The THD AllGather restructuring left cu_seqlens_kv_per_step[i] as None
for the SBHD/BSHD path, causing a TypeError in fused_attn_fwd which
expects a Tensor. Build it from batch_size and the KV slice length,
matching what the original prepare_outputs helper provided.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Comment on lines +3200 to +3202
v_ag = reorder_thd_sequences_to_contiguous(
v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing cp_stream.wait_stream in THD AllGather forward

The non-THD path correctly calls cp_stream.wait_stream(torch.cuda.current_stream()) at line 3213 after preparing k_ag/v_ag, ensuring the cp_stream (used for step 1) sees the AllGathered and reordered tensors. The THD branch skips this synchronization, so when step 1 (i=1) launches on cp_stream, it may read partially-overwritten or stale k_ag/v_ag data that reorder_thd_sequences_to_contiguous hasn't finished writing on the current stream.

            v_ag = reorder_thd_sequences_to_contiguous(
                v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size
            )
+           cp_stream.wait_stream(torch.cuda.current_stream())
        else:

The backward already handles this correctly — ctx.cp_stream.wait_stream(torch.cuda.current_stream()) at line 3782 is unconditional (outside the if ctx.qkv_format == "thd": block). The forward needs the same treatment.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorporated. Check again

The non-THD branch of AttnFuncWithCPAndKVAllGather.forward calls
cp_stream.wait_stream(current_stream) after preparing k_ag/v_ag, but the
THD branch (added by this PR) was missing that synchronization. Step 1
of the per-step attention loop launches on cp_stream and reads
k_ag/v_ag, so without the wait it can race against the AllGather +
reorder_thd_sequences_to_contiguous writes still in flight on the
current stream.

The race is masked at low cp_size because the reorder finishes before
cp_stream actually starts reading. At cp=8 with the largest THD config
(bucket128k: B=3, S_max=131072) the reorder takes long enough to
outlast cp_stream's launch, and the resulting partially-written
k_ag/v_ag feeds garbage values into cuDNN, surfacing async as
cudaErrorIllegalInstruction.

Move the wait_stream call out of the non-THD else-branch so both paths
get the synchronization unconditionally. Verified that bucket128k
all_gather @ cp=8 now runs cleanly (360.72 ms/iter).

Identified by Greptile review (P1, discussion_r3170978192).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ransformerEngine into flash_attn_pad_bw_seqs
Comment on lines +286 to +299
global _thd_reorder_perm_cache
key = (tuple(cu_seqlens.tolist()), cp_size)
if key not in _thd_reorder_perm_cache:
total_tokens = int(cu_seqlens[-1].item())
P = torch.cat(
[
tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, cp_size, rank)
for rank in range(cp_size)
]
)
P_inv = torch.empty_like(P)
P_inv[P.long()] = torch.arange(total_tokens, dtype=P.dtype, device=P.device)
_thd_reorder_perm_cache[key] = (P, P_inv)
return _thd_reorder_perm_cache[key]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Cache key missing CUDA device

Every other shape/rank cache in this file (e.g. _seq_chunk_ids_cache_for_reordering_before_attn, _seq_chunk_ids_cache_for_reordering_after_attn, _softmax_offset_chunk_ids_cache) includes the device in its key as (cp_size, device). _thd_reorder_perm_cache only uses (tuple(cu_seqlens.tolist()), cp_size).

In any context where two CUDA devices share the same cu_seqlens shape — multi-GPU unit tests with a single process, or pipeline-parallel+data-parallel setups — the first call populates the cache with tensors on device A; a subsequent call from device B hits the same key and gets those tensors back, then passes them to x.index_select(seq_dim, P_inv) where x lives on device B, producing a cross-device RuntimeError at runtime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants