|
13 | 13 | import cutlass.pipeline |
14 | 14 | from cutlass._mlir.dialects import llvm |
15 | 15 | from cutlass._mlir import ir |
| 16 | + |
| 17 | +from quack.utils import make_vector |
16 | 18 | from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir |
17 | 19 |
|
18 | 20 |
|
@@ -1029,3 +1031,87 @@ def copy_fn(src_idx, dst_idx, tma_bar_ptr: cute.Pointer): |
1029 | 1031 | tma_gather4_load_fn(smem_ptr, tma_bar_ptr, col_idx, row_indices) |
1030 | 1032 |
|
1031 | 1033 | return copy_fn |
| 1034 | + |
| 1035 | + |
| 1036 | +# --------------------------------------------------------------------------- |
| 1037 | +# Store helpers |
| 1038 | +# --------------------------------------------------------------------------- |
| 1039 | + |
| 1040 | + |
| 1041 | +@dsl_user_op |
| 1042 | +@cute.jit |
| 1043 | +def store( |
| 1044 | + ptr: cute.Pointer, |
| 1045 | + val, |
| 1046 | + pred: Optional[Boolean] = None, |
| 1047 | + cop: cutlass.Constexpr = None, |
| 1048 | + *, |
| 1049 | + loc=None, |
| 1050 | + ip=None, |
| 1051 | +): |
| 1052 | + """Store a scalar value via cute.arch.store. |
| 1053 | +
|
| 1054 | + ptr: cute.Pointer (any address space). |
| 1055 | + val: DSL Numeric value. |
| 1056 | + pred: None → unconditional. DSL Boolean → skipped when pred == 0. |
| 1057 | + cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt". |
| 1058 | + """ |
| 1059 | + if const_expr(pred is None): |
| 1060 | + cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip) |
| 1061 | + else: |
| 1062 | + if pred: |
| 1063 | + cute.arch.store(ptr.llvm_ptr, type(val)(val), cop=cop, loc=loc, ip=ip) |
| 1064 | + |
| 1065 | + |
| 1066 | +@dsl_user_op |
| 1067 | +@cute.jit |
| 1068 | +def store_v2( |
| 1069 | + ptr: cute.Pointer, |
| 1070 | + v0, |
| 1071 | + v1, |
| 1072 | + pred: Optional[Boolean] = None, |
| 1073 | + cop: cutlass.Constexpr = None, |
| 1074 | + *, |
| 1075 | + loc=None, |
| 1076 | + ip=None, |
| 1077 | +): |
| 1078 | + """Vectorized store of 2 elements via cute.arch.store. |
| 1079 | +
|
| 1080 | + Packs v0, v1 into an MLIR <2 x T> vector. |
| 1081 | + ptr: cute.Pointer (any address space, must be aligned for vector width). |
| 1082 | + cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt". |
| 1083 | + """ |
| 1084 | + vec = make_vector(type(v0), v0, v1, loc=loc, ip=ip) |
| 1085 | + if const_expr(pred is None): |
| 1086 | + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) |
| 1087 | + else: |
| 1088 | + if pred: |
| 1089 | + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) |
| 1090 | + |
| 1091 | + |
| 1092 | +@dsl_user_op |
| 1093 | +@cute.jit |
| 1094 | +def store_v4( |
| 1095 | + ptr: cute.Pointer, |
| 1096 | + v0, |
| 1097 | + v1, |
| 1098 | + v2, |
| 1099 | + v3, |
| 1100 | + pred: Optional[Boolean] = None, |
| 1101 | + cop: cutlass.Constexpr = None, |
| 1102 | + *, |
| 1103 | + loc=None, |
| 1104 | + ip=None, |
| 1105 | +): |
| 1106 | + """Vectorized store of 4 elements via cute.arch.store. |
| 1107 | +
|
| 1108 | + Packs v0–v3 into an MLIR <4 x T> vector. |
| 1109 | + ptr: cute.Pointer (any address space, must be aligned for vector width). |
| 1110 | + cop: Cache operator — "wb" (default), "cg", "cs" (streaming), "wt". |
| 1111 | + """ |
| 1112 | + vec = make_vector(type(v0), v0, v1, v2, v3, loc=loc, ip=ip) |
| 1113 | + if const_expr(pred is None): |
| 1114 | + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) |
| 1115 | + else: |
| 1116 | + if pred: |
| 1117 | + cute.arch.store(ptr.llvm_ptr, vec, cop=cop, loc=loc, ip=ip) |
0 commit comments