Skip to content

Add an example for disabling contextual mask in training#395

Open
geoffreyQiu wants to merge 1 commit into
NVIDIA:mainfrom
geoffreyQiu:fix_consistency
Open

Add an example for disabling contextual mask in training#395
geoffreyQiu wants to merge 1 commit into
NVIDIA:mainfrom
geoffreyQiu:fix_consistency

Conversation

@geoffreyQiu
Copy link
Copy Markdown
Contributor

@geoffreyQiu geoffreyQiu commented May 18, 2026

Description

  • Add disable_contextual_mask in HSTUConfig to compute HSTU attention kernel without contextual mask, with respect to Meta implementation.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 18, 2026

Greptile Summary

This PR introduces a disable_contextual_mask flag that allows users to opt out of the contextual attention mask in HSTU layers during training, propagating the option through the config, gin args, and layer implementations.

  • Adds disable_contextual_mask: bool = False to HSTUConfig, get_hstu_config, and NetworkArgs, and threads it through create_hstu_config so it can be set from gin configs.
  • Updates FusedHSTULayer and debug_hstu_layer.HSTULayer to conditionally pass None for num_contextuals when the flag is enabled; enables the flag in the kuairand_1k ranking gin config as a usage example.
  • NativeHSTULayer (used when tensor parallelism size > 1) is not updated and will silently ignore the flag.

Confidence Score: 3/5

Safe to merge for single-GPU or non-TP training, but the flag is silently ineffective when tensor parallelism is enabled.

The change works correctly for the FUSED and DEBUG layer paths, but NativeHSTULayer — selected automatically when tensor_model_parallel_size > 1 — was not updated. A user who sets disable_contextual_mask = True with TP enabled will receive no error or warning, yet the contextual mask will still be applied, producing different model behavior than expected.

examples/hstu/modules/native_hstu_layer.py needs the same disable_contextual_mask guard that was applied to the fused and debug layers.

Important Files Changed

Filename Overview
examples/hstu/modules/native_hstu_layer.py Not changed in this PR, but the num_contextuals call is not guarded by disable_contextual_mask, making the flag a no-op when tensor parallelism is used (NATIVE layer path).
examples/hstu/modules/fused_hstu_layer.py Correctly reads disable_contextual_mask from config and conditionally passes None for num_contextuals in the forward pass.
examples/hstu/modules/debug/debug_hstu_layer.py Correctly reads and applies disable_contextual_mask in the debug/reference HSTU layer.
examples/hstu/configs/hstu_config.py New disable_contextual_mask field added to HSTUConfig and get_hstu_config; field is missing from the class docstring.
examples/hstu/utils/gin_config_args.py Adds disable_contextual_mask: bool = False to NetworkArgs with correct docstring entry; also back-fills previously undocumented scaling_seqlen and embedding_backend entries.
examples/hstu/training/trainer/utils.py Passes disable_contextual_mask from NetworkArgs to get_hstu_config cleanly.
examples/hstu/training/configs/kuairand_1k_ranking.gin Adds NetworkArgs.disable_contextual_mask = True to the kuairand_1k ranking config as an example of the new feature.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    GIN["kuairand_1k_ranking.gin\nNetworkArgs.disable_contextual_mask = True"] --> NA["NetworkArgs\ndisable_contextual_mask: bool"]
    NA --> CU["create_hstu_config (utils.py)"]
    CU --> HC["HSTUConfig\ndisable_contextual_mask: bool"]
    HC --> FL["FusedHSTULayer\n✅ flag respected"]
    HC --> DL["debug HSTULayer\n✅ flag respected"]
    HC --> NL["NativeHSTULayer\n❌ flag ignored\n(TP > 1 path)"]
    FL -->|"disable_contextual_mask=True"| FNONE["num_contextuals = None"]
    DL -->|"disable_contextual_mask=True"| DNONE["num_contextuals = None"]
    NL -->|"always"| CSEQ["num_contextuals = jd.contextual_seqlen"]
Loading

Comments Outside Diff (1)

  1. examples/hstu/configs/hstu_config.py, line 106-108 (link)

    P2 The new disable_contextual_mask field is not documented in the HSTUConfig docstring, unlike every other field. Adding a short entry keeps the docstring consistent with the rest of the class.

Reviews (1): Last reviewed commit: "Add an example for disabling contextual ..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant