Skip to content

[FEA] Beam search#379

Open
z52527 wants to merge 75 commits into
NVIDIA:mainfrom
z52527:fea-mask-beam-search
Open

[FEA] Beam search#379
z52527 wants to merge 75 commits into
NVIDIA:mainfrom
z52527:fea-mask-beam-search

Conversation

@z52527
Copy link
Copy Markdown
Collaborator

@z52527 z52527 commented May 9, 2026

Description

z52527 and others added 30 commits March 11, 2026 09:17
Add generate_beam_decode() path using context/beam KV separation:
- BeamSearch: track parent_indices; add build_beam_topk_indices()
- JaggedGPTLayer/Block: add prefill() and decode_beam() methods
- SIDGRModel: add generate_beam_decode() with prefill + decode loop
- Tests: 11 integration tests for topk indices, prefill, decode pipeline

Use backend="3kernel" to work around fused-path JIT recompilation hang.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Bug fixes:
- Variable-length history: kernel attended padding K positions, contaminating
  K1 attention output. Add seqused_k support to gr-decode_atten interface.py
  (3-kernel path) and thread it through generate_beam_decode.
- Degenerate beam_width: use generated_sids.shape[1] (actual top-k) instead
  of beam_widths[0] when codebook is smaller than configured beam_width.
- dtype side effect: stop converting context_kv inside decode_beam every
  layer/step; assert caller passes fp16/bf16 cached tensors.

Refactors:
- Add SIDGRDecoder.get_jagged_flash_attn_block() to replace fragile
  self.decoder.decoder.block triple-attribute access.
- Make backend ("3kernel" / "dsl") configurable on generate_beam_decode and
  decode_beam; default "3kernel" with FIXME for fused-path bug.

Tests:
- E2E test for generate_beam_decode through full SIDGRModel (skipped when
  commons stack unavailable, e.g. outside Docker).
- Reference oracle: 12 cases comparing CuTe kernel output to PyTorch
  reference within bf16 tolerance.
- Variable-length history reproducer that exercises seqused_k masking.

54 tests pass (2 e2e skipped due to commons ABI mismatch in this env).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Adds a standalone benchmark that times both generation paths back-to-back
on the same model and batch, with warmup and median/p95 stats. Single-
config and sweep modes (sweep prints a markdown table over hist_len ×
beam_width).

Run inside the Docker container with quack-kernels installed and
gr-decode_atten on PYTHONPATH.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
When running via torchrun (not pytest), the benchmark script fails to
import tests.test_utils because pytest's auto-discovery isn't active.
Add the sid_gr root dir to sys.path explicitly.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Captures speedups of generate_beam_decode (Jerry kernel, 3-kernel backend)
over the original generate() path on H100. Median speedup ≈ 1.27× across
11 configurations, range 1.14×–1.38×. Bigger beam widths and shorter
sequences benefit most; deeper models see smaller relative gains because
the FFN dominates total time.

hist_len=256 case didn't finish — likely a K1 cache-key bug in
gr-decode_atten/interface.py when num_splits>1 path is hit; documented
in RESULTS.md and not blocking the integration.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
- Add test_generate_vs_generate_beam_decode_equivalence: asserts the two
  paths produce close log_probs and overlapping top-K SID sets within
  bf16 tolerance (the topk decision boundary in beam search prevents
  bit-exact match, but mathematical equivalence is verified).
- generate_beam_decode now accepts an optional phase_times dict that the
  caller can populate with prefill_ms / decode_loop_ms via cuda events.
- Benchmark gains:
    - --dtype bf16/fp16 on the model dtype
    - --sweep_dtype to sweep over both
    - measure_phase_breakdown() to time prefill vs decode_loop
- Sweep markdown table now includes dtype + phase columns.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Captures speedups across hist_len {32,64,128,256} × beam_w {4,10,20}
× dtype {bf16,fp16} with prefill/decode phase breakdown. Median speedup
1.27x, range 1.22x-1.34x. fp16 and bf16 perform near-identically.

