Skip to content

Commit 8636b38

Browse files
Merge remote-tracking branch 'origin/develop' into dpp-refactor-blockwise-reduce
2 parents 9b206aa + d2d62f3 commit 8636b38

50 files changed

Lines changed: 3236 additions & 1911 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/ci.yml

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Python Lint and Format Check
1+
name: rocMLIR GitHub Actions
22

33
on:
44
pull_request:
@@ -9,6 +9,8 @@ on:
99
- "mlir/**"
1010
- "external/**"
1111
- "!external/llvm-project/**"
12+
- ".github/workflows/**"
13+
- "pip_requirements.txt"
1214
push:
1315
branches:
1416
- develop
@@ -17,8 +19,11 @@ on:
1719
- "mlir/**"
1820
- "external/**"
1921
- "!external/llvm-project/**"
22+
- ".github/workflows/**"
23+
- "pip_requirements.txt"
2024
jobs:
21-
py-checks:
25+
format-and-lint-checks:
26+
name: Python format and lint checks
2227
runs-on: ubuntu-latest
2328
container:
2429
image: python:3.10
@@ -95,3 +100,30 @@ jobs:
95100
- name: No Python changes in mlir/
96101
if: steps.changes.outputs.files == ''
97102
run: echo "No changed *.py files under mlir/ – skipping."
103+
104+
python-tests:
105+
name: Python performance script tests
106+
runs-on: ubuntu-latest
107+
container:
108+
image: python:3.10
109+
options: --user root
110+
steps:
111+
- uses: actions/checkout@v4
112+
with:
113+
fetch-depth: 0
114+
115+
- name: Fix git ownership
116+
run: |
117+
git config --global --add safe.directory "$GITHUB_WORKSPACE"
118+
119+
- name: Install dependencies
120+
run: |
121+
python -m pip install --upgrade pip
122+
pip install -r pip_requirements.txt
123+
124+
- name: Run performance script tests (no GPU)
125+
run: |
126+
cd mlir/utils/performance && python -m pytest tests/ -v
127+
env:
128+
# Tests mock HIP/GPU; no ROCm or GPU required
129+
PYTHONPATH: ${{ github.workspace }}/mlir/utils/performance

mlir/include/mlir/Conversion/MIGraphXToLinalg/MIGraphXToLinalg.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ void populateMIGraphXToLinalgBoundaryDialectConversion(
3939
/// migraphx.mlir.as_logical_shape and migraphx.mlir.as_underlying_shape.
4040
void populateMIGraphXFuncBoundaryToLinalgConversionPatterns(
4141
RewritePatternSet &target, TypeConverter &typeConverter);
42+
43+
/// Populates conversion patterns for function boundaries mhal.launcher
44+
void populateMIGraphXToLinalgMHALLauncherConversion(
45+
RewritePatternSet &target, TypeConverter &typeConverter);
4246
} // namespace migraphx
4347
} // namespace mlir
4448

mlir/include/mlir/Conversion/RocMLIRPasses.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def MIGraphXToLinalgPass : Pass<"migraphx-to-linalg", "::mlir::func::FuncOp"> {
144144
}];
145145

146146
let dependentDialects = ["arith::ArithDialect", "tensor::TensorDialect",
147-
"linalg::LinalgDialect"];
147+
"linalg::LinalgDialect", "rock::RockDialect"];
148148
}
149149

150150
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ enum class MfmaTypeId : uint32_t {
3131
Fp8Fp8TyId,
3232
Fp8Bf8TyId,
3333
Bf8Fp8TyId,
34-
Bf8Bf8TyId
34+
Bf8Bf8TyId,
35+
// FP8 via scaled MFMA (uses mfma_scale_f32_16x16x128_f8f6f4 with cbsz=0)
36+
// These provide larger K dimension (128 for 16x16, 64 for 32x32)
37+
Fp8Fp8ScaledTyId,
38+
Fp8Bf8ScaledTyId,
39+
Bf8Fp8ScaledTyId,
40+
Bf8Bf8ScaledTyId
3541
};
3642

3743
struct MfmaInsnInfo {
@@ -71,7 +77,8 @@ class MfmaInsn {
7177
MfmaInsnAttr getAttr() const;
7278
Type getArgTypeFor(Type elementTypeA);
7379
VectorType getRetType(Type elementType);
74-
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock);
80+
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock,
81+
int64_t scheduleVersion);
7582
};
7683

7784
template <typename T>
@@ -138,7 +145,8 @@ class MfmaInsnGroup {
138145
public:
139146
static FailureOr<MfmaInsnGroup> select(Type elementTypeA, Type elementTypeB,
140147
StringRef arch, int64_t mnPerXdl,
141-
int64_t kPack, int64_t kPackPerBlock);
148+
int64_t kPack, int64_t kPackPerBlock,
149+
int64_t scheduleVersion);
142150
MfmaInsnGroup(Type elementTypeA, Type elementTypeB, const MfmaInsn &insn,
143151
const MfmaInsnGroupAttr &groupAttr);
144152
int64_t getMRepeats(int64_t mPerWave);
@@ -150,8 +158,13 @@ class MfmaInsnGroup {
150158
Type getArgTypeA();
151159
Type getArgTypeB();
152160
VectorType getRetType();
153-
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock);
161+
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock,
162+
int64_t scheduleVersion);
154163
SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; }
164+
165+
// Check if this is FP8 using scaled MFMA (mfma_scale with cbsz=0, blgp=0)
166+
// These instructions have larger K dimension (128 for 16x16, 64 for 32x32)
167+
bool isScaledFp8() const;
155168
};
156169

157170
} // namespace rock

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ def ConvOpBwdWeightType : I32EnumAttrCase<"BwdWeight", 2, "conv_bwd_weight">;
5959
def ConvOpTypes : Rock_I32Enum<"ConvOpType", "The type of a convolution operation",
6060
[ConvOpType, ConvOpBwdDataType, ConvOpBwdWeightType]>;
6161

62+
/// LinalgConvType
63+
def LinalgConv_1D : I32EnumAttrCase<"Conv1dNgchGkch", 0, "conv1d_ngch_gkch">;
64+
def LinalgConv_2D
65+
: I32EnumAttrCase<"Conv2dNgchwGkchw", 1, "conv2d_ngchw_gkchw">;
66+
def LinalgConv_3D
67+
: I32EnumAttrCase<"Conv3dNgchwdGkchwd", 2, "conv3d_ngchwd_gkchwd">;
68+
69+
def LinalgConvType
70+
: Rock_I32Enum<"LinalgConvType",
71+
"Hints for the linalg.generic convolution ops used by "
72+
"linalg-to-rock lowering",
73+
[LinalgConv_1D, LinalgConv_2D, LinalgConv_3D]>;
74+
75+
def LinalgConvTypeAttr
76+
: EnumAttr<Rock_Dialect, LinalgConvType, "LinalgConvType">;
77+
6278
/// Kerneltype
6379
def KernelTypeConv : I32EnumAttrCase<"Conv", 0>;
6480
def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>;

0 commit comments

Comments
 (0)