Skip to content

Commit 05c3478

Browse files
authored
[New Doc] add Triton 内存与数据搬运 (#73)
1 parent 70260cb commit 05c3478

File tree

7 files changed

+1277
-18
lines changed

7 files changed

+1277
-18
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
### Triton 系列 💡
5858

5959
+ [Triton 编程范式入门](./docs/18_triton/01_triton_programming_paradigms/README.md)
60+
+ [Triton 内存和数据传输](./docs/18_triton/02_triton_memory_and_data_movement/README.md)
6061

6162
### LLM 推理技术 🤖
6263

docs/18_triton/01_triton_programming_paradigms/homework.ipynb

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66
"source": [
77
"# Triton 编程范式 - 课后练习\n",
88
"\n",
9-
"本 Notebook 包含三个练习,帮助你巩固 Triton 的核心概念。\n",
10-
"\n",
11-
"**学习目标**:\n",
12-
"- 掌握 Triton 的基本语法和向量化操作\n",
13-
"- 理解 `BLOCK_SIZE` 对性能的影响\n",
14-
"- 学会用向量化方式处理复杂的数据访问模式"
9+
"本 Notebook 包含俩个练习,帮助你巩固 Triton 的核心概念"
1510
]
1611
},
1712
{
@@ -212,15 +207,6 @@
212207
" print(f\"Torch: {y_torch[:5].cpu().numpy()}\")"
213208
]
214209
},
215-
{
216-
"cell_type": "markdown",
217-
"metadata": {},
218-
"source": [
219-
"**思考题**(高级):\n",
220-
"1. 为什么这种方法效率不高?(提示:重复加载)\n",
221-
"2. 如何优化?(提示:加载更大的块然后切片)"
222-
]
223-
},
224210
{
225211
"cell_type": "markdown",
226212
"metadata": {},
@@ -229,9 +215,9 @@
229215
"\n",
230216
"## 总结\n",
231217
"\n",
232-
"完成这三个练习后,你应该掌握了 Triton kernel 的基本写法\n",
218+
"完成这两个练习后,你应该掌握了 Triton kernel 的基本写法\n",
233219
"\n",
234-
"**下一步**:学习 Triton 的 Shared Memory 和 Block Reduction 操作!\n",
220+
"**下一步**:学习 Triton 的内存与数据搬运\n",
235221
"\n",
236222
"## 课后答案\n",
237223
"\n",
@@ -277,7 +263,7 @@
277263
" mask = offsets < n_elements\n",
278264
" \n",
279265
" x_center = tl.load(x_ptr + offsets, mask=mask, other=0.0)\n",
280-
" x_left = tl.load(x_ptr + offsets - 1, mask=offsets > 0, other=0.0)\n",
266+
" x_left = tl.load(x_ptr + offsets - 1, mask=mask & (offsets > 0), other=0.0)\n",
281267
" x_right = tl.load(x_ptr + offsets + 1, mask=offsets < n_elements - 1, other=0.0)\n",
282268
" \n",
283269
" y = x_left + x_center + x_right\n",

0 commit comments

Comments
 (0)