Includes:
- Phase breakdown showing prefill (~3ms) vs decode_loop (~5.6ms)
- Discussion of where the speedup comes from
- Equivalence test summary
- Known kernel issues (split-KV+seqused_k deadlock, fused-path cache bug)
  with documented workarounds

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
The previous refactor (commit 528cf77) moved mask construction out of
generate() into the caller's responsibility but the actual call sites
left attention_mask=None, falling back to decoder_step's plain causal
default. This silently broke beam isolation in the JaggedTransformerBlock
path: at hierarchy step i>=1, beam_b's tokens could attend to beam_a's
tokens (a<b) via the plain jagged-causal mask, contaminating logits.

Restore the original generate() behaviour:
- generate() builds padded_target_aware_causal_mask per step (same as
  pre-528cf77).
- decoder_step on the jagged FA path now auto-converts a dense mask to a
  flattened arbitrary_func via dense_mask_to_jagged_arbitrary_func
  (re-introduced in attention_mask.py from the earlier 74e07c8 version).

Tests:
- TestBeamIsolationMask: unit-test padded_target_aware_causal_mask
  geometrically (no cross-beam attention, full history visibility).
- test_generate_beam_perturbation_invariance: smoke check that generate
  is deterministic, prerequisite for the perturbation argument.
- test_generate_vs_generate_beam_decode_equivalence: thresholds tightened
  now that the baseline is correct (top-1 SIDs must match exactly,
  log_prob max diff < 0.15, top-K set overlap >= 70%).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
The previous fix used dense_mask_to_jagged_arbitrary_func which has a
nested Python loop (B × N × intervals); for typical seqlens this turned
generate() into a 700ms-per-call hot spot.

Add build_jagged_target_aware_arbitrary_func that constructs the same
flattened arbitrary_func directly from offsets / history_seqlens /
beam_width / candidate_length using vectorised tensor ops (single
torch.where, no per-position Python loop).

generate() now uses the fast builder on the jagged FA path; the dense
path still uses padded_target_aware_causal_mask. Benchmark recovers to
1.30-1.38x speedup with the correct beam-isolating baseline.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
…ices

The previous implementation used s * W to index into beam_kv, which is
only correct when beam_widths is uniform across hierarchy steps. Switch
to cumulative offsets (sum(beam_widths[:s])) so the math is also right
for non-uniform per-step widths.

This is currently latent — Jerry's kernel asserts uniform widths via
k_beam.shape[1] == decode_nums * beam_width, so non-uniform widths
can't be passed end-to-end. The BeamSearch side is now ready if/when
the kernel grows non-uniform support.

Tests:
- test_list_equal_widths_matches_int_path: BeamSearch(W=int) and
  BeamSearch([W,W,W]) must produce bit-identical topk_indices.
- test_nonuniform_widths_offsets: explicit check on cumsum offsets and
  ancestor flat-index computation for [3, 5, 7] beam widths.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Address remaining audit items (skip B — kernel patch portability is
deferred):

A. BeamSearch rejects non-uniform beam_widths at __init__ since Jerry's
   kernel asserts k_beam.shape[1] == decode_nums * beam_width. The
   indexing math in build_beam_topk_indices is already general (cumsum
   offsets) so this is forward-compatible if kernel adds support.

C. Renamed test_generate_vs_generate_beam_decode_equivalence to
   test_generate_vs_generate_beam_decode_regression_guard to reflect
   what the test actually verifies. Reference oracle remains the strong
   correctness check.

D. Added @torch.no_grad to generate_beam_decode (generate already had it).

E. RESULTS.md gets an explicit "correctness preconditions" section
   noting (a) numbers are vs the corrected baseline, (b) local kernel
   patches required, (c) regression guard != equivalence proof. Re-ran
   the sweep — speedup is now 1.31x-1.43x (was 1.22x-1.34x against the
   buggy baseline) since the corrected baseline does more work.

F. Fixed two "to perceive the mppy check" -> "to appease the mypy check"
   typos.

Test count: 45 -> 46 (added test_nonuniform_widths_rejected_at_init).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Address Cursor's follow-up findings:

1. top_k > codebook_size shrink path: BeamSearch.propagate clamps actual
   topk to min(beam_widths[s], topk_prev * codebook_size). When this
   triggers, k_beam.shape[1] no longer equals decode_nums * beam_width
   and the kernel assertion fails. generate_beam_decode now asserts
   beam_width <= min(codebook_sizes) before running.

