Skip to content

Commit e8e622e

Browse files
committed
feat: Add tests on duplication
1 parent 314ec6c commit e8e622e

1 file changed

Lines changed: 125 additions & 0 deletions

File tree

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,5 +862,130 @@ def forward(self, src, values, idx0, idx1):
862862
), f"Accumulate broadcast mismatch: max diff = {(result - torch_output).abs().max()}"
863863

864864

865+
# ------------------------------------------------------------------
866+
# Duplicate-index tests for realistic use-case models
867+
# These mirror the scenarios in experiments/bench_index_put_scatter_add.py
868+
# and verify that _index_put_scatter_add correctly accumulates into
869+
# duplicate positions when index_put is embedded in a larger graph.
870+
# ------------------------------------------------------------------
871+
872+
def test_kv_cache_duplicate_slot_writes(self):
873+
"""KV-cache style: linear projection → index_put(accumulate=True) into
874+
a flat cache with duplicate slot indices → output projection.
875+
876+
Multiple writes to the same cache slot must sum, not overwrite.
877+
"""
878+
N, S, D = 6, 16, 32
879+
880+
# Positions with duplicates: slots 2 and 5 are written twice
881+
positions = torch.tensor([0, 2, 2, 5, 5, 7], dtype=torch.int64)
882+
883+
@torch._dynamo.assume_constant_result
884+
def get_positions():
885+
return positions
886+
887+
class KVCacheDupWrites(torch.nn.Module):
888+
def __init__(self):
889+
super().__init__()
890+
self.proj_in = torch.nn.Linear(D, D, bias=False)
891+
self.proj_out = torch.nn.Linear(D, D, bias=False)
892+
893+
def forward(self, tokens, cache):
894+
# tokens: (N, D), cache: (S, D)
895+
feats = self.proj_in(tokens)
896+
cache = cache.index_put((get_positions(),), feats, accumulate=True)
897+
return self.proj_out(cache)
898+
899+
tokens = torch.randn(N, D)
900+
cache = torch.zeros(S, D)
901+
902+
self.run_test(
903+
KVCacheDupWrites().cuda(),
904+
inputs=[tokens, cache],
905+
use_dynamo_tracer=True,
906+
enable_passes=True,
907+
)
908+
909+
def test_sparse_embedding_duplicate_seq_ids(self):
910+
"""Sparse embedding accumulation: embedding lookup → index_put(accumulate=True)
911+
into per-sequence accumulators where many tokens map to the same sequence → ReLU.
912+
913+
Multiple tokens per sequence must be summed, not overwritten.
914+
"""
915+
B, N = 4, 20
916+
917+
# Many tokens map to the same sequence: each seq gets ~5 tokens
918+
seq_ids = torch.tensor(
919+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3],
920+
dtype=torch.int64,
921+
)
922+
923+
@torch._dynamo.assume_constant_result
924+
def get_seq_ids():
925+
return seq_ids
926+
927+
class SparseEmbedAccum(torch.nn.Module):
928+
def __init__(self):
929+
super().__init__()
930+
self.embedding = torch.nn.Embedding(64, 16)
931+
self.head = torch.nn.Linear(16, 8, bias=False)
932+
933+
def forward(self, token_ids, accum):
934+
# token_ids: (N,) int64, accum: (B, 16)
935+
embs = self.embedding(token_ids)
936+
accum = accum.index_put((get_seq_ids(),), embs, accumulate=True)
937+
return self.head(torch.relu(accum))
938+
939+
token_ids = torch.randint(0, 64, (N,))
940+
accum = torch.zeros(B, 16)
941+
942+
self.run_test(
943+
SparseEmbedAccum().cuda(),
944+
inputs=[token_ids, accum],
945+
use_dynamo_tracer=True,
946+
enable_passes=True,
947+
)
948+
949+
def test_histogram_conv_duplicate_bin_ids(self):
950+
"""Histogram accumulation: Conv1d → index_put(accumulate=True) into histogram
951+
bins where many frames land in the same bin → mean pool → linear.
952+
953+
Multiple frames writing to the same bin must accumulate, not overwrite.
954+
"""
955+
C, L, n_bins = 4, 12, 6
956+
957+
# Skewed bin assignment: bins 0 and 1 receive many frames
958+
bin_ids = torch.tensor(
959+
[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 5],
960+
dtype=torch.int64,
961+
)
962+
963+
@torch._dynamo.assume_constant_result
964+
def get_bin_ids():
965+
return bin_ids
966+
967+
class HistConvAccum(torch.nn.Module):
968+
def __init__(self):
969+
super().__init__()
970+
self.conv = torch.nn.Conv1d(C, 8, kernel_size=3, padding=1)
971+
self.head = torch.nn.Linear(8, 4, bias=False)
972+
973+
def forward(self, signal, hist):
974+
# signal: (1, C, L), hist: (n_bins, 8)
975+
feat = self.conv(signal).squeeze(0).T # (L, 8)
976+
hist = hist.index_put((get_bin_ids(),), feat, accumulate=True)
977+
return self.head(hist.mean(dim=0))
978+
979+
signal = torch.randn(1, C, L)
980+
hist = torch.zeros(n_bins, 8)
981+
982+
self.run_test(
983+
HistConvAccum().cuda(),
984+
inputs=[signal, hist],
985+
use_dynamo_tracer=True,
986+
enable_passes=True,
987+
)
988+
989+
865990
if __name__ == "__main__":
866991
run_tests()

0 commit comments

Comments
 (0)