Skip to content

Latest commit

Β 

History

History
107 lines (82 loc) Β· 4.74 KB

File metadata and controls

107 lines (82 loc) Β· 4.74 KB

Hardware β€” 4-Pi Distributed Setup

Ring topology

        β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
        β”‚  Pi 0    │──GPIO──▢ β”‚  Pi 1    β”‚
        β”‚ layers   β”‚          β”‚ layers   β”‚
        β”‚ [0, L/3) β”‚          β”‚[L/3,2L/3)β”‚
        β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
             β–²                      β”‚
             β”‚ GPIO                 β”‚ GPIO
             β”‚                      β–Ό
        β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
        β”‚  Pi 3    │◀──GPIO── β”‚  Pi 2    β”‚
        β”‚ embed +  β”‚          β”‚ layers   β”‚
        β”‚ head     β”‚          β”‚[2L/3, L) β”‚
        β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Forward:  R3 embed β†’ R0 β†’ R1 β†’ R2 β†’ R3 head β†’ argmax
Backward: R3 head  β†’ R2 β†’ R1 β†’ R0 β†’ R3 embed

R3 holds the embedding table and classifier head. R0/R1/R2 hold transformer layers. Each Pi loads only its shard from SD β€” the 110M model (418 MB total) fits across 4 Pis where it wouldn't fit on one.

GPIO wiring

Each link uses 10 pins (8 data + CLK + ACK), half-duplex. Every Pi has a downstream bank (sends to next rank) and an upstream bank (receives from previous rank):

Direction Bank Data CLK ACK
Downstream (β†’ next rank) High GPIO 16–23 24 25
Upstream (← prev rank) Low GPIO 4–11 12 13

Wire Pi N's high bank to Pi N+1's low bank:

     Pi N  (sender)                     Pi N+1 (receiver)
     HIGH BANK                          LOW BANK
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚ GPIO 16 ────┼── D0  ──────────── ─ GPIO 4      β”‚
    β”‚ GPIO 17 ────┼── D1  ──────────── ─ GPIO 5      β”‚
    β”‚ GPIO 18 ────┼── D2  ──────────── ─ GPIO 6      β”‚
    β”‚ GPIO 19 ────┼── D3  ──────────── ─ GPIO 7      β”‚
    β”‚ GPIO 20 ────┼── D4  ──────────── ─ GPIO 8      β”‚
    β”‚ GPIO 21 ────┼── D5  ──────────── ─ GPIO 9      β”‚
    β”‚ GPIO 22 ────┼── D6  ──────────── ─ GPIO 10     β”‚
    β”‚ GPIO 23 ────┼── D7  ──────────── ─ GPIO 11     β”‚
    β”‚ GPIO 24 ────┼── CLK ──────────── ─ GPIO 12     β”‚
    β”‚ GPIO 25 ────┼── ACK ──────────── ─ GPIO 13     β”‚
    β”‚ GND     ────┼── GND ──────────── ─ GND         β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

One byte is transferred per handshake cycle: sender raises CLK when data is on the bus, receiver raises ACK when it has read, sender lowers CLK, receiver lowers ACK. The link is self-clocked β€” no baud rate, no timing constraints.

USB serial ports

Each Pi connects to the laptop over UART for bootloading and log output. Edit devices.conf with your port suffixes:

Pi 0 β†’ /dev/cu.usbserial-<suffix>
Pi 1 β†’ /dev/cu.usbserial-<suffix>
Pi 2 β†’ /dev/cu.usbserial-<suffix>
Pi 3 β†’ /dev/cu.usbserial-<suffix>

Weight sharding

Split a full model into 4 per-rank shard files:

python3 tools/shard_weights.py weights/stories42M.bin  4 weights/shards/42M/
python3 tools/shard_weights.py weights/stories110M.bin 4 weights/shards/110M/

Layer assignment for world_size=4:

Rank Role 42M 110M
R0 Compute layers [0, 3) layers [0, 4)
R1 Compute layers [3, 6) layers [4, 8)
R2 Compute layers [6, 8) layers [8, 12)
R3 Coord embed + head embed + head

R3 holds both the embedding table and the classifier head because they share the same weight matrix (weight tying). Keeping both on one rank avoids a cross-Pi gradient reduction during training.

SD card setup

Each Pi gets its own shard file via initramfs:

bash tools/setup-sd-distributed.sh 0 PIE0 42M   # rank 0
bash tools/setup-sd-distributed.sh 1 PIE1 42M   # rank 1
bash tools/setup-sd-distributed.sh 2 PIE2 42M   # rank 2
bash tools/setup-sd-distributed.sh 3 PIE3 42M   # rank 3

Running

cd examples
./run.sh generate-distributed    # 4-Pi inference
./run.sh train-distributed       # 4-Pi training

Logs stream to examples/logs/pi{0,1,2,3}.log in real-time. Console shows only the head rank (R3) output during inference.