2. BeamSearch global non-uniform reject was too aggressive: BeamSearch
   is a general class, only the kernel-using path needs uniformity.
   Move the assertion to generate_beam_decode (kernel-specific consumer)
   and let BeamSearch itself accept non-uniform widths.

3. backend="dsl" + seqused_k unsafe combination: decode_beam now asserts
   backend=="3kernel" when seqused_k is set, since seqused_k support
   only exists on the 3-kernel path.

4. build_jagged_target_aware_arbitrary_func oracle test: parametrised
   over 3 hist-length patterns, 3 beam-width settings (incl. 0), and 2
   candidate lengths. Compares vectorised builder output to the dense
   path (padded_target_aware_causal_mask -> dense_to_jagged) by
   expanding both to dense bool masks.

5. Defensive fix: build_beam_topk_indices uses parent_indices[s].shape[1]
   (actual topk after propagate) for offsets and W_d, not the configured
   beam_widths. Stays correct under top_k_for_generation > codebook_size
   even if the assert in generate_beam_decode is bypassed.

6. test_generate_beam_perturbation_invariance was a determinism check,
   not perturbation. Renamed to test_generate_is_deterministic.

64 tests pass (was 46; +18 oracle parametrisations).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
1. RESULTS.md: BeamSearch.__init__ no longer rejects non-uniform widths;
   describe the actual gate (SIDGRModel.generate_beam_decode entry) and
   note BeamSearch itself stays general.

2. BeamSearch.__init__ now asserts len(beam_widths) == num_hierarchies
   (parallels the existing codebook_sizes check) so [2, 2] with 3 hierarchies
   fails fast at construction, not midway through propagate.

3. generate_beam_decode rejects backend != "3kernel" at entry, before
   prefill / step0 MLP. Previously decode_beam would assert later in the
   pipeline after meaningful work was done.

4. test_cumsum_offsets_via_attribute_override: rename + comment update —
   it now constructs BeamSearch with the non-uniform list directly
   instead of overriding the attribute, since BeamSearch accepts
   non-uniform widths now.

64 tests still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Document why we currently force backend="3kernel" everywhere:

- K1 (Context Attention, tensor core): Q × shared context KV.
- K2 (Beam Sparse, CUDA core FMA): Q × per-beam KV via topk gather.
- K3 (Combine): log-sum-exp merge.

