Skip to content

Commit 9fb9c25

Browse files
dunnoconnormodularbot
authored andcommitted
[EDU] Improve cross-puzzle consistency
MODULAR_ORIG_COMMIT_REV_ID: 233276c087fd504f86549164fe577e6b2837c7c2
1 parent b074462 commit 9fb9c25

7 files changed

Lines changed: 145 additions & 16 deletions

File tree

book/src/puzzle_18/puzzle_18.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Our GPU implementation uses parallel reduction for both finding the maximum
3737
value and computing the sum of exponentials, making it highly efficient for
3838
large vectors.
3939

40+
> **Scope:** This puzzle runs in a single block (grid is \\(1 \times 1\\)). Both
41+
> reductions use shared memory and `barrier()` within that one block — there is
42+
> no cross-block communication here, which is why the vector fits in a single
43+
> block's threads.
44+
4045
## Key concepts
4146

4247
- Parallel reduction for efficient maximum and sum calculations

book/src/puzzle_19/puzzle_19.md

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ The computation involves three main steps:
3131
3. **Weighted Sum**: Combine value vectors using attention weights to produce
3232
the final output
3333

34+
> **Scope:** This puzzle composes existing kernels (transpose, tiled matmul,
35+
> softmax) into one attention op for a single query vector. Cross-block
36+
> coordination lives inside each reused kernel — your focus is the transpose
37+
> kernel and the host-side orchestration that connects the pieces.
38+
3439
## Understanding attention: a step-by-step breakdown
3540

