Skip to content

Commit d09bc3f

Browse files
authored
Update all implicit type casts to be explicit (#237)
* Update all implicit type casts to be explicit reusing dtype var where possible * Remove redundant Int() casts around block_dim, block_idx, and thread_idx * Wrap both values in a cast for a single dtype across an operation.
1 parent d2cab32 commit d09bc3f

83 files changed

Lines changed: 546 additions & 536 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

book/src/puzzle_26/puzzle_26.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Learn XOR-based butterfly communication patterns for efficient tree algorithms a
115115
max_val = input[global_i]
116116
offset = WARP_SIZE // 2
117117
while offset > 0:
118-
max_val = max(max_val, shuffle_xor(max_val, offset))
118+
max_val = max(max_val, shuffle_xor(max_val, UInt32(offset)))
119119
offset //= 2
120120
# All lanes now have global maximum
121121
```

book/src/puzzle_26/warp_prefix_sum.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ if global_i < size:
391391
# Butterfly reduction to get total across the warp: dynamic for any WARP_SIZE
392392
offset = WARP_SIZE // 2
393393
while offset > 0:
394-
warp_left_total += shuffle_xor(warp_left_total, offset)
394+
warp_left_total += shuffle_xor(warp_left_total, UInt32(offset))
395395
offset //= 2
396396
397397
# Phase 4: Write to output positions

book/src/puzzle_26/warp_shuffle_xor.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ if global_i < size:
406406
# Butterfly reduction tree: dynamic for any WARP_SIZE
407407
offset = WARP_SIZE // 2
408408
while offset > 0:
409-
max_val = max(max_val, shuffle_xor(max_val, offset))
409+
max_val = max(max_val, shuffle_xor(max_val, UInt32(offset)))
410410
offset //= 2
411411
412412
output[global_i] = max_val # All lanes have global maximum
@@ -599,10 +599,10 @@ if global_i < size:
599599
# Butterfly reduction for both max and min log_2(WARP_SIZE}) steps)
600600
offset = WARP_SIZE // 2
601601
while offset > 0:
602-
neighbor_val = shuffle_xor(current_val, offset)
602+
neighbor_val = shuffle_xor(current_val, UInt32(offset))
603603
current_val = max(current_val, neighbor_val) # Max reduction
604604
605-
min_neighbor_val = shuffle_xor(min_val, offset)
605+
min_neighbor_val = shuffle_xor(min_val, UInt32(offset))
606606
min_val = min(min_val, min_neighbor_val) # Min reduction
607607
608608
offset //= 2
@@ -643,10 +643,10 @@ Final result: All lanes have current_val=7 (global max) and min_val=1 (global mi
643643
```mojo
644644
offset = WARP_SIZE // 2
645645
while offset > 0:
646-
neighbor_val = shuffle_xor(current_val, offset)
646+
neighbor_val = shuffle_xor(current_val, UInt32(offset))
647647
current_val = max(current_val, neighbor_val)
648648
649-
min_neighbor_val = shuffle_xor(min_val, offset)
649+
min_neighbor_val = shuffle_xor(min_val, UInt32(offset))
650650
min_val = min(min_val, min_neighbor_val)
651651
652652
offset //= 2
@@ -710,7 +710,7 @@ The `shuffle_xor()` primitive enables powerful butterfly communication patterns
710710
```mojo
711711
offset = WARP_SIZE // 2
712712
while offset > 0:
713-
neighbor_val = shuffle_xor(current_val, offset)
713+
neighbor_val = shuffle_xor(current_val, UInt32(offset))
714714
current_val = operation(current_val, neighbor_val)
715715
offset //= 2
716716
```

book/src/puzzle_27/block_prefix_sum.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ To classify a `Float32` value into bins:
9292

9393
```mojo
9494
my_value = input_data[global_i][0] # Extract SIMD like in dot product
95-
bin_number = Int(floor(my_value * num_bins))
95+
bin_number = Int(floor(my_value * Float32(num_bins)))
9696
```
9797

9898
**Edge case handling**: Values exactly 1.0 would go to bin `NUM_BINS`, but you only have bins 0 to `NUM_BINS-1`. Use an `if` statement to clamp the maximum bin.
@@ -140,7 +140,7 @@ The last thread (not thread 0!) computes the total count:
140140

141141
```mojo
142142
if local_i == tpb - 1: # Last thread in block
143-
total_count = offset[0] + belongs_to_target # Inclusive = exclusive + own contribution
143+
total_count = offset[0] + Int32(belongs_to_target) # Inclusive = exclusive + own contribution
144144
count_output[0] = total_count
145145
```
146146

@@ -322,7 +322,7 @@ Result: [0.00, 0.01, 0.02, ..., 0.12, ???, ???, ...] // Perfectly packed!
322322
```
323323
Last thread computes total (not thread 0!):
324324
if local_i == tpb - 1: // Thread 127 in our case
325-
total = write_offset[0] + belongs_to_target // Inclusive sum formula
325+
total = write_offset[0] + Int32(belongs_to_target) // Inclusive sum formula
326326
count_output[0] = total
327327
```
328328

book/src/puzzle_29/memory_barrier.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ stencil_count = 0
327327
for neighbor in valid_neighbors:
328328
stencil_sum += buffer[neighbor]
329329
stencil_count += 1
330-
result[i] = stencil_sum / stencil_count
330+
result[i] = stencil_sum / Float32(stencil_count)
331331
```
332332

333333
## **Buffer role alternation**

pixi.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ system-requirements = { macos = "15.0" }
4040
[dependencies]
4141
python = "==3.12"
4242
mojo = "<1.0.0" # includes `mojo-compiler`, lsp, debugger, formatter etc.
43-
max = "==26.3.0.dev2026032405"
43+
max = "==26.3.0.dev2026033105"
4444
bash = ">=5.2.21,<6"
4545
manim = ">=0.18.1,<0.19"
4646
mdbook = ">=0.4.48,<0.5"

problems/p01/p01.mojo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def main() raises:
2929
a.enqueue_fill(0)
3030
with a.map_to_host() as a_host:
3131
for i in range(SIZE):
32-
a_host[i] = i
32+
a_host[i] = Scalar[dtype](i)
3333

3434
ctx.enqueue_function[add_10, add_10](
3535
out,
@@ -43,7 +43,7 @@ def main() raises:
4343
ctx.synchronize()
4444

4545
for i in range(SIZE):
46-
expected[i] = i + 10
46+
expected[i] = Scalar[dtype](i + 10)
4747

4848
with out.map_to_host() as out_host:
4949
print("out:", out_host)

problems/p02/p02.mojo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def main() raises:
3434
expected.enqueue_fill(0)
3535
with a.map_to_host() as a_host, b.map_to_host() as b_host:
3636
for i in range(SIZE):
37-
a_host[i] = i
38-
b_host[i] = i
37+
a_host[i] = Scalar[dtype](i)
38+
b_host[i] = Scalar[dtype](i)
3939
expected[i] = a_host[i] + b_host[i]
4040

4141
ctx.enqueue_function[add, add](

problems/p03/p03.mojo

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ comptime dtype = DType.float32
1313
def add_10_guard(
1414
output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
1515
a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
16-
size: UInt,
16+
size: Int,
1717
):
1818
var i = thread_idx.x
1919
# FILL ME IN (roughly 2 lines)
@@ -30,12 +30,12 @@ def main() raises:
3030
a.enqueue_fill(0)
3131
with a.map_to_host() as a_host:
3232
for i in range(SIZE):
33-
a_host[i] = i
33+
a_host[i] = Scalar[dtype](i)
3434

3535
ctx.enqueue_function[add_10_guard, add_10_guard](
3636
out,
3737
a,
38-
UInt(SIZE),
38+
SIZE,
3939
grid_dim=BLOCKS_PER_GRID,
4040
block_dim=THREADS_PER_BLOCK,
4141
)
@@ -45,7 +45,7 @@ def main() raises:
4545
ctx.synchronize()
4646

4747
for i in range(SIZE):
48-
expected[i] = i + 10
48+
expected[i] = Scalar[dtype](i + 10)
4949

5050
with out.map_to_host() as out_host:
5151
print("out:", out_host)

problems/p04/p04.mojo

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ comptime dtype = DType.float32
1313
def add_10_2d(
1414
output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
1515
a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
16-
size: UInt,
16+
size: Int,
1717
):
1818
var row = thread_idx.y
1919
var col = thread_idx.x
@@ -35,13 +35,13 @@ def main() raises:
3535
# row-major
3636
for i in range(SIZE):
3737
for j in range(SIZE):
38-
a_host[i * SIZE + j] = i * SIZE + j
38+
a_host[i * SIZE + j] = Scalar[dtype](i * SIZE + j)
3939
expected[i * SIZE + j] = a_host[i * SIZE + j] + 10
4040

4141
ctx.enqueue_function[add_10_2d, add_10_2d](
4242
out,
4343
a,
44-
UInt(SIZE),
44+
SIZE,
4545
grid_dim=BLOCKS_PER_GRID,
4646
block_dim=THREADS_PER_BLOCK,
4747
)

0 commit comments

Comments
 (0)