Per Jerry: fused is faster on SM80/90/120 at small decode_nums; SM100
prefers 3-kernel (kernel auto-routes there regardless). Once Jerry's
upstream cache-key bug is fixed (the JIT-deadlock when decode_nums
varies), the 3-kernel forcing in generate_beam_decode should be
relaxed so the arch-aware default picks the optimal path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Previous version exposed the kernel's internal K1/K2/K3 names which are
implementation details. The user-visible API offers two backends —
pipelined ("3kernel") and fused ("dsl") — selected via the `backend`
argument. Rewrite the docstrings to describe what each does behaviour-
wise, when each is preferred (per-arch), and why we currently pin to
"3kernel" (workaround for upstream JIT cache bug). No code changes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Previous comment claimed "upstream cache key bug" was the reason —
verified empirically that's not the whole story. With our local
cache-key patch in place (decode_nums included), backend="dsl" still
hangs on SM90 H100 PCIe even at the smallest workload (B=1, W=2,
decode_nums=1): the kernel launches, GPU sits at 100% utilisation, and
the call never returns. Root cause TBD (likely a CuTe-DSL fused-kernel
bug or an env/version mismatch we haven't pinned down).

3-kernel pipeline path runs cleanly. Benchmark sanity check after the
doc rewording: 1.39x speedup at hist=128, beam=10, 2 layers (matches
previous numbers).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Narrowed the dsl backend hang via a beam-width sweep on H100 PCIe:
- W in {8, 16, 24, 32, 40, 48, 64, 128}: returns in <3 s
- W in {1, 2, 4, 9, 10, 11, 12, 14, 15, 17, 20}: hangs forever (GPU 100%)

Pattern is exact: W % 8 != 0 hangs every time. Jerry's own test_fwd.py
passes for dsl because all his cases use W in {128, 256, 512, 1024}.

Our local cache-key patch is irrelevant — bug reproduces on a clean
checkout. Likely a tile-size assumption in the SM90 fused kernel that
isn't validated for partial tiles.

Added docs/dsl_backend_hang_bug_report.md with the minimal reproducer
to share with Jerry. Updated the docstrings in generate_beam_decode and
decode_beam to point at the actual finding instead of vague "fused
hangs".

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
@z52527
Copy link
Copy Markdown
Collaborator Author

z52527 commented May 14, 2026

/build devel

z52527 added 3 commits May 14, 2026 06:38
…1 bump

The merge of origin/main picked up the Megatron-LM core_v0.12.1 →
core_v0.13.1 bump and resolved the textual conflict in favor of main's
version. That dropped the `pip uninstall -y nvidia-cutlass-dsl ...`
line from a031b92, which is needed because the nvcr base image
preinstalls cutlass-dsl=4.3.0 via a .pth injection that `pip install
==4.4.1` doesn't fully replace.

Re-apply the uninstall on top of the merge, now placed before the
core_v0.13.1 git clone in the same RUN.

Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
@z52527
Copy link
Copy Markdown
Collaborator Author

z52527 commented May 14, 2026

/build devel

@JacoCheung
Copy link
Copy Markdown
Collaborator

JacoCheung commented May 14, 2026

Pipeline #51257019 -- canceling

Job Status Log
build_devel ❔ canceling view
build_inference_devel ❌ failed view
build_tritonserver_devel ❌ failed view
pre_check 🚫 canceled view
train_build 🚫 canceled view
inference_build 🚫 canceled view
tritonserver_build 🚫 canceled view
build_whl 🚫 canceled view
dynamicemb_test_fwd_bwd_8gpus 🚫 canceled view
dynamicemb_test_load_dump_8gpus 🚫 canceled view
unit_test_1gpu_a100 🚫 canceled view
unit_test_1gpu_h100 🚫 canceled view
unit_test_4gpu 🚫 canceled view
unit_test_tp_4gpu 🚫 canceled view
L20_unit_test_1gpu 🚫 canceled view
inference_unit_test_1gpu 🚫 canceled view
inference_test_1gpu 🚫 canceled view

View full pipeline

tile_m = 128 if D <= 128 else 64
tile_n = 128 if D <= 128 else 64

key = ("fused", cc, D, qhead_per_kv, cutlass_dtype, is_split_kv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Fused-path compile-cache key omits decode_nums, W, and num_splits

_fused_context_beam specialises the CuTe kernel at compile time on the concrete integer arguments W, decode_nums, and num_splits. The cache key ("fused", cc, D, qhead_per_kv, cutlass_dtype, is_split_kv) omits all three. In generate_beam_decode, decode_nums increments from 1 to num_hierarchies-2 on successive decode steps. After the first call compiles for decode_nums=1, every later step gets a cache hit and reuses that stale kernel — the same root cause as the K2 bug but in the fused/DSL path used on SM80/SM90/SM120.

Suggested change
key = ("fused", cc, D, qhead_per_kv, cutlass_dtype, is_split_kv)
key = ("fused", cc, D, qhead_per_kv, cutlass_dtype, is_split_kv, W, decode_nums, num_splits)

z52527 and others added 6 commits May 14, 2026 07:50
…ices

The previous diagram numbered beams (beam 0..3) at each step, which
made cloning hard to read — the same SID tuple would carry different
beam indices across rounds. Restructure each step's subgraph into
per-parent sub-subgraphs; a group tagged "cloned" means that parent
fed multiple children to the next step. Step labels now show the
top-K math explicitly (e.g. "pick top 4 of 4 × 256 = 1024").

Node labels are just the SID tuples now; cloning, parentage, and
top-K math are carried by the subgraph titles instead.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Annotate each step's top-K denominator with where the multiplier comes
from: "1 (BOS context)" for Step 1, "4 (every Step 1 beam) × 256" for
Step 2, etc. Caption explicitly notes that dashed beams are still
expanded — pruning happens only at the next step's top-K, not at the
current step's forward pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
GitHub's mermaid base theme renders nodes with a #ECECFF (light
purple) fill, #9999CC border, and 1 px edges — looks washed out.
Inject a theme override at the top of the block:
  - white node fill, dark gray border
  - subgraph clusters get a subtle gray-blue background
  - edges go to 1.5 px in a darker gray (linkStyle default)
  - pruned style lightened slightly to read as "ghost", not "broken"

No structural change to the diagram itself.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
…rmaid

Both flowchart blocks inside the HTML table now get the same theme
init + linkStyle as the beam-expansion diagram: white node fill,
dark gray border, 1.5 px edges. Visually consistent across the three
diagrams in this section. No structural change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
`pip uninstall -y nvidia-cutlass-dsl` did not actually evict the
base-image 4.3.0 install: the dist-info / .pth / package tree at
/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/ survived
the uninstall (whether due to a broken RECORD file, a non-pip system
install, or pip refusing to delete files outside what it tracked).
Subsequent `pip install ==4.4.1` recorded a new dist-info but did not
replace the 4.3.0 files, so `import cutlass.cute.arch` continued to
resolve to 4.3.0 (no `ProxyKind` symbol) and flash_attn.cute import
failed at the Layer 7 verify.

Fix: combine pip uninstall with an explicit `rm -rf` of the 4.3.0
tree (.pth + dist-info + package dir) before installing 4.4.1, and
add an immediate `python -c "from cutlass.cute.arch import ProxyKind,
SharedSpace"` verify at the end of the same RUN. If the upgrade is
ever re-shadowed by a future base-image change, this will now fail
at Layer 2 (the cutlass install layer) with a precise message,
instead of much later when flash_attn.cute is first imported.

Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Functional behavior unchanged. The session's iterative debugging left
20+ lines of inline comments explaining the cutlass shadow and the
flash_attn.cute pkg layout; collapse each to 3-line blurbs. Drop
--no-build-isolation on the flash_attn/cute install — it was needed
when we used to install the root flash-attention pkg (its setup.py
imports torch); the cute subpackage's pyproject.toml needs only
setuptools, so default build isolation works.

Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Comment thread examples/sid_gr/model/attention_mask.py Outdated
z52527 and others added 2 commits May 14, 2026 23:36
nvcr.io/nvidia/pytorch:26.02-py3's pre-populated pip cache contains
an nvcr-built nvidia-cutlass-dsl-libs-base==4.4.1 wheel whose
cute/arch/__init__.py is 9 bytes shorter than PyPI's public 4.4.1
wheel and omits the top-level ProxyKind / SharedSpace re-export
that flash_attn.cute requires. Plain `pip install
'nvidia-cutlass-dsl[cu13]==4.4.1'` hits the bad cached wheel via
pip's extra-resolution code path, even with --no-cache-dir.

Switch to --no-deps + the three cutlass-dsl subpackages spelled
out explicitly — that routes pip through the simpler explicit-args
install path where the cache trap doesn't apply. Re-pin all three
subpackages on the bundled `pip install` too, otherwise other
packages' deps (quack-kernels, apache-tvm-ffi) cascade and bump
cutlass-dsl to a mismatched newer minor.

The verify-line `python -c "from cutlass.cute.arch import
ProxyKind, SharedSpace"` fail-fasts the build if the upgrade
ever stops taking effect.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Replace the per-row Python loop with a cumsum + nonzero scatter so the
function issues a single host sync (for `max_intervals`) instead of one
per row × per interval × per .item() call.

Why
---
Greptile flagged this as P1: the loop has 4 host-syncing ops in the
inner body — `row.any()`, two `.nonzero()` materialisations, and
`start_pos[iv].item()` / `end_pos[iv].item()`. For B=64, seqlen=1024,
~2 intervals/row, that's ≈500 k forced GPU→CPU syncs per call. The
function is on the jagged-FA fallback path in `SIDGRModel.decoder_step`
(when the caller passes a dense `attention_mask` instead of a
prebuilt `arbitrary_func`), so this dominates training step time on
that path.

How
---
- `starts` / `ends` boundary detection was already vectorised; keep
  that.
- Mask out positions outside each sample's `[0, seq_len)` so padded
  rows/cols don't produce spurious intervals.
- `starts.cumsum(dim=-1)` assigns each transition a 1-based interval
  index without any sync.
- `starts.nonzero()` gives all (b, q, k) coordinates in one shot; index
  into `af` via vectorised assignment. One nonzero call per side
  replaces ~N × seq_len of them.
- Same for `ends`, with the existing `+1` (exclusive) offset preserved.

Verification
------------
Add `TestDenseMaskToJaggedVectorisedMatchesLoop` comparing the new
vectorised path against the existing loop-based test helper across:
jagged causal, target-grouped (4 beam_width × candidate_len cases),
all-zero mask, multi-interval per row, uneven seq_lens.

Local: 27/27 pass (was 20), pre-commit clean, no behaviour change for
the existing 20 tests.

Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
@z52527 z52527 force-pushed the fea-mask-beam-search branch from 7c69059 to ed60ab2 Compare May 15, 2026 06:37
z52527 and others added 2 commits May 15, 2026 00:36
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>

# Conflicts:
#	docker/Dockerfile
Merge bf77212 brought in main's Layer 2 with nvidia-cutlass-dsl==4.3.0
in the bundled pip install, overriding the fix from 5efdadc. Following
the merge-then-redo-fix convention (resolve in favor of main, then re-
apply the fix as a follow-up commit), this re-applies:

* Split cutlass-dsl into its own pip install with --no-deps +
  --no-cache-dir + three subpackages explicitly named at 4.4.1
  (bypasses the nvcr base-image poisoned cache wheel — see 5efdadc).
* Drop nvidia-cutlass-dsl==4.3.0 from the bundled install and re-pin
  the three subpackages there at 4.4.1 too, so quack-kernels /
  apache-tvm-ffi deps can't cascade and bump cutlass-dsl to a
  mismatched newer minor.
* Add quack-kernels / apache-tvm-ffi / torch-c-dlpack-ext back to the
  bundled install — these are runtime deps of flash_attn.cute that
  main dropped along with the cutlass-dsl version bump.
* Add the python -c "from flash_attn.cute.interface import
  flash_attn_func" verify-line to the flash-attention RUN (main
  dropped it) so namespace failures fail-fast at build time.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