3641
Think of attention as a **smart lookup mechanism**. Given a query (what you're
@@ -71,15 +76,22 @@ Step 3: Weights(1,16) @ V(16,16) → Output(1,16) → reshape → Output(16,)
7176

7277
**Key insight**: We reshape the query vector \\(Q\\) from shape \\((16,)\\) to
7378
\\((1,16)\\) so we can use matrix multiplication instead of manual dot products.
74-
This allows us to leverage the highly optimized tiled matmul kernel from Puzzle
75-
18!
79+
This allows us to leverage the highly optimized
80+
[tiled matmul kernel from Puzzle 16](../puzzle_16/tiled.md)!
81+
82+
In Mojo, you reshape a `LayoutTensor` by calling `reshape[new_layout]()` with the
83+
target layout as a compile-time parameter (for example,
84+
`q_tensor.reshape[layout_q_2d]()`) rather than copying or mutating data in place.
85+
You'll see this idiom in the orchestration code below.
7686

7787
Our GPU implementation
78-
**reuses and combines optimized kernels from previous puzzles**:
88+
**reuses and combines optimized kernels, mostly from previous puzzles**:
7989

8090
- **[Tiled matrix multiplication from Puzzle 16](../puzzle_16/puzzle_16.md)**
8191
for efficient \\(Q \cdot K^T\\) and \\(\text{weights} \cdot V\\) operations
82-
- **Shared memory transpose** for computing \\(K^T\\) efficiently
92+
- **[Shared memory transpose](#1-implement-the-transpose-kernel)** for computing
93+
\\(K^T\\) efficiently — this is the one kernel you implement in this puzzle
94+
(see below)
8395
- **[Parallel softmax from Puzzle 18](../puzzle_18/puzzle_18.md)** for
8496
numerically stable attention weight computation
8597

@@ -88,6 +100,13 @@ Our GPU implementation
88100
> Rather than writing everything from scratch, we leverage the
89101
> `matmul_idiomatic_tiled` from Puzzle 16 and `softmax_kernel` from Puzzle 18,
90102
> showcasing the power of modular GPU kernel design.
103+
>
104+
> **Reuse checkpoint**: Before continuing, revisit the kernels you're about to
105+
> compose — `matmul_idiomatic_tiled` in
106+
> [Puzzle 16's tiled solution](../puzzle_16/tiled.md) and `softmax_kernel` in
107+
> [Puzzle 18](../puzzle_18/puzzle_18.md). Treat this puzzle as a
108+
> composition/refactor exercise: your job is to wire these existing building
109+
> blocks together (plus the transpose you write here), not to reinvent them.
91110
92111
## Key concepts
93112

@@ -199,6 +218,22 @@ kernel in the Mojo file using shared memory.
199218

200219
### 2. Orchestrate the attention
201220

221+
So far you've written a single kernel. Attention, however, is a *pipeline* of
222+
kernels: the transpose you just implemented, the tiled matmul from Puzzle 16, the
223+
softmax from Puzzle 18, and a second matmul. **Orchestration** is the host-side
224+
code that runs these kernels in sequence and wires the output of each step into
225+
the input of the next:
226+
227+
```text
228+
K → transpose → Kᵀ → matmul(Q, Kᵀ) → scores → softmax → weights → matmul(weights, V) → output
229+
```
230+
231+
The orchestration function below allocates the intermediate buffers (`Kᵀ`,
232+
`scores`, `weights`), reshapes \\(Q\\) to \\((1, 16)\\) with `reshape[...]()` as
233+
shown above, and enqueues each kernel launch on the GPU. There's no new kernel
234+
math here — the work is choosing buffer layouts and calling the existing kernels
235+
in the right order.
236+
202237
```mojo
203238
{{#include ../../../problems/p19/op/attention.mojo:attention_orchestration}}
204239
```

book/src/puzzle_23/elementwise.md

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ modern GPU programming abstracts low-level details while preserving high
66
performance.
77

88
**Key insight:** _The
9-
[elementwise](https://docs.modular.com/mojo/std/algorithm/functional/elementwise/)
9+
[elementwise](https://mojolang.org/docs/std/algorithm/functional/elementwise/)
1010
function automatically handles thread management, SIMD vectorization, and memory
1111
coalescing for you._
1212

@@ -26,13 +26,24 @@ The mathematical operation is simple element-wise addition:
2626
The implementation covers fundamental patterns applicable to all GPU functional
2727
programming in Mojo.
2828

29+
**Where to start:** You begin from the `elementwise` template in the problem file
30+
— there is no manual shared memory or thread-index math here. The key shift from
31+
earlier puzzles is that each invocation of your nested function processes a whole
32+
SIMD vector, not a single element. That's why you load and store with
33+
`aligned_load[simd_width]` / `store[simd_width]` (vectorized) instead of indexing
34+
one scalar at a time.
35+
2936
## Configuration
3037

3138
- Vector size: `SIZE = 1024`
3239
- Data type: `DType.float32`
3340
- SIMD width: Target-dependent (determined by GPU architecture and data type)
3441
- Layout: `row_major[SIZE]()` (1D row-major)
3542

43+
> **Scope:** This is a single-kernel, per-element operation. The `elementwise`
44+
> abstraction handles thread, block, and grid configuration for you — there is no
45+
> cross-thread or cross-block communication to reason about here.
46+
3647
## Code to complete
3748

3849
```mojo
@@ -53,7 +64,9 @@ The `elementwise` function expects a nested function with this exact signature:
5364
```mojo
5465
@parameter
5566
@always_inline
56-
def your_function[simd_width: Int, rank: Int](indices: IndexList[rank]) capturing -> None:
67+
def your_function[
68+
simd_width: Int, alignment: Int = align_of[dtype]()
69+
](indices: Coord) capturing -> None:
5770
# Your implementation here
5871
```
5972

@@ -65,13 +78,13 @@ def your_function[simd_width: Int, rank: Int](indices: IndexList[rank]) capturin
6578
kernels
6679
- `capturing`: Allows access to variables from the outer scope (the input/output
6780
tensors)
68-
- `IndexList[rank]`: Provides multi-dimensional indexing (rank=1 for vectors,
69-
rank=2 for matrices)
81+
- `Coord`: Carries the per-dimension indices for the current SIMD chunk; use
82+
`indices[0]` for 1D operations
7083

7184
### 2. **Index extraction and SIMD processing**
7285

7386
```mojo
74-
idx = indices[0] # Extract linear index for 1D operations
87+
idx = Int(indices[0].value()) # Extract linear index for 1D operations
7588
```
7689

7790
This `idx` represents the **starting position** for a SIMD vector, not a single
@@ -239,25 +252,27 @@ elementwise[add_function, simd_width, target="gpu"](size, ctx)
239252
```mojo
240253
@parameter
241254
@always_inline
242-
def add[simd_width: Int, rank: Int](indices: IndexList[rank]) capturing -> None:
255+
def add[
256+
simd_width: Int, alignment: Int = align_of[dtype]()
257+
](indices: Coord) capturing -> None:
243258
```
244259

245260
**Parameter Analysis:**
246261

247262
- **`@parameter`**: This decorator provides **compile-time specialization**. The
248-
function is generated separately for each unique `simd_width` and `rank`,
249-
allowing aggressive optimization.
263+
function is generated separately for each unique `simd_width`, allowing
264+
aggressive optimization.
250265
- **`@always_inline`**: Critical for GPU performance - eliminates function call
251266
overhead by embedding the code directly into the kernel.
252267
- **`capturing`**: Enables **lexical scoping** - the inner function can access
253268
variables from the outer scope without explicit parameter passing.
254-
- **`IndexList[rank]`**: Provides **dimension-agnostic indexing** - the same
255-
pattern works for 1D vectors, 2D matrices, 3D tensors, etc.
269+
- **`Coord`**: Carries the per-dimension indices for the SIMD chunk being
270+
processed; `indices[0]` is the linear start position for 1D operations.
256271

257272
### 3. **SIMD execution model deep dive**
258273

259274
```mojo
260-
idx = indices[0] # Linear index: 0, 4, 8, 12... (GPU-dependent spacing)
275+
idx = Int(indices[0].value()) # Linear index: 0, 4, 8, 12... (GPU-dependent spacing)
261276
a_simd = a.aligned_load[simd_width](Index(idx)) # Load: [a[0:4], a[4:8], a[8:12]...] (4 elements per load)
262277
b_simd = b.aligned_load[simd_width](Index(idx)) # Load: [b[0:4], b[4:8], b[8:12]...] (4 elements per load)
263278
ret = a_simd + b_simd # SIMD: 4 additions in parallel (GPU-dependent)

book/src/puzzle_23/vectorize.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
This puzzle explores **advanced vectorization techniques** using manual
66
vectorization and
7-
[vectorize](https://docs.modular.com/mojo/std/algorithm/functional/vectorize/)
7+
[vectorize](https://mojolang.org/docs/std/algorithm/backend/vectorize/vectorize/)
88
that give you precise control over SIMD operations within GPU kernels. You'll
99
implement two different approaches to vectorized computation:
1010

@@ -42,6 +42,10 @@ But with sophisticated vectorization strategies for maximum performance.
4242
- SIMD width: GPU-dependent
4343
- Layout: `row_major[SIZE]()` (1D row-major)
4444

45+
> **Scope:** Both approaches operate within a single tile at a time; bounds
46+
> checking is per-tile and there is no cross-tile or cross-block communication.
47+
> The focus is SIMD control inside a tile, not coordination across them.
48+
4549
## 1. Manual vectorization approach
4650

4751
### Code to complete
@@ -235,6 +239,32 @@ for i in range(tile_size): # i = 0, 1, 2, ..., 31
235239

236240
<div class="solution-tips">
237241

242+
### 0. **From scalar to vectorized**
243+
244+
Start by writing the addition as a plain scalar loop over a tile, then convert it
245+
to `vectorize`. The transformation is mechanical: replace the per-element loop
246+
body with a SIMD load/add/store, and hand the loop to `vectorize`, which calls
247+
your body in `width`-sized steps and processes the leftover remainder for you.
248+
249+
```mojo
250+
# Before: scalar loop over the tile (one element at a time)
251+
for i in range(actual_tile_size):
252+
global_idx = tile_start + i
253+
out_lt[global_idx] = a_lt[global_idx] + b_lt[global_idx]
254+
255+
# After: same logic, but the body operates on a SIMD vector of `width`
256+
def vectorized_add[width: Int](i: Int) {read tile_start, read a_lt, read b_lt, mut out_lt}:
257+
global_idx = tile_start + i
258+
if global_idx + width <= size: # bounds check
259+
a_vec = a_lt.aligned_load[width](Index(global_idx))
260+
b_vec = b_lt.aligned_load[width](Index(global_idx))
261+
out_lt.store[width](Index(global_idx), a_vec + b_vec)
262+
263+
vectorize[simd_width](actual_tile_size, vectorized_add) # drives the loop + remainder
264+
```
265+
266+
The remaining tips break this down piece by piece.
267+
238268
### 1. **Tile boundary calculation**
239269

240270
```mojo

book/src/puzzle_24/warp_sum.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ programming in Mojo.
3535
- Grid configuration: `(1, 1)` blocks per grid
3636
- Layout: `row_major[SIZE]()` (1D row-major)
3737

38+
> **Scope:** This puzzle works within a single warp (`SIZE = WARP_SIZE`). The
39+
> reduction happens across lanes of one warp via `warp.sum()`; there is no
40+
> cross-warp or cross-block reduction here.
41+
3842
## The traditional complexity (from Puzzle 12)
3943

4044
Recall the complex approach from
@@ -55,6 +59,12 @@ memory, barriers, and tree reduction:
5559
This works, but it's verbose, error-prone, and requires deep understanding of
5660
GPU synchronization.
5761

62+
> **Note:** This is intentionally a *different* approach from the
63+
> [Puzzle 12 solution](../../../solutions/p12/p12.mojo). Puzzle 12 uses shared
64+
> memory, `barrier()`, and a tree reduction; this puzzle deliberately replaces
65+
> all of that with a single `warp.sum()`. The code below won't match the P12
66+
> solution line-for-line — that contrast is the point.
67+
5868
**Test the traditional approach:**
5969
<div class="code-tabs" data-tab-group="package-manager">
6070
<div class="tab-buttons">

book/src/puzzle_25/warp_shuffle_down.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ This transforms complex neighbor access patterns into simple warp-level
3737
operations, enabling efficient stencil computations without explicit memory
3838
indexing.
3939

40+
> **Scope:** `shuffle_down()` only moves data *within a warp*. In the multi-block
41+
> section below, each block's warp handles its own boundary lanes independently —
42+
> there is no cross-warp or cross-block data exchange. Lanes at the top of a warp
43+
> simply have no neighbor to read from, which is why boundary handling matters.
44+
4045
## 1. Basic neighbor difference
4146

4247
### Configuration
@@ -393,6 +398,7 @@ boundary lanes of each block.
393398
<div class="tab-buttons">
394399
<button class="tab-button">pixi NVIDIA (default)</button>
395400
<button class="tab-button">pixi AMD</button>
401+
<button class="tab-button">pixi Apple</button>
396402
<button class="tab-button">uv</button>
397403
</div>
398404
<div class="tab-content">
@@ -411,6 +417,13 @@ pixi run -e amd p25 --average
411417
</div>
412418
<div class="tab-content">
413419

420+
```bash
421+
pixi run -e apple p25 --average
422+
```
423+
424+
</div>
425+
<div class="tab-content">
426+
414427
```bash
415428
uv run poe p25 --average
416429
```

book/src/puzzle_26/warp_shuffle_xor.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ This transforms complex parallel algorithms into elegant butterfly communication
3838
patterns, enabling efficient tree reductions and sorting networks without
3939
explicit coordination.
4040

41+
> **Scope:** `shuffle_xor()` exchanges data *within a single warp*. Every
42+
> reduction and butterfly here is per-warp; the results are global only because
43+
> each section runs a single warp over the data. There is no cross-warp or
44+
> cross-block communication.
45+
4146
## 1. Basic butterfly pair swap
4247

4348
### Configuration
@@ -388,6 +393,7 @@ Result: All lanes have global maximum = 7
388393
<div class="tab-buttons">
389394
<button class="tab-button">pixi NVIDIA (default)</button>
390395
<button class="tab-button">pixi AMD</button>
396+
<button class="tab-button">pixi Apple</button>
391397
<button class="tab-button">uv</button>
392398
</div>
393399
<div class="tab-content">
@@ -406,6 +412,13 @@ pixi run -e amd p26 --parallel-max
406412
</div>
407413
<div class="tab-content">
408414

415+
```bash
416+
pixi run -e apple p26 --parallel-max
417+
```
418+
419+
</div>
420+
<div class="tab-content">
421+
409422
```bash
410423
uv run poe p26 --parallel-max
411424
```
@@ -595,6 +608,7 @@ This puzzle uses multiple blocks. Consider how this affects the reduction scope.
595608
<div class="tab-buttons">
596609
<button class="tab-button">pixi NVIDIA (default)</button>
597610
<button class="tab-button">pixi AMD</button>
611+
<button class="tab-button">pixi Apple</button>
598612
<button class="tab-button">uv</button>
599613
</div>
600614
<div class="tab-content">
@@ -613,6 +627,13 @@ pixi run -e amd p26 --conditional-max
613627
</div>
614628
<div class="tab-content">
615629

630+
```bash
631+
pixi run -e apple p26 --conditional-max
632+
```
633+
634+
</div>
635+
<div class="tab-content">
636+
616637
```bash
617638
uv run poe p26 --conditional-max
618639
```

0 commit comments

Comments
 (0)