Add an example for disabling contextual mask in training#395
Add an example for disabling contextual mask in training#395geoffreyQiu wants to merge 1 commit into
Conversation
Greptile SummaryThis PR introduces a
Confidence Score: 3/5Safe to merge for single-GPU or non-TP training, but the flag is silently ineffective when tensor parallelism is enabled. The change works correctly for the FUSED and DEBUG layer paths, but NativeHSTULayer — selected automatically when tensor_model_parallel_size > 1 — was not updated. A user who sets disable_contextual_mask = True with TP enabled will receive no error or warning, yet the contextual mask will still be applied, producing different model behavior than expected. examples/hstu/modules/native_hstu_layer.py needs the same disable_contextual_mask guard that was applied to the fused and debug layers. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
GIN["kuairand_1k_ranking.gin\nNetworkArgs.disable_contextual_mask = True"] --> NA["NetworkArgs\ndisable_contextual_mask: bool"]
NA --> CU["create_hstu_config (utils.py)"]
CU --> HC["HSTUConfig\ndisable_contextual_mask: bool"]
HC --> FL["FusedHSTULayer\n✅ flag respected"]
HC --> DL["debug HSTULayer\n✅ flag respected"]
HC --> NL["NativeHSTULayer\n❌ flag ignored\n(TP > 1 path)"]
FL -->|"disable_contextual_mask=True"| FNONE["num_contextuals = None"]
DL -->|"disable_contextual_mask=True"| DNONE["num_contextuals = None"]
NL -->|"always"| CSEQ["num_contextuals = jd.contextual_seqlen"]
|
Description
disable_contextual_maskinHSTUConfigto compute HSTU attention kernel without contextual mask, with respect to Meta implementation.Checklist