diff --git a/mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc b/mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc index 18a48591ccb8..37e9f2809aa7 100644 --- a/mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc +++ b/mlir/include/mlir/Dialect/Rock/Tuning/QuickTuningPerfconfigs.inc @@ -3796,6 +3796,46 @@ const StringRef PopulateParamsGemmGemm::initParametersI8AttentionGfx1152[] = { }; // END_ATTENTION_GemmGemm_i8_gfx1152_DEFS +// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx908_DEFS +const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx908[] = { + "attn:v3:128,128,16,16,16,16,16,8,4,1,2,0,1", + "attn:v3:128,128,32,32,32,16,16,4,4,1,2,0,1", + "attn:v3:128,128,128,32,32,32,16,4,4,1,2,0,1", + "attn:v3:128,128,32,8,16,32,16,16,4,1,2,0,1", + "attn:v3:32,256,16,8,16,16,16,8,1,1,2,0,1" +}; +// END_GEMM_GEMM_GemmGemm_f16_gfx908_DEFS + +// BEGIN_GEMM_GEMM_GemmGemm_f32_gfx908_DEFS +const StringRef PopulateParamsGemmGemm::initParametersF32GemmGemmGfx908[] = { + "attn:v3:64,128,32,16,16,16,16,4,4,1,2,0,1", + "attn:v3:64,128,32,32,16,16,16,4,4,1,2,0,1", + "attn:v3:64,64,128,8,64,16,16,8,4,1,2,0,1" +}; +// END_GEMM_GEMM_GemmGemm_f32_gfx908_DEFS + +// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1200_DEFS +const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1200[] = { + "attn:v3:128,128,64,32,16,16,16,4,4,1,2,0,1", + "attn:v3:64,64,32,16,16,16,16,8,4,2,2,0,1", + "attn:v3:128,128,16,4,16,16,16,8,4,1,2,0,1", + "attn:v3:128,128,32,4,16,16,16,8,1,1,2,0,1", + "attn:v3:128,128,32,8,16,16,16,8,1,1,2,0,1", + "attn:v3:128,128,64,16,32,32,16,8,1,1,2,0,1", + "attn:v3:32,32,32,8,16,16,16,8,1,2,2,0,1", + "attn:v3:64,256,32,16,16,16,16,8,1,1,2,0,1" +}; +// END_GEMM_GEMM_GemmGemm_f16_gfx1200_DEFS + +// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1100_DEFS +const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1100[] = { + "attn:v3:128,128,16,16,16,16,16,8,1,1,2,0,1", + "attn:v3:128,128,16,4,16,16,16,8,1,1,2,0,1", + "attn:v3:64,256,32,16,16,16,16,8,1,1,2,0,1", + "attn:v3:128,128,16,8,32,16,16,8,1,1,2,0,1", + "attn:v3:32,64,256,16,32,64,16,16,1,1,2,0,1" +}; +// END_GEMM_GEMM_GemmGemm_f16_gfx1100_DEFS // BEGIN_ATTENTION_GemmGemm_bf16_gfx1103_DEFS const StringRef PopulateParamsGemmGemm::initParametersBf16AttentionGfx1103[] = { "attn:v3:32,64,128,2,32,32,16,8,1,1,2,0,1", @@ -4016,6 +4056,25 @@ static constexpr size_t nInitParametersI8AttentionGfx1152 = 10; static const StringRef initParametersI8AttentionGfx1152[nInitParametersI8AttentionGfx1152]; // END_ATTENTION_GemmGemm_i8_gfx1152_DECS +// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx908_DECS +static constexpr size_t nInitParametersF16GemmGemmGfx908 = 5; +static const StringRef initParametersF16GemmGemmGfx908[nInitParametersF16GemmGemmGfx908]; +// END_GEMM_GEMM_GemmGemm_f16_gfx908_DECS + +// BEGIN_GEMM_GEMM_GemmGemm_f32_gfx908_DECS +static constexpr size_t nInitParametersF32GemmGemmGfx908 = 3; +static const StringRef initParametersF32GemmGemmGfx908[nInitParametersF32GemmGemmGfx908]; +// END_GEMM_GEMM_GemmGemm_f32_gfx908_DECS + +// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1200_DECS +static constexpr size_t nInitParametersF16GemmGemmGfx1200 = 8; +static const StringRef initParametersF16GemmGemmGfx1200[nInitParametersF16GemmGemmGfx1200]; +// END_GEMM_GEMM_GemmGemm_f16_gfx1200_DECS + +// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1100_DECS +static constexpr size_t nInitParametersF16GemmGemmGfx1100 = 5; +static const StringRef initParametersF16GemmGemmGfx1100[nInitParametersF16GemmGemmGfx1100]; +// END_GEMM_GEMM_GemmGemm_f16_gfx1100_DECS // BEGIN_ATTENTION_GemmGemm_bf16_gfx1103_DECS static constexpr size_t nInitParametersBf16AttentionGfx1103 = 13; static const StringRef initParametersBf16AttentionGfx1103[nInitParametersBf16AttentionGfx1103]; @@ -4295,6 +4354,14 @@ static const StringRef initParametersI8AttentionGfx1103[nInitParametersI8Attenti {"gfx1152_attention_i8", {PopulateParamsGemmGemm::initParametersI8AttentionGfx1152, PopulateParamsGemmGemm::nInitParametersI8AttentionGfx1152}}, +{"gfx908_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx908, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx908}}, + +{"gfx908_gemmelementwisegemm_f32", {PopulateParamsGemmGemm::initParametersF32GemmGemmGfx908, PopulateParamsGemmGemm::nInitParametersF32GemmGemmGfx908}}, + +{"gfx1200_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1200, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx1200}}, + +{"gfx1100_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1100, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx1100}}, + {"gfx1103_attention_bf16", {PopulateParamsGemmGemm::initParametersBf16AttentionGfx1103, PopulateParamsGemmGemm::nInitParametersBf16AttentionGfx1103}}, {"gfx1103_attention_f16", {PopulateParamsGemmGemm::initParametersF16AttentionGfx1103, PopulateParamsGemmGemm::nInitParametersF16AttentionGfx1103}}, diff --git a/mlir/utils/performance/analysis/quickTuningGen.py b/mlir/utils/performance/analysis/quickTuningGen.py index ed0b745363a5..96f86f1d90fc 100644 --- a/mlir/utils/performance/analysis/quickTuningGen.py +++ b/mlir/utils/performance/analysis/quickTuningGen.py @@ -24,6 +24,11 @@ 'TransQ', 'TransK', 'TransV', 'TransO', 'Causal', 'ReturnLSE', 'SplitKV', 'WithAttnScale', 'WithAttnBias', 'G', 'SeqLenQ', 'SeqLenK', 'NumHeadsQ', 'NumHeadsKV', 'HeadDimQK', 'HeadDimV' ] +GEMM_GEMM_COLUMNS = ['TransA', 'TransB', 'TransC', 'TransO', 'G', 'M', 'K', 'N', 'O'] +CONV_GEMM_COLUMNS = [ + 'FilterLayout', 'InputLayout', 'TransC', 'TransO', 'N', 'C', 'H', 'W', 'K', 'Y', 'X', + 'DilationH', 'DilationW', 'StrideH', 'StrideW', 'PaddingH', 'PaddingW', 'O' +] # Regex pattern for lookup table entries: {"arch_op_dtype", {Class::params, Class::count}}, // optional comment LOOKUP_ENTRY_PATTERN = re.compile(r'\{("(gfx\w+)_(\w+)_(\w+)"),\s*(\{[^}]+\})\},(\s*//[^\n]*)?') @@ -35,7 +40,9 @@ def get_instruction_type(arch, dtype, op): """Determine instruction type based on architecture, data type, and operation.""" - if op == "attention": + if op in ("attention", "gemm_gemm", "conv_gemm"): + if op == "gemm_gemm" and arch.startswith("gfx1") and dtype == "f32": + return "NonAccel" return "GemmGemm" if arch.startswith("gfx9"): return "XDL" @@ -55,9 +62,15 @@ def get_class_name(arch, dtype, op): return f"PopulateParams{instr}" if instr != "NonAccel" else "PopulateParams" +def _op_cap_for_param_name(op): + """Format op for C++ param name: gemm_gemm -> GemmGemm, attention -> Attention.""" + return "".join(part.capitalize() for part in op.split("_")) + + def get_param_names(arch, dtype, op): """Generate array and count variable names.""" - base = f"initParameters{dtype.capitalize()}{op.capitalize()}{arch.capitalize()}" + op_cap = _op_cap_for_param_name(op) + base = f"initParameters{dtype.capitalize()}{op_cap}{arch.capitalize()}" return base, f"n{base[0].upper()}{base[1:]}" @@ -69,6 +82,10 @@ def get_target_columns(op): return CONV_COLUMNS elif op == "attention": return ATTENTION_COLUMNS + elif op == "gemm_gemm": + return GEMM_GEMM_COLUMNS + elif op == "conv_gemm": + return CONV_GEMM_COLUMNS else: raise ValueError(f"Unknown operation: {op}") @@ -302,9 +319,23 @@ def add_lookup_entry(content, insert_marker, entry): return content[:insert_pos] + f'{entry}\n\n' + content[insert_pos:] +def get_lookup_key_op(op): + """Return the operation key used in the C++ lookup table (matches stringifyEnum(KernelType).lower()).""" + # C++ KernelType enum: Attention, GemmElementwiseGemm, ConvElementwiseGemm -> lower() + key_map = { + "attention": "attention", + "gemm_gemm": "gemmelementwisegemm", + "conv_gemm": "convelementwisegemm" + } + return key_map.get(op, op) + + def get_lookup_endif(arch, op, dtype): """Get the appropriate lookup table #endif marker.""" - if op == "attention": + # op may be script name (gemm_gemm) or C++ key form (gemmelementwisegemm) from .inc + gemm_gemm_ops = ("attention", "gemm_gemm", "conv_gemm", "gemmelementwisegemm", + "convelementwisegemm") + if op in gemm_gemm_ops: return "#endif // GemmGemm_LOOKUP_TABLE_GEN" elif is_accel(arch, dtype, op): return "#endif // Accel_LOOKUP_TABLE_GEN" @@ -350,7 +381,7 @@ def update_inc_file(results, arch, op): # Add lookup entry endif_marker = get_lookup_endif(arch, op, dtype) - key = f"{arch}_{op}_{dtype}" + key = f"{arch}_{get_lookup_key_op(op)}_{dtype}" value = f"{{{class_name}::{param_name}, {class_name}::{count_name}}}" entry = f'{{"{key}", {value}}},' content = add_lookup_entry(content, endif_marker, entry) @@ -449,7 +480,9 @@ def main(args=None): nargs='*', metavar='FILE', help='.debug files produced by tuningRunner.py (reads TSV from stdin if none provided)') - parser.add_argument('--op', choices=['gemm', 'conv', 'attention'], help='Operation') + parser.add_argument('--op', + choices=['gemm', 'conv', 'attention', 'gemm_gemm', 'conv_gemm'], + help='Operation') parser.add_argument('--th', type=float, default=0.93, diff --git a/mlir/utils/performance/perfRunner.py b/mlir/utils/performance/perfRunner.py index eb5716f5d063..7af1ceade2da 100644 --- a/mlir/utils/performance/perfRunner.py +++ b/mlir/utils/performance/perfRunner.py @@ -1784,8 +1784,8 @@ def run_config_with_mlir(config: PerfConfiguration, '--entry-point-result=void' ] profiler_cmd = [ROCPROF] + get_metric_args_for_rocprof(arch) + [ - '--kernel-trace', '--stats', '-f', 'csv', '-o', BENCHMARKING_RESULT_FILE_NAME, '--', - paths.mlir_paths.cpu_runner_path + '--kernel-trace', '--stats', '--output-format', 'csv', '-o', + BENCHMARKING_RESULT_FILE_NAME, '--', paths.mlir_paths.cpu_runner_path ] + mlir_cpu_runner_args outs, noerr = run_pipeline([rocmlir_gen_cmd.split(), rocmlir_driver_cmd, profiler_cmd]) @@ -2068,7 +2068,7 @@ def run_fusion_kernel(filename, rocmlir_gen_args, paths: Paths): '--entry-point-result=void' ] profiler_cmd = [ROCPROF] + get_metric_args_for_rocprof(chip) + [ - '--kernel-trace', '--stats', '-f', 'csv', '-o', BENCHMARKING_RESULT_FILE_NAME + '--kernel-trace', '--stats', '--output-format', 'csv', '-o', BENCHMARKING_RESULT_FILE_NAME ] + ['--', paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args commands.append(profiler_cmd) outs, noerr = run_pipeline(commands) diff --git a/mlir/utils/performance/tuningRunner.py b/mlir/utils/performance/tuningRunner.py index 72b901e3bd64..69403ac2bc99 100755 --- a/mlir/utils/performance/tuningRunner.py +++ b/mlir/utils/performance/tuningRunner.py @@ -183,12 +183,33 @@ class Options: timeout: Optional[int] +def _is_navi_arch(arch: str) -> bool: + """Return True if arch is Navi (gfx11xx or gfx12xx).""" + return arch.startswith("gfx1") + + +# Operations that have no f32 tuning support on Navi (gfx11xx/gfx12xx) - empty tuning range +_F32_NAVI_UNSUPPORTED_OPS = frozenset( + {'GemmGemmConfiguration', 'ConvGemmConfiguration', 'AttentionConfiguration'}) + + +def _should_skip_f32_on_navi(arch: str, test_vector: str, conf_class: type) -> bool: + """Return True if this op is f32 on Navi and has no tuning support (empty range).""" + if conf_class.__name__ not in _F32_NAVI_UNSUPPORTED_OPS: + return False + if not _is_navi_arch(arch): + return False + # Match -t f32 in the test vector (e.g. "-t f32 -transA" or " -t f32 ") + return '-t f32' in test_vector + + @dataclass class TuningResult: """Result of tuning a single configuration.""" test_vector: str success: bool timed_out: bool = False + skipped: bool = False gpu_id: int = -1 duration_seconds: float = 0.0 timestamp: Optional[str] = None @@ -500,6 +521,12 @@ def set_succeeded(self, test_vector: str) -> None: self._state.remove(test_vector) self._save_locked() + def remove(self, test_vector: str) -> None: + """Remove test_vector from state (e.g. when skipping without marking failed).""" + with self._lock: + self._state.remove(test_vector) + self._save_locked() + def finalize_interrupted(self) -> None: """Mark RUNNING configs as INTERRUPTED on clean shutdown.""" with self._lock: @@ -715,11 +742,14 @@ class ETATracker: success_times: List[float] = field(default_factory=list) ok_count: int = 0 fail_count: int = 0 + skip_count: int = 0 _processed: int = field(default=0, init=False) def record(self, result: TuningResult) -> None: self._processed += 1 - if result.success: + if result.skipped: + self.skip_count += 1 + elif result.success: self.ok_count += 1 self.success_times.append(result.duration_seconds) else: @@ -760,7 +790,10 @@ def get_postfix_str(self) -> str: rate = self._format_rate(median) eta = self._format_eta(eta_seconds) - return f"ok={self.ok_count}, fail={self.fail_count}, rate={rate}, eta={eta}" + postfix = f"ok={self.ok_count}, fail={self.fail_count}" + if self.skip_count > 0: + postfix += f", skip={self.skip_count}" + return f"{postfix}, rate={rate}, eta={eta}" @dataclass @@ -1153,7 +1186,7 @@ def verify_perfconfig(perfconfig: str, config: PerfConfiguration, paths: Paths, ] rocprof_command = [perfRunner.ROCPROF] + perfRunner.get_metric_args_for_rocprof( options.arch) + [ - '--kernel-trace', '--stats', '-f', 'csv', '-o', + '--kernel-trace', '--stats', '--output-format', 'csv', '-o', perfRunner.BENCHMARKING_RESULT_FILE_NAME, '--', paths.mlir_paths.cpu_runner_path ] + mlir_cpu_runner_args @@ -1520,6 +1553,13 @@ def execute_tuning_task(test_vector: str) -> TuningResult: state_file.set_running(test_vector) + if _should_skip_f32_on_navi(ctx.options.chip, test_vector, ctx.conf_class): + state_file.remove(test_vector) + return TuningResult(test_vector=test_vector, + success=False, + skipped=True, + gpu_id=gpu_id) + timestamp = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ') start_time = time.time() compile_threads = ctx.get_compile_threads(gpu_id) @@ -1550,6 +1590,16 @@ def execute_tuning_task(test_vector: str) -> TuningResult: results_writer.write_result(result) if debug_writer: debug_writer.write_result(result) + elif result.skipped: + skip_msg = (f"'{result.test_vector}' on GPU {result.gpu_id} " + "(f32 on Navi has no tuning support for this op)") + if sys.stderr.isatty(): + tqdm.write( + f"{_LOG_COLORS[logging.WARNING]}SKIPPED{_COLOR_RESET}: {skip_msg}", + file=sys.stderr, + ) + else: + tqdm.write(f"SKIPPED: {skip_msg}", file=sys.stderr) else: has_errors = True logger.error(