Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3796,6 +3796,46 @@ const StringRef PopulateParamsGemmGemm::initParametersI8AttentionGfx1152[] = {
};
// END_ATTENTION_GemmGemm_i8_gfx1152_DEFS

// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx908_DEFS
const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx908[] = {
"attn:v3:128,128,16,16,16,16,16,8,4,1,2,0,1",
"attn:v3:128,128,32,32,32,16,16,4,4,1,2,0,1",
"attn:v3:128,128,128,32,32,32,16,4,4,1,2,0,1",
"attn:v3:128,128,32,8,16,32,16,16,4,1,2,0,1",
"attn:v3:32,256,16,8,16,16,16,8,1,1,2,0,1"
};
// END_GEMM_GEMM_GemmGemm_f16_gfx908_DEFS

// BEGIN_GEMM_GEMM_GemmGemm_f32_gfx908_DEFS
const StringRef PopulateParamsGemmGemm::initParametersF32GemmGemmGfx908[] = {
"attn:v3:64,128,32,16,16,16,16,4,4,1,2,0,1",
"attn:v3:64,128,32,32,16,16,16,4,4,1,2,0,1",
"attn:v3:64,64,128,8,64,16,16,8,4,1,2,0,1"
};
// END_GEMM_GEMM_GemmGemm_f32_gfx908_DEFS

// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1200_DEFS
const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1200[] = {
"attn:v3:128,128,64,32,16,16,16,4,4,1,2,0,1",
"attn:v3:64,64,32,16,16,16,16,8,4,2,2,0,1",
"attn:v3:128,128,16,4,16,16,16,8,4,1,2,0,1",
"attn:v3:128,128,32,4,16,16,16,8,1,1,2,0,1",
"attn:v3:128,128,32,8,16,16,16,8,1,1,2,0,1",
"attn:v3:128,128,64,16,32,32,16,8,1,1,2,0,1",
"attn:v3:32,32,32,8,16,16,16,8,1,2,2,0,1",
"attn:v3:64,256,32,16,16,16,16,8,1,1,2,0,1"
};
// END_GEMM_GEMM_GemmGemm_f16_gfx1200_DEFS

// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1100_DEFS
const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1100[] = {
"attn:v3:128,128,16,16,16,16,16,8,1,1,2,0,1",
"attn:v3:128,128,16,4,16,16,16,8,1,1,2,0,1",
"attn:v3:64,256,32,16,16,16,16,8,1,1,2,0,1",
"attn:v3:128,128,16,8,32,16,16,8,1,1,2,0,1",
"attn:v3:32,64,256,16,32,64,16,16,1,1,2,0,1"
};
// END_GEMM_GEMM_GemmGemm_f16_gfx1100_DEFS
// BEGIN_ATTENTION_GemmGemm_bf16_gfx1103_DEFS
const StringRef PopulateParamsGemmGemm::initParametersBf16AttentionGfx1103[] = {
"attn:v3:32,64,128,2,32,32,16,8,1,1,2,0,1",
Expand Down Expand Up @@ -4016,6 +4056,25 @@ static constexpr size_t nInitParametersI8AttentionGfx1152 = 10;
static const StringRef initParametersI8AttentionGfx1152[nInitParametersI8AttentionGfx1152];
// END_ATTENTION_GemmGemm_i8_gfx1152_DECS

// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx908_DECS
static constexpr size_t nInitParametersF16GemmGemmGfx908 = 5;
static const StringRef initParametersF16GemmGemmGfx908[nInitParametersF16GemmGemmGfx908];
// END_GEMM_GEMM_GemmGemm_f16_gfx908_DECS

// BEGIN_GEMM_GEMM_GemmGemm_f32_gfx908_DECS
static constexpr size_t nInitParametersF32GemmGemmGfx908 = 3;
static const StringRef initParametersF32GemmGemmGfx908[nInitParametersF32GemmGemmGfx908];
// END_GEMM_GEMM_GemmGemm_f32_gfx908_DECS

// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1200_DECS
static constexpr size_t nInitParametersF16GemmGemmGfx1200 = 8;
static const StringRef initParametersF16GemmGemmGfx1200[nInitParametersF16GemmGemmGfx1200];
// END_GEMM_GEMM_GemmGemm_f16_gfx1200_DECS

// BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1100_DECS
static constexpr size_t nInitParametersF16GemmGemmGfx1100 = 5;
static const StringRef initParametersF16GemmGemmGfx1100[nInitParametersF16GemmGemmGfx1100];
// END_GEMM_GEMM_GemmGemm_f16_gfx1100_DECS
// BEGIN_ATTENTION_GemmGemm_bf16_gfx1103_DECS
static constexpr size_t nInitParametersBf16AttentionGfx1103 = 13;
static const StringRef initParametersBf16AttentionGfx1103[nInitParametersBf16AttentionGfx1103];
Expand Down Expand Up @@ -4295,6 +4354,14 @@ static const StringRef initParametersI8AttentionGfx1103[nInitParametersI8Attenti

{"gfx1152_attention_i8", {PopulateParamsGemmGemm::initParametersI8AttentionGfx1152, PopulateParamsGemmGemm::nInitParametersI8AttentionGfx1152}},

{"gfx908_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx908, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx908}},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have i8 and bf16 gemm+gemm configs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was all that was generated by now


{"gfx908_gemmelementwisegemm_f32", {PopulateParamsGemmGemm::initParametersF32GemmGemmGfx908, PopulateParamsGemmGemm::nInitParametersF32GemmGemmGfx908}},

{"gfx1200_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1200, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx1200}},

{"gfx1100_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1100, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx1100}},

Comment on lines +4357 to +4364
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have bf16 and i8 gemm+gemm configs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was all that was generated by now

{"gfx1103_attention_bf16", {PopulateParamsGemmGemm::initParametersBf16AttentionGfx1103, PopulateParamsGemmGemm::nInitParametersBf16AttentionGfx1103}},

{"gfx1103_attention_f16", {PopulateParamsGemmGemm::initParametersF16AttentionGfx1103, PopulateParamsGemmGemm::nInitParametersF16AttentionGfx1103}},
Expand Down
43 changes: 38 additions & 5 deletions mlir/utils/performance/analysis/quickTuningGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
'TransQ', 'TransK', 'TransV', 'TransO', 'Causal', 'ReturnLSE', 'SplitKV', 'WithAttnScale',
'WithAttnBias', 'G', 'SeqLenQ', 'SeqLenK', 'NumHeadsQ', 'NumHeadsKV', 'HeadDimQK', 'HeadDimV'
]
GEMM_GEMM_COLUMNS = ['TransA', 'TransB', 'TransC', 'TransO', 'G', 'M', 'K', 'N', 'O']
CONV_GEMM_COLUMNS = [
'FilterLayout', 'InputLayout', 'TransC', 'TransO', 'N', 'C', 'H', 'W', 'K', 'Y', 'X',
'DilationH', 'DilationW', 'StrideH', 'StrideW', 'PaddingH', 'PaddingW', 'O'
]

# Regex pattern for lookup table entries: {"arch_op_dtype", {Class::params, Class::count}}, // optional comment
LOOKUP_ENTRY_PATTERN = re.compile(r'\{("(gfx\w+)_(\w+)_(\w+)"),\s*(\{[^}]+\})\},(\s*//[^\n]*)?')
Expand All @@ -35,7 +40,9 @@

def get_instruction_type(arch, dtype, op):
"""Determine instruction type based on architecture, data type, and operation."""
if op == "attention":
if op in ("attention", "gemm_gemm", "conv_gemm"):
if op == "gemm_gemm" and arch.startswith("gfx1") and dtype == "f32":
return "NonAccel"
Comment on lines +44 to +45
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Navi4x/gfx12, doesn't it require nonaccel as well ?

Comment on lines +44 to +45
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f32 on Navi for "GemmGemm" instruction types is unsupported. Not sure if it's appropriate to return NonAccel here.

return "GemmGemm"
if arch.startswith("gfx9"):
return "XDL"
Expand All @@ -55,9 +62,15 @@ def get_class_name(arch, dtype, op):
return f"PopulateParams{instr}" if instr != "NonAccel" else "PopulateParams"


def _op_cap_for_param_name(op):
"""Format op for C++ param name: gemm_gemm -> GemmGemm, attention -> Attention."""
return "".join(part.capitalize() for part in op.split("_"))


def get_param_names(arch, dtype, op):
"""Generate array and count variable names."""
base = f"initParameters{dtype.capitalize()}{op.capitalize()}{arch.capitalize()}"
op_cap = _op_cap_for_param_name(op)
base = f"initParameters{dtype.capitalize()}{op_cap}{arch.capitalize()}"
return base, f"n{base[0].upper()}{base[1:]}"


Expand All @@ -69,6 +82,10 @@ def get_target_columns(op):
return CONV_COLUMNS
elif op == "attention":
return ATTENTION_COLUMNS
elif op == "gemm_gemm":
return GEMM_GEMM_COLUMNS
elif op == "conv_gemm":
return CONV_GEMM_COLUMNS
else:
raise ValueError(f"Unknown operation: {op}")

Expand Down Expand Up @@ -302,9 +319,23 @@ def add_lookup_entry(content, insert_marker, entry):
return content[:insert_pos] + f'{entry}\n\n' + content[insert_pos:]


def get_lookup_key_op(op):
"""Return the operation key used in the C++ lookup table (matches stringifyEnum(KernelType).lower())."""
# C++ KernelType enum: Attention, GemmElementwiseGemm, ConvElementwiseGemm -> lower()
key_map = {
"attention": "attention",
"gemm_gemm": "gemmelementwisegemm",
"conv_gemm": "convelementwisegemm"
}
Comment on lines +325 to +329
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about conv and gemm?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR adds support for GG and CG only

return key_map.get(op, op)


def get_lookup_endif(arch, op, dtype):
"""Get the appropriate lookup table #endif marker."""
if op == "attention":
# op may be script name (gemm_gemm) or C++ key form (gemmelementwisegemm) from .inc
gemm_gemm_ops = ("attention", "gemm_gemm", "conv_gemm", "gemmelementwisegemm",
"convelementwisegemm")
if op in gemm_gemm_ops:
return "#endif // GemmGemm_LOOKUP_TABLE_GEN"
elif is_accel(arch, dtype, op):
return "#endif // Accel_LOOKUP_TABLE_GEN"
Expand Down Expand Up @@ -350,7 +381,7 @@ def update_inc_file(results, arch, op):

# Add lookup entry
endif_marker = get_lookup_endif(arch, op, dtype)
key = f"{arch}_{op}_{dtype}"
key = f"{arch}_{get_lookup_key_op(op)}_{dtype}"
value = f"{{{class_name}::{param_name}, {class_name}::{count_name}}}"
entry = f'{{"{key}", {value}}},'
content = add_lookup_entry(content, endif_marker, entry)
Expand Down Expand Up @@ -449,7 +480,9 @@ def main(args=None):
nargs='*',
metavar='FILE',
help='.debug files produced by tuningRunner.py (reads TSV from stdin if none provided)')
parser.add_argument('--op', choices=['gemm', 'conv', 'attention'], help='Operation')
parser.add_argument('--op',
choices=['gemm', 'conv', 'attention', 'gemm_gemm', 'conv_gemm'],
help='Operation')
parser.add_argument('--th',
type=float,
default=0.93,
Expand Down
6 changes: 3 additions & 3 deletions mlir/utils/performance/perfRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,8 +1784,8 @@ def run_config_with_mlir(config: PerfConfiguration,
'--entry-point-result=void'
]
profiler_cmd = [ROCPROF] + get_metric_args_for_rocprof(arch) + [
'--kernel-trace', '--stats', '-f', 'csv', '-o', BENCHMARKING_RESULT_FILE_NAME, '--',
paths.mlir_paths.cpu_runner_path
'--kernel-trace', '--stats', '--output-format', 'csv', '-o',
BENCHMARKING_RESULT_FILE_NAME, '--', paths.mlir_paths.cpu_runner_path
] + mlir_cpu_runner_args

outs, noerr = run_pipeline([rocmlir_gen_cmd.split(), rocmlir_driver_cmd, profiler_cmd])
Expand Down Expand Up @@ -2068,7 +2068,7 @@ def run_fusion_kernel(filename, rocmlir_gen_args, paths: Paths):
'--entry-point-result=void'
]
profiler_cmd = [ROCPROF] + get_metric_args_for_rocprof(chip) + [
'--kernel-trace', '--stats', '-f', 'csv', '-o', BENCHMARKING_RESULT_FILE_NAME
'--kernel-trace', '--stats', '--output-format', 'csv', '-o', BENCHMARKING_RESULT_FILE_NAME
] + ['--', paths.mlir_paths.cpu_runner_path] + mlir_cpu_runner_args
commands.append(profiler_cmd)
outs, noerr = run_pipeline(commands)
Expand Down
56 changes: 53 additions & 3 deletions mlir/utils/performance/tuningRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,33 @@ class Options:
timeout: Optional[int]


def _is_navi_arch(arch: str) -> bool:
"""Return True if arch is Navi (gfx11xx or gfx12xx)."""
return arch.startswith("gfx1")


# Operations that have no f32 tuning support on Navi (gfx11xx/gfx12xx) - empty tuning range
_F32_NAVI_UNSUPPORTED_OPS = frozenset(
{'GemmGemmConfiguration', 'ConvGemmConfiguration', 'AttentionConfiguration'})


def _should_skip_f32_on_navi(arch: str, test_vector: str, conf_class: type) -> bool:
"""Return True if this op is f32 on Navi and has no tuning support (empty range)."""
if conf_class.__name__ not in _F32_NAVI_UNSUPPORTED_OPS:
return False
if not _is_navi_arch(arch):
return False
# Match -t f32 in the test vector (e.g. "-t f32 -transA" or " -t f32 ")
return '-t f32' in test_vector
Comment on lines +202 to +203
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this work for GemmGemm problem configs as well ?



@dataclass
class TuningResult:
"""Result of tuning a single configuration."""
test_vector: str
success: bool
timed_out: bool = False
skipped: bool = False
gpu_id: int = -1
duration_seconds: float = 0.0
timestamp: Optional[str] = None
Expand Down Expand Up @@ -500,6 +521,12 @@ def set_succeeded(self, test_vector: str) -> None:
self._state.remove(test_vector)
self._save_locked()

def remove(self, test_vector: str) -> None:
"""Remove test_vector from state (e.g. when skipping without marking failed)."""
with self._lock:
self._state.remove(test_vector)
self._save_locked()

def finalize_interrupted(self) -> None:
"""Mark RUNNING configs as INTERRUPTED on clean shutdown."""
with self._lock:
Expand Down Expand Up @@ -715,11 +742,14 @@ class ETATracker:
success_times: List[float] = field(default_factory=list)
ok_count: int = 0
fail_count: int = 0
skip_count: int = 0
_processed: int = field(default=0, init=False)

def record(self, result: TuningResult) -> None:
self._processed += 1
if result.success:
if result.skipped:
self.skip_count += 1
elif result.success:
self.ok_count += 1
self.success_times.append(result.duration_seconds)
else:
Expand Down Expand Up @@ -760,7 +790,10 @@ def get_postfix_str(self) -> str:
rate = self._format_rate(median)
eta = self._format_eta(eta_seconds)

return f"ok={self.ok_count}, fail={self.fail_count}, rate={rate}, eta={eta}"
postfix = f"ok={self.ok_count}, fail={self.fail_count}"
if self.skip_count > 0:
postfix += f", skip={self.skip_count}"
return f"{postfix}, rate={rate}, eta={eta}"


@dataclass
Expand Down Expand Up @@ -1153,7 +1186,7 @@ def verify_perfconfig(perfconfig: str, config: PerfConfiguration, paths: Paths,
]
rocprof_command = [perfRunner.ROCPROF] + perfRunner.get_metric_args_for_rocprof(
options.arch) + [
'--kernel-trace', '--stats', '-f', 'csv', '-o',
'--kernel-trace', '--stats', '--output-format', 'csv', '-o',
perfRunner.BENCHMARKING_RESULT_FILE_NAME, '--', paths.mlir_paths.cpu_runner_path
] + mlir_cpu_runner_args

Expand Down Expand Up @@ -1520,6 +1553,13 @@ def execute_tuning_task(test_vector: str) -> TuningResult:

state_file.set_running(test_vector)

if _should_skip_f32_on_navi(ctx.options.chip, test_vector, ctx.conf_class):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to filter out these configs in the beginning and print something like "Skipping N unsupported configs". It would simplify the rest of the code.

state_file.remove(test_vector)
return TuningResult(test_vector=test_vector,
success=False,
skipped=True,
gpu_id=gpu_id)

timestamp = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
start_time = time.time()
compile_threads = ctx.get_compile_threads(gpu_id)
Expand Down Expand Up @@ -1550,6 +1590,16 @@ def execute_tuning_task(test_vector: str) -> TuningResult:
results_writer.write_result(result)
if debug_writer:
debug_writer.write_result(result)
elif result.skipped:
skip_msg = (f"'{result.test_vector}' on GPU {result.gpu_id} "
"(f32 on Navi has no tuning support for this op)")
if sys.stderr.isatty():
tqdm.write(
f"{_LOG_COLORS[logging.WARNING]}SKIPPED{_COLOR_RESET}: {skip_msg}",
file=sys.stderr,
)
else:
tqdm.write(f"SKIPPED: {skip_msg}", file=sys.stderr)
else:
has_errors = True
logger.error(
Expand Down
Loading