Skip to content

Commit 10015ea

Browse files
author
Xun Wang
committed
change dtype
1 parent 5afcc47 commit 10015ea

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

problems/amd_distributed/gemm-rs/reference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ def generate_input(RANK: int, world_size: int, m: int, n: int, k: int, has_bias:
2222
local_k = k // world_size
2323

2424
# Generate random inputs and weights
25-
input = (torch.rand((m, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01
26-
weight = (torch.rand((n, local_k), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01
25+
input = (torch.rand((m, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01
26+
weight = (torch.rand((n, local_k), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01
2727

2828
bias = None
2929
if has_bias:
3030
gen.manual_seed(seed)
31-
bias = (torch.rand((n,), dtype=torch.float16, device="cuda", generator=gen) * 2 - 1) * 0.01
31+
bias = (torch.rand((n,), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01
3232

3333
return (input, weight, bias)
3434

0 commit comments

Comments
 (0)