@NVIDIA NVIDIA deleted a comment from JacoCheung May 15, 2026
@NVIDIA NVIDIA deleted a comment from JacoCheung May 15, 2026
@z52527
Copy link
Copy Markdown
Collaborator Author

z52527 commented May 15, 2026

/build devel

1 similar comment
@z52527
Copy link
Copy Markdown
Collaborator Author

z52527 commented May 15, 2026

/build devel

@JacoCheung
Copy link
Copy Markdown
Collaborator

JacoCheung commented May 15, 2026

Pipeline #51381430 -- failed

Job Status Log
build_devel ✅ success view
build_inference_devel ✅ success view
build_tritonserver_devel ✅ success view
pre_check ❌ failed view
train_build ✅ success view
inference_build ✅ success view
tritonserver_build ✅ success view
build_whl ❌ failed view
dynamicemb_test_fwd_bwd_8gpus ✅ success view
dynamicemb_test_load_dump_8gpus ✅ success view
unit_test_1gpu_a100 ✅ success view
unit_test_1gpu_h100 ✅ success view
unit_test_4gpu ✅ success view
unit_test_tp_4gpu ❌ failed view
L20_unit_test_1gpu ✅ success view
inference_unit_test_1gpu ✅ success view
inference_test_1gpu ✅ success view

Result: 14/17 jobs passed

View full pipeline

@JacoCheung JacoCheung self-requested a review May 18, 2026 07:34
Comment thread docker/Dockerfile
cloudpickle triton==3.6.0 nvidia-cutlass-dsl==4.3.0 --no-cache pre-commit
pip install --no-cache-dir torchx gin-config torchmetrics==1.0.3 typing-extensions iopath pyvers expiring_dict \
cloudpickle triton==3.6.0 \
'nvidia-cutlass-dsl==4.4.1' 'nvidia-cutlass-dsl-libs-base==4.4.1' 'nvidia-cutlass-dsl-libs-cu13==4.4.1' \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvidia-cutlass-* duplicated?

@@ -0,0 +1,187 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this file be moved to utest?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants