|
132 | 132 | "source": [ |
133 | 133 | "---\n", |
134 | 134 | "\n", |
135 | | - "## 练习 2: 性能测试 - BLOCK_SIZE 的影响\n", |
136 | | - "\n", |
137 | | - "**目标**:探索不同 `BLOCK_SIZE` 对性能的影响,找出最优配置\n", |
138 | | - "\n", |
139 | | - "这个练习帮助你理解为什么 Triton 的 `BLOCK_SIZE` 通常比 CUDA 的 `blockDim` 大得多。\n", |
140 | | - "\n", |
141 | | - "**测试方案**:\n", |
142 | | - "- 使用向量加法作为基准测试\n", |
143 | | - "- 测试不同的 `BLOCK_SIZE`: [128, 256, 512, 1024, 2048, 4096]\n", |
144 | | - "- 测量执行时间和内存带宽" |
145 | | - ] |
146 | | - }, |
147 | | - { |
148 | | - "cell_type": "code", |
149 | | - "execution_count": null, |
150 | | - "metadata": {}, |
151 | | - "outputs": [], |
152 | | - "source": [ |
153 | | - "# 向量加法 Kernel(用于性能测试)\n", |
154 | | - "@triton.jit\n", |
155 | | - "def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n", |
156 | | - " pid = tl.program_id(axis=0)\n", |
157 | | - " block_start = pid * BLOCK_SIZE\n", |
158 | | - " offsets = block_start + tl.arange(0, BLOCK_SIZE)\n", |
159 | | - " mask = offsets < n_elements\n", |
160 | | - " x = tl.load(x_ptr + offsets, mask=mask)\n", |
161 | | - " y = tl.load(y_ptr + offsets, mask=mask)\n", |
162 | | - " output = x + y\n", |
163 | | - " tl.store(output_ptr + offsets, output, mask=mask)" |
164 | | - ] |
165 | | - }, |
166 | | - { |
167 | | - "cell_type": "code", |
168 | | - "execution_count": null, |
169 | | - "metadata": {}, |
170 | | - "outputs": [], |
171 | | - "source": [ |
172 | | - "def benchmark_block_size(block_size, x, y, output, warmup=10, repeat=100):\n", |
173 | | - " \"\"\"基准测试单个 BLOCK_SIZE\"\"\"\n", |
174 | | - " n_elements = x.numel()\n", |
175 | | - " grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n", |
176 | | - " \n", |
177 | | - " # Warmup\n", |
178 | | - " for _ in range(warmup):\n", |
179 | | - " add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size)\n", |
180 | | - " \n", |
181 | | - " # Timing\n", |
182 | | - " torch.cuda.synchronize()\n", |
183 | | - " start_event = torch.cuda.Event(enable_timing=True)\n", |
184 | | - " end_event = torch.cuda.Event(enable_timing=True)\n", |
185 | | - " \n", |
186 | | - " start_event.record()\n", |
187 | | - " for _ in range(repeat):\n", |
188 | | - " add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size)\n", |
189 | | - " end_event.record()\n", |
190 | | - " \n", |
191 | | - " torch.cuda.synchronize()\n", |
192 | | - " time_ms = start_event.elapsed_time(end_event) / repeat\n", |
193 | | - " \n", |
194 | | - " # 计算带宽 (读 x, 读 y, 写 output)\n", |
195 | | - " total_bytes = 3 * n_elements * 4 # float32 = 4 bytes\n", |
196 | | - " bandwidth_gb_s = total_bytes / (time_ms * 1e-3) / 1e9\n", |
197 | | - " \n", |
198 | | - " return time_ms, bandwidth_gb_s" |
199 | | - ] |
200 | | - }, |
201 | | - { |
202 | | - "cell_type": "code", |
203 | | - "execution_count": null, |
204 | | - "metadata": {}, |
205 | | - "outputs": [], |
206 | | - "source": [ |
207 | | - "# 运行基准测试\n", |
208 | | - "size = 1024 * 1024 * 10 # 10M elements\n", |
209 | | - "x = torch.randn(size, device='cuda', dtype=torch.float32)\n", |
210 | | - "y = torch.randn(size, device='cuda', dtype=torch.float32)\n", |
211 | | - "output = torch.empty_like(x)\n", |
212 | | - "\n", |
213 | | - "block_sizes = [128, 256, 512, 1024, 2048, 4096]\n", |
214 | | - "results = []\n", |
215 | | - "\n", |
216 | | - "print(f\"{'BLOCK_SIZE':<15} {'Time (ms)':<15} {'Bandwidth (GB/s)':<20}\")\n", |
217 | | - "print(\"-\" * 50)\n", |
218 | | - "\n", |
219 | | - "for bs in block_sizes:\n", |
220 | | - " time_ms, bandwidth = benchmark_block_size(bs, x, y, output)\n", |
221 | | - " results.append((bs, time_ms, bandwidth))\n", |
222 | | - " print(f\"{bs:<15} {time_ms:<15.3f} {bandwidth:<20.2f}\")" |
223 | | - ] |
224 | | - }, |
225 | | - { |
226 | | - "cell_type": "code", |
227 | | - "execution_count": null, |
228 | | - "metadata": {}, |
229 | | - "outputs": [], |
230 | | - "source": [ |
231 | | - "# 可视化结果\n", |
232 | | - "block_sizes_list = [r[0] for r in results]\n", |
233 | | - "bandwidths = [r[2] for r in results]\n", |
234 | | - "\n", |
235 | | - "plt.figure(figsize=(10, 5))\n", |
236 | | - "plt.plot(block_sizes_list, bandwidths, marker='o', linewidth=2, markersize=8)\n", |
237 | | - "plt.xlabel('BLOCK_SIZE', fontsize=12)\n", |
238 | | - "plt.ylabel('Bandwidth (GB/s)', fontsize=12)\n", |
239 | | - "plt.title('Triton BLOCK_SIZE vs Memory Bandwidth', fontsize=14)\n", |
240 | | - "plt.grid(True, alpha=0.3)\n", |
241 | | - "plt.xscale('log', base=2)\n", |
242 | | - "plt.xticks(block_sizes_list, block_sizes_list)\n", |
243 | | - "\n", |
244 | | - "# 标注最佳 BLOCK_SIZE\n", |
245 | | - "best_idx = bandwidths.index(max(bandwidths))\n", |
246 | | - "plt.axvline(x=block_sizes_list[best_idx], color='r', linestyle='--', alpha=0.5)\n", |
247 | | - "plt.text(block_sizes_list[best_idx], max(bandwidths) * 0.95, \n", |
248 | | - " f'Best: {block_sizes_list[best_idx]}', ha='center', fontsize=10, color='r')\n", |
249 | | - "\n", |
250 | | - "plt.tight_layout()\n", |
251 | | - "plt.show()\n", |
252 | | - "\n", |
253 | | - "print(f\"\\n🏆 最优 BLOCK_SIZE: {block_sizes_list[best_idx]}\")\n", |
254 | | - "print(f\"🏆 最高带宽: {max(bandwidths):.2f} GB/s\")" |
255 | | - ] |
256 | | - }, |
257 | | - { |
258 | | - "cell_type": "markdown", |
259 | | - "metadata": {}, |
260 | | - "source": [ |
261 | | - "**思考题**:\n", |
262 | | - "1. 为什么 `BLOCK_SIZE=128` 性能较差?(提示:GPU 利用率)\n", |
263 | | - "2. 为什么 `BLOCK_SIZE=4096` 可能也不理想?(提示:寄存器压力)\n", |
264 | | - "3. 对比 CUDA 的 `blockDim.x` 常用值(256),Triton 的最优 `BLOCK_SIZE` 为什么更大?" |
265 | | - ] |
266 | | - }, |
267 | | - { |
268 | | - "cell_type": "markdown", |
269 | | - "metadata": {}, |
270 | | - "source": [ |
271 | | - "---\n", |
272 | | - "\n", |
273 | | - "## 练习 3: 1D 卷积(挑战)\n", |
| 135 | + "## 练习 2: 1D 卷积(挑战)\n", |
274 | 136 | "\n", |
275 | 137 | "**目标**:实现简单的 1D 卷积(3-tap box filter):$Y[i] = X[i-1] + X[i] + X[i+1]$\n", |
276 | 138 | "\n", |
|
304 | 166 | " # ==================== 在下方编写代码 ====================\n", |
305 | 167 | " \n", |
306 | 168 | " \n", |
307 | | - " \n", |
308 | 169 | " # ========================================================\n", |
309 | 170 | " pass\n", |
310 | 171 | "\n", |
|
351 | 212 | " print(f\"Torch: {y_torch[:5].cpu().numpy()}\")" |
352 | 213 | ] |
353 | 214 | }, |
354 | | - { |
355 | | - "cell_type": "code", |
356 | | - "execution_count": null, |
357 | | - "metadata": {}, |
358 | | - "outputs": [], |
359 | | - "source": [ |
360 | | - "# 可视化卷积效果(可选)\n", |
361 | | - "size = 100\n", |
362 | | - "x = torch.randn(size, device='cuda', dtype=torch.float32)\n", |
363 | | - "y = run_conv1d(x)\n", |
364 | | - "\n", |
365 | | - "plt.figure(figsize=(12, 5))\n", |
366 | | - "plt.plot(x.cpu().numpy(), label='Input', alpha=0.7)\n", |
367 | | - "plt.plot(y.cpu().numpy(), label='Output (Smoothed)', alpha=0.7, linewidth=2)\n", |
368 | | - "plt.xlabel('Index')\n", |
369 | | - "plt.ylabel('Value')\n", |
370 | | - "plt.title('1D Convolution: Box Filter (3-tap)')\n", |
371 | | - "plt.legend()\n", |
372 | | - "plt.grid(True, alpha=0.3)\n", |
373 | | - "plt.tight_layout()\n", |
374 | | - "plt.show()" |
375 | | - ] |
376 | | - }, |
377 | 215 | { |
378 | 216 | "cell_type": "markdown", |
379 | 217 | "metadata": {}, |
380 | 218 | "source": [ |
381 | 219 | "**思考题**(高级):\n", |
382 | 220 | "1. 为什么这种方法效率不高?(提示:重复加载)\n", |
383 | | - "2. 如何优化?(提示:Shared Memory 或加载更大的块然后切片)" |
| 221 | + "2. 如何优化?(提示:加载更大的块然后切片)" |
384 | 222 | ] |
385 | 223 | }, |
386 | 224 | { |
|
391 | 229 | "\n", |
392 | 230 | "## 总结\n", |
393 | 231 | "\n", |
394 | | - "完成这三个练习后,你应该:\n", |
395 | | - "- 掌握了 Triton kernel 的基本写法\n", |
396 | | - "- 理解了 `BLOCK_SIZE` 对性能的重要影响\n", |
397 | | - "- 学会了如何处理复杂的内存访问模式\n", |
| 232 | + "完成这三个练习后,你应该掌握了 Triton kernel 的基本写法\n", |
| 233 | + "\n", |
| 234 | + "**下一步**:学习 Triton 的 Shared Memory 和 Block Reduction 操作!\n", |
| 235 | + "\n", |
| 236 | + "## 课后答案\n", |
| 237 | + "\n", |
| 238 | + "```python\n", |
| 239 | + "@triton.jit\n", |
| 240 | + "def axpy_kernel(\n", |
| 241 | + " x_ptr, y_ptr, z_ptr,\n", |
| 242 | + " n_elements,\n", |
| 243 | + " alpha, # 标量参数\n", |
| 244 | + " BLOCK_SIZE: tl.constexpr\n", |
| 245 | + "):\n", |
| 246 | + " \"\"\"\n", |
| 247 | + " TODO: 实现 AXPY 操作\n", |
| 248 | + " 1. 计算 pid 和 offsets\n", |
| 249 | + " 2. 创建 mask\n", |
| 250 | + " 3. 加载 x 和 y\n", |
| 251 | + " 4. 计算 z = alpha * x + y\n", |
| 252 | + " 5. 存储 z\n", |
| 253 | + " \"\"\"\n", |
| 254 | + " # ==================== 在下方编写代码 ====================\n", |
| 255 | + " pid = tl.program_id(0)\n", |
| 256 | + " offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n", |
| 257 | + " mask = offsets < n_elements\n", |
| 258 | + " x = tl.load(x_ptr + offsets, mask=mask, other=0.0)\n", |
| 259 | + " y = tl.load(y_ptr + offsets, mask=mask, other=0.0)\n", |
| 260 | + " z = alpha * x + y\n", |
| 261 | + " tl.store(z_ptr + offsets, z, mask=mask)\n", |
| 262 | + " # ========================================================\n", |
398 | 263 | "\n", |
399 | | - "**下一步**:学习 Triton 的 Shared Memory 和 Block Reduction 操作!" |
| 264 | + "@triton.jit\n", |
| 265 | + "def conv1d_kernel(\n", |
| 266 | + " x_ptr, y_ptr,\n", |
| 267 | + " n_elements,\n", |
| 268 | + " BLOCK_SIZE: tl.constexpr\n", |
| 269 | + "):\n", |
| 270 | + " \"\"\"\n", |
| 271 | + " TODO: 实现 3-tap 1D 卷积\n", |
| 272 | + " Y[i] = X[i-1] + X[i] + X[i+1]\n", |
| 273 | + " \"\"\"\n", |
| 274 | + " # ==================== 在下方编写代码 ====================\n", |
| 275 | + " pid = tl.program_id(0)\n", |
| 276 | + " offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n", |
| 277 | + " mask = offsets < n_elements\n", |
| 278 | + " \n", |
| 279 | + " x_center = tl.load(x_ptr + offsets, mask=mask, other=0.0)\n", |
| 280 | + " x_left = tl.load(x_ptr + offsets - 1, mask=offsets > 0, other=0.0)\n", |
| 281 | + " x_right = tl.load(x_ptr + offsets + 1, mask=offsets < n_elements - 1, other=0.0)\n", |
| 282 | + " \n", |
| 283 | + " y = x_left + x_center + x_right\n", |
| 284 | + " tl.store(y_ptr + offsets, y, mask=mask)\n", |
| 285 | + " # =========================================================\n", |
| 286 | + "```" |
400 | 287 | ] |
401 | 288 | } |
402 | 289 | ], |
|
0 commit comments