Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ int64_t obtainBlockSize(int64_t waveSize, int64_t mPerBlock, int64_t nPerBlock,
int64_t obtainBlockSize(int64_t waveSize,
RockAccelTuningParamAttrInterface params);

/// Raw convolution dimensions for classifier features.
struct ConvMeta {
int64_t batchN = 1;
int64_t cChannels = 1, kChannels = 1;
int64_t inH = 1, inW = 1;
int64_t filterH = 1, filterW = 1;
int64_t padH = 0, padW = 0;
int64_t strideH = 1, strideW = 1;
int64_t dilH = 1, dilW = 1;
};

/// Store information useful for populating perf configurations
struct PopulateParamsInfo {
GemmSize gemmSize;
Expand All @@ -67,6 +78,7 @@ struct PopulateParamsInfo {
int64_t batchSize;
uint32_t numCu;
bool hasFusedReduction;
std::optional<ConvMeta> convMeta;

PopulateParamsInfo(GemmSize gemmSize, StringRef arch,
GemmFeatures gemmFeatures, Type gemmAType, Type gemmBType,
Expand Down
54 changes: 54 additions & 0 deletions mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningClassifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===- QuickTuningClassifier.h - XGBoost-based perfconfig ranking ---------===//
//
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Copyright (c) 2025 Advanced Micro Devices Inc.
//===----------------------------------------------------------------------===//
//
// This file declares the QuickTuningClassifier, which uses XGBoost models to
// rank quick-tune perfconfigs and select the top-N most likely performant ones
// for a given problem.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ROCK_TUNING_QUICK_TUNING_CLASSIFIER_H
#define MLIR_DIALECT_ROCK_TUNING_QUICK_TUNING_CLASSIFIER_H

#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
#include "mlir/Dialect/Rock/Tuning/GridwiseGemmGemmParams.h"
#include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
#include "llvm/ADT/ArrayRef.h"
#include <vector>

namespace mlir {
namespace rock {

class QuickTuningClassifier {
public:
/// Read ROCMLIR_QUICK_TUNE_TOP_N env var. Default 30, 0 disables classifier.
static unsigned getTopN();

/// Filter XDL/WMMA candidates down to the top-N using the classifier.
/// Returns the full list if no model is found or top-N is 0.
static std::vector<AccelGemmParamsAttr>
filterTopN(const PopulateParamsInfo &info,
llvm::ArrayRef<AccelGemmParamsAttr> candidates);

/// Filter non-accel candidates down to the top-N.
static std::vector<GeneralGemmParamsAttr>
filterTopN(const PopulateParamsInfo &info,
llvm::ArrayRef<GeneralGemmParamsAttr> candidates);

/// Filter gemm-gemm (attention) candidates down to the top-N.
static std::vector<GemmGemmParamsAttr>
filterTopN(RockGemmGemmWrapperInterface op,
llvm::ArrayRef<GemmGemmParamsAttr> candidates);
};

} // namespace rock
} // namespace mlir

#endif // MLIR_DIALECT_ROCK_TUNING_QUICK_TUNING_CLASSIFIER_H
Loading
Loading