@@ -862,5 +862,130 @@ def forward(self, src, values, idx0, idx1):
862862 ), f"Accumulate broadcast mismatch: max diff = { (result - torch_output ).abs ().max ()} "
863863
864864
865+ # ------------------------------------------------------------------
866+ # Duplicate-index tests for realistic use-case models
867+ # These mirror the scenarios in experiments/bench_index_put_scatter_add.py
868+ # and verify that _index_put_scatter_add correctly accumulates into
869+ # duplicate positions when index_put is embedded in a larger graph.
870+ # ------------------------------------------------------------------
871+
872+ def test_kv_cache_duplicate_slot_writes (self ):
873+ """KV-cache style: linear projection → index_put(accumulate=True) into
874+ a flat cache with duplicate slot indices → output projection.
875+
876+ Multiple writes to the same cache slot must sum, not overwrite.
877+ """
878+ N , S , D = 6 , 16 , 32
879+
880+ # Positions with duplicates: slots 2 and 5 are written twice
881+ positions = torch .tensor ([0 , 2 , 2 , 5 , 5 , 7 ], dtype = torch .int64 )
882+
883+ @torch ._dynamo .assume_constant_result
884+ def get_positions ():
885+ return positions
886+
887+ class KVCacheDupWrites (torch .nn .Module ):
888+ def __init__ (self ):
889+ super ().__init__ ()
890+ self .proj_in = torch .nn .Linear (D , D , bias = False )
891+ self .proj_out = torch .nn .Linear (D , D , bias = False )
892+
893+ def forward (self , tokens , cache ):
894+ # tokens: (N, D), cache: (S, D)
895+ feats = self .proj_in (tokens )
896+ cache = cache .index_put ((get_positions (),), feats , accumulate = True )
897+ return self .proj_out (cache )
898+
899+ tokens = torch .randn (N , D )
900+ cache = torch .zeros (S , D )
901+
902+ self .run_test (
903+ KVCacheDupWrites ().cuda (),
904+ inputs = [tokens , cache ],
905+ use_dynamo_tracer = True ,
906+ enable_passes = True ,
907+ )
908+
909+ def test_sparse_embedding_duplicate_seq_ids (self ):
910+ """Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
911+ into per-sequence accumulators where many tokens map to the same sequence → ReLU.
912+
913+ Multiple tokens per sequence must be summed, not overwritten.
914+ """
915+ B , N = 4 , 20
916+
917+ # Many tokens map to the same sequence: each seq gets ~5 tokens
918+ seq_ids = torch .tensor (
919+ [0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 , 3 ],
920+ dtype = torch .int64 ,
921+ )
922+
923+ @torch ._dynamo .assume_constant_result
924+ def get_seq_ids ():
925+ return seq_ids
926+
927+ class SparseEmbedAccum (torch .nn .Module ):
928+ def __init__ (self ):
929+ super ().__init__ ()
930+ self .embedding = torch .nn .Embedding (64 , 16 )
931+ self .head = torch .nn .Linear (16 , 8 , bias = False )
932+
933+ def forward (self , token_ids , accum ):
934+ # token_ids: (N,) int64, accum: (B, 16)
935+ embs = self .embedding (token_ids )
936+ accum = accum .index_put ((get_seq_ids (),), embs , accumulate = True )
937+ return self .head (torch .relu (accum ))
938+
939+ token_ids = torch .randint (0 , 64 , (N ,))
940+ accum = torch .zeros (B , 16 )
941+
942+ self .run_test (
943+ SparseEmbedAccum ().cuda (),
944+ inputs = [token_ids , accum ],
945+ use_dynamo_tracer = True ,
946+ enable_passes = True ,
947+ )
948+
949+ def test_histogram_conv_duplicate_bin_ids (self ):
950+ """Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
951+ bins where many frames land in the same bin → mean pool → linear.
952+
953+ Multiple frames writing to the same bin must accumulate, not overwrite.
954+ """
955+ C , L , n_bins = 4 , 12 , 6
956+
957+ # Skewed bin assignment: bins 0 and 1 receive many frames
958+ bin_ids = torch .tensor (
959+ [0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 , 3 , 3 , 4 , 5 ],
960+ dtype = torch .int64 ,
961+ )
962+
963+ @torch ._dynamo .assume_constant_result
964+ def get_bin_ids ():
965+ return bin_ids
966+
967+ class HistConvAccum (torch .nn .Module ):
968+ def __init__ (self ):
969+ super ().__init__ ()
970+ self .conv = torch .nn .Conv1d (C , 8 , kernel_size = 3 , padding = 1 )
971+ self .head = torch .nn .Linear (8 , 4 , bias = False )
972+
973+ def forward (self , signal , hist ):
974+ # signal: (1, C, L), hist: (n_bins, 8)
975+ feat = self .conv (signal ).squeeze (0 ).T # (L, 8)
976+ hist = hist .index_put ((get_bin_ids (),), feat , accumulate = True )
977+ return self .head (hist .mean (dim = 0 ))
978+
979+ signal = torch .randn (1 , C , L )
980+ hist = torch .zeros (n_bins , 8 )
981+
982+ self .run_test (
983+ HistConvAccum ().cuda (),
984+ inputs = [signal , hist ],
985+ use_dynamo_tracer = True ,
986+ enable_passes = True ,
987+ )
988+
989+
865990if __name__ == "__main__" :
866991 run_tests ()
0 commit comments