Skip to content

Commit 41fb9bc

Browse files
cyanguwapre-commit-ci[bot]ksivaman
authored
[PyTorch] fix test_current_device test (NVIDIA#2398)
* fix test_current_device Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 05bfa3f commit 41fb9bc

1 file changed

Lines changed: 50 additions & 10 deletions

File tree

tests/pytorch/distributed/test_sanity.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,29 @@ def test_current_device(model, module):
5555
self_attn_mask_type="padding",
5656
device=f"cuda:{tensor_device}",
5757
)
58-
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
58+
seqlens_q = torch.randint(
59+
1,
60+
config.max_seqlen_q,
61+
[config.batch_size],
62+
dtype=torch.int32,
63+
device=f"cuda:{tensor_device}",
64+
)
65+
cu_seqlens_q = torch.zeros(
66+
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
67+
)
68+
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
69+
seqlens_kv = torch.randint(
70+
1,
71+
config.max_seqlen_kv,
72+
[config.batch_size],
73+
dtype=torch.int32,
74+
device=f"cuda:{tensor_device}",
75+
)
76+
cu_seqlens_kv = torch.zeros(
77+
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
78+
)
79+
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
80+
num_tokens = cu_seqlens_q[-1]
5981
args = [
6082
torch.randn(
6183
(num_tokens, config.hidden_size),
@@ -64,9 +86,6 @@ def test_current_device(model, module):
6486
requires_grad=True,
6587
)
6688
]
67-
cu_seqlens_q, cu_seqlens_kv = [
68-
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
69-
]
7089
kwargs["cu_seqlens_q"] = cu_seqlens_q
7190
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
7291
kwargs["max_seqlen_q"] = config.max_seqlen_q
@@ -75,26 +94,47 @@ def test_current_device(model, module):
7594
model = DotProductAttention(
7695
config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding"
7796
)
78-
num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item()
97+
seqlens_q = torch.randint(
98+
1,
99+
config.max_seqlen_q,
100+
[config.batch_size],
101+
dtype=torch.int32,
102+
device=f"cuda:{tensor_device}",
103+
)
104+
cu_seqlens_q = torch.zeros(
105+
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
106+
)
107+
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
108+
seqlens_kv = torch.randint(
109+
1,
110+
config.max_seqlen_kv,
111+
[config.batch_size],
112+
dtype=torch.int32,
113+
device=f"cuda:{tensor_device}",
114+
)
115+
cu_seqlens_kv = torch.zeros(
116+
config.batch_size + 1, dtype=torch.int32, device=f"cuda:{tensor_device}"
117+
)
118+
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
119+
num_tokens = cu_seqlens_q[-1]
79120
args = [
80121
torch.randn(
81122
num_tokens,
82123
config.num_heads,
83124
config.head_dim_qk,
84125
dtype=dtype,
85-
device=tensor_device,
126+
device=f"cuda:{tensor_device}",
86127
requires_grad=True,
87128
)
88129
for _ in range(3)
89130
]
90-
cu_seqlens_q, cu_seqlens_kv = [
91-
torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2)
92-
]
93131
kwargs["cu_seqlens_q"] = cu_seqlens_q
94132
kwargs["cu_seqlens_kv"] = cu_seqlens_kv
95133
kwargs["max_seqlen_q"] = config.max_seqlen_q
96134
kwargs["max_seqlen_kv"] = config.max_seqlen_kv
97-
bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)]
135+
bwd_args = [
136+
torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=f"cuda:{tensor_device}")
137+
]
98138
elif module == "Linear":
99139
model = Linear(
100140
config.hidden_size,

0 commit comments

Comments
 (0)