-
Notifications
You must be signed in to change notification settings - Fork 55
GEMM+GEMM and CONV+GEMM support to quickTuningGen and GEMM+GEMM quick tuning list #2262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
902e56b
25e21c1
b65f493
64cab37
892f6d9
10ba5ef
e009729
6efb66d
5f35955
68a466e
5b0a41e
4d2a00f
6be3f3c
ee4ee61
16a57b9
6558492
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3796,6 +3796,46 @@ const StringRef PopulateParamsGemmGemm::initParametersI8AttentionGfx1152[] = { | |
| }; | ||
| // END_ATTENTION_GemmGemm_i8_gfx1152_DEFS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f16_gfx908_DEFS | ||
| const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx908[] = { | ||
| "attn:v3:128,128,16,16,16,16,16,8,4,1,2,0,1", | ||
| "attn:v3:128,128,32,32,32,16,16,4,4,1,2,0,1", | ||
| "attn:v3:128,128,128,32,32,32,16,4,4,1,2,0,1", | ||
| "attn:v3:128,128,32,8,16,32,16,16,4,1,2,0,1", | ||
| "attn:v3:32,256,16,8,16,16,16,8,1,1,2,0,1" | ||
| }; | ||
| // END_GEMM_GEMM_GemmGemm_f16_gfx908_DEFS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f32_gfx908_DEFS | ||
| const StringRef PopulateParamsGemmGemm::initParametersF32GemmGemmGfx908[] = { | ||
| "attn:v3:64,128,32,16,16,16,16,4,4,1,2,0,1", | ||
| "attn:v3:64,128,32,32,16,16,16,4,4,1,2,0,1", | ||
| "attn:v3:64,64,128,8,64,16,16,8,4,1,2,0,1" | ||
| }; | ||
| // END_GEMM_GEMM_GemmGemm_f32_gfx908_DEFS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1200_DEFS | ||
| const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1200[] = { | ||
| "attn:v3:128,128,64,32,16,16,16,4,4,1,2,0,1", | ||
| "attn:v3:64,64,32,16,16,16,16,8,4,2,2,0,1", | ||
| "attn:v3:128,128,16,4,16,16,16,8,4,1,2,0,1", | ||
| "attn:v3:128,128,32,4,16,16,16,8,1,1,2,0,1", | ||
| "attn:v3:128,128,32,8,16,16,16,8,1,1,2,0,1", | ||
| "attn:v3:128,128,64,16,32,32,16,8,1,1,2,0,1", | ||
| "attn:v3:32,32,32,8,16,16,16,8,1,2,2,0,1", | ||
| "attn:v3:64,256,32,16,16,16,16,8,1,1,2,0,1" | ||
| }; | ||
| // END_GEMM_GEMM_GemmGemm_f16_gfx1200_DEFS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1100_DEFS | ||
| const StringRef PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1100[] = { | ||
| "attn:v3:128,128,16,16,16,16,16,8,1,1,2,0,1", | ||
| "attn:v3:128,128,16,4,16,16,16,8,1,1,2,0,1", | ||
| "attn:v3:64,256,32,16,16,16,16,8,1,1,2,0,1", | ||
| "attn:v3:128,128,16,8,32,16,16,8,1,1,2,0,1", | ||
| "attn:v3:32,64,256,16,32,64,16,16,1,1,2,0,1" | ||
| }; | ||
| // END_GEMM_GEMM_GemmGemm_f16_gfx1100_DEFS | ||
| // BEGIN_ATTENTION_GemmGemm_bf16_gfx1103_DEFS | ||
| const StringRef PopulateParamsGemmGemm::initParametersBf16AttentionGfx1103[] = { | ||
| "attn:v3:32,64,128,2,32,32,16,8,1,1,2,0,1", | ||
|
|
@@ -4016,6 +4056,25 @@ static constexpr size_t nInitParametersI8AttentionGfx1152 = 10; | |
| static const StringRef initParametersI8AttentionGfx1152[nInitParametersI8AttentionGfx1152]; | ||
| // END_ATTENTION_GemmGemm_i8_gfx1152_DECS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f16_gfx908_DECS | ||
| static constexpr size_t nInitParametersF16GemmGemmGfx908 = 5; | ||
| static const StringRef initParametersF16GemmGemmGfx908[nInitParametersF16GemmGemmGfx908]; | ||
| // END_GEMM_GEMM_GemmGemm_f16_gfx908_DECS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f32_gfx908_DECS | ||
| static constexpr size_t nInitParametersF32GemmGemmGfx908 = 3; | ||
| static const StringRef initParametersF32GemmGemmGfx908[nInitParametersF32GemmGemmGfx908]; | ||
| // END_GEMM_GEMM_GemmGemm_f32_gfx908_DECS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1200_DECS | ||
| static constexpr size_t nInitParametersF16GemmGemmGfx1200 = 8; | ||
| static const StringRef initParametersF16GemmGemmGfx1200[nInitParametersF16GemmGemmGfx1200]; | ||
| // END_GEMM_GEMM_GemmGemm_f16_gfx1200_DECS | ||
|
|
||
| // BEGIN_GEMM_GEMM_GemmGemm_f16_gfx1100_DECS | ||
| static constexpr size_t nInitParametersF16GemmGemmGfx1100 = 5; | ||
| static const StringRef initParametersF16GemmGemmGfx1100[nInitParametersF16GemmGemmGfx1100]; | ||
| // END_GEMM_GEMM_GemmGemm_f16_gfx1100_DECS | ||
| // BEGIN_ATTENTION_GemmGemm_bf16_gfx1103_DECS | ||
| static constexpr size_t nInitParametersBf16AttentionGfx1103 = 13; | ||
| static const StringRef initParametersBf16AttentionGfx1103[nInitParametersBf16AttentionGfx1103]; | ||
|
|
@@ -4295,6 +4354,14 @@ static const StringRef initParametersI8AttentionGfx1103[nInitParametersI8Attenti | |
|
|
||
| {"gfx1152_attention_i8", {PopulateParamsGemmGemm::initParametersI8AttentionGfx1152, PopulateParamsGemmGemm::nInitParametersI8AttentionGfx1152}}, | ||
|
|
||
| {"gfx908_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx908, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx908}}, | ||
|
|
||
| {"gfx908_gemmelementwisegemm_f32", {PopulateParamsGemmGemm::initParametersF32GemmGemmGfx908, PopulateParamsGemmGemm::nInitParametersF32GemmGemmGfx908}}, | ||
|
|
||
| {"gfx1200_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1200, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx1200}}, | ||
|
|
||
| {"gfx1100_gemmelementwisegemm_f16", {PopulateParamsGemmGemm::initParametersF16GemmGemmGfx1100, PopulateParamsGemmGemm::nInitParametersF16GemmGemmGfx1100}}, | ||
|
|
||
|
Comment on lines
+4357
to
+4364
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have bf16 and i8 gemm+gemm configs?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was all that was generated by now |
||
| {"gfx1103_attention_bf16", {PopulateParamsGemmGemm::initParametersBf16AttentionGfx1103, PopulateParamsGemmGemm::nInitParametersBf16AttentionGfx1103}}, | ||
|
|
||
| {"gfx1103_attention_f16", {PopulateParamsGemmGemm::initParametersF16AttentionGfx1103, PopulateParamsGemmGemm::nInitParametersF16AttentionGfx1103}}, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,11 @@ | |
| 'TransQ', 'TransK', 'TransV', 'TransO', 'Causal', 'ReturnLSE', 'SplitKV', 'WithAttnScale', | ||
| 'WithAttnBias', 'G', 'SeqLenQ', 'SeqLenK', 'NumHeadsQ', 'NumHeadsKV', 'HeadDimQK', 'HeadDimV' | ||
| ] | ||
| GEMM_GEMM_COLUMNS = ['TransA', 'TransB', 'TransC', 'TransO', 'G', 'M', 'K', 'N', 'O'] | ||
| CONV_GEMM_COLUMNS = [ | ||
| 'FilterLayout', 'InputLayout', 'TransC', 'TransO', 'N', 'C', 'H', 'W', 'K', 'Y', 'X', | ||
| 'DilationH', 'DilationW', 'StrideH', 'StrideW', 'PaddingH', 'PaddingW', 'O' | ||
| ] | ||
|
|
||
| # Regex pattern for lookup table entries: {"arch_op_dtype", {Class::params, Class::count}}, // optional comment | ||
| LOOKUP_ENTRY_PATTERN = re.compile(r'\{("(gfx\w+)_(\w+)_(\w+)"),\s*(\{[^}]+\})\},(\s*//[^\n]*)?') | ||
|
|
@@ -35,7 +40,9 @@ | |
|
|
||
| def get_instruction_type(arch, dtype, op): | ||
| """Determine instruction type based on architecture, data type, and operation.""" | ||
| if op == "attention": | ||
| if op in ("attention", "gemm_gemm", "conv_gemm"): | ||
| if op == "gemm_gemm" and arch.startswith("gfx1") and dtype == "f32": | ||
| return "NonAccel" | ||
|
Comment on lines
+44
to
+45
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Navi4x/gfx12, doesn't it require nonaccel as well ?
Comment on lines
+44
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. f32 on Navi for "GemmGemm" instruction types is unsupported. Not sure if it's appropriate to return NonAccel here. |
||
| return "GemmGemm" | ||
| if arch.startswith("gfx9"): | ||
| return "XDL" | ||
|
|
@@ -55,9 +62,15 @@ def get_class_name(arch, dtype, op): | |
| return f"PopulateParams{instr}" if instr != "NonAccel" else "PopulateParams" | ||
|
|
||
|
|
||
| def _op_cap_for_param_name(op): | ||
| """Format op for C++ param name: gemm_gemm -> GemmGemm, attention -> Attention.""" | ||
| return "".join(part.capitalize() for part in op.split("_")) | ||
|
|
||
|
|
||
| def get_param_names(arch, dtype, op): | ||
| """Generate array and count variable names.""" | ||
| base = f"initParameters{dtype.capitalize()}{op.capitalize()}{arch.capitalize()}" | ||
| op_cap = _op_cap_for_param_name(op) | ||
| base = f"initParameters{dtype.capitalize()}{op_cap}{arch.capitalize()}" | ||
| return base, f"n{base[0].upper()}{base[1:]}" | ||
|
|
||
|
|
||
|
|
@@ -69,6 +82,10 @@ def get_target_columns(op): | |
| return CONV_COLUMNS | ||
| elif op == "attention": | ||
| return ATTENTION_COLUMNS | ||
| elif op == "gemm_gemm": | ||
| return GEMM_GEMM_COLUMNS | ||
| elif op == "conv_gemm": | ||
| return CONV_GEMM_COLUMNS | ||
| else: | ||
| raise ValueError(f"Unknown operation: {op}") | ||
|
|
||
|
|
@@ -302,9 +319,23 @@ def add_lookup_entry(content, insert_marker, entry): | |
| return content[:insert_pos] + f'{entry}\n\n' + content[insert_pos:] | ||
|
|
||
|
|
||
| def get_lookup_key_op(op): | ||
| """Return the operation key used in the C++ lookup table (matches stringifyEnum(KernelType).lower()).""" | ||
| # C++ KernelType enum: Attention, GemmElementwiseGemm, ConvElementwiseGemm -> lower() | ||
| key_map = { | ||
| "attention": "attention", | ||
| "gemm_gemm": "gemmelementwisegemm", | ||
| "conv_gemm": "convelementwisegemm" | ||
| } | ||
|
Comment on lines
+325
to
+329
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about conv and gemm?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this PR adds support for GG and CG only |
||
| return key_map.get(op, op) | ||
|
|
||
|
|
||
| def get_lookup_endif(arch, op, dtype): | ||
| """Get the appropriate lookup table #endif marker.""" | ||
| if op == "attention": | ||
| # op may be script name (gemm_gemm) or C++ key form (gemmelementwisegemm) from .inc | ||
| gemm_gemm_ops = ("attention", "gemm_gemm", "conv_gemm", "gemmelementwisegemm", | ||
| "convelementwisegemm") | ||
| if op in gemm_gemm_ops: | ||
| return "#endif // GemmGemm_LOOKUP_TABLE_GEN" | ||
| elif is_accel(arch, dtype, op): | ||
| return "#endif // Accel_LOOKUP_TABLE_GEN" | ||
|
|
@@ -350,7 +381,7 @@ def update_inc_file(results, arch, op): | |
|
|
||
| # Add lookup entry | ||
| endif_marker = get_lookup_endif(arch, op, dtype) | ||
| key = f"{arch}_{op}_{dtype}" | ||
| key = f"{arch}_{get_lookup_key_op(op)}_{dtype}" | ||
| value = f"{{{class_name}::{param_name}, {class_name}::{count_name}}}" | ||
| entry = f'{{"{key}", {value}}},' | ||
| content = add_lookup_entry(content, endif_marker, entry) | ||
|
|
@@ -449,7 +480,9 @@ def main(args=None): | |
| nargs='*', | ||
| metavar='FILE', | ||
| help='.debug files produced by tuningRunner.py (reads TSV from stdin if none provided)') | ||
| parser.add_argument('--op', choices=['gemm', 'conv', 'attention'], help='Operation') | ||
| parser.add_argument('--op', | ||
| choices=['gemm', 'conv', 'attention', 'gemm_gemm', 'conv_gemm'], | ||
| help='Operation') | ||
| parser.add_argument('--th', | ||
| type=float, | ||
| default=0.93, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -183,12 +183,33 @@ class Options: | |
| timeout: Optional[int] | ||
|
|
||
|
|
||
| def _is_navi_arch(arch: str) -> bool: | ||
| """Return True if arch is Navi (gfx11xx or gfx12xx).""" | ||
| return arch.startswith("gfx1") | ||
|
|
||
|
|
||
| # Operations that have no f32 tuning support on Navi (gfx11xx/gfx12xx) - empty tuning range | ||
| _F32_NAVI_UNSUPPORTED_OPS = frozenset( | ||
| {'GemmGemmConfiguration', 'ConvGemmConfiguration', 'AttentionConfiguration'}) | ||
|
|
||
|
|
||
| def _should_skip_f32_on_navi(arch: str, test_vector: str, conf_class: type) -> bool: | ||
| """Return True if this op is f32 on Navi and has no tuning support (empty range).""" | ||
| if conf_class.__name__ not in _F32_NAVI_UNSUPPORTED_OPS: | ||
| return False | ||
| if not _is_navi_arch(arch): | ||
| return False | ||
| # Match -t f32 in the test vector (e.g. "-t f32 -transA" or " -t f32 ") | ||
| return '-t f32' in test_vector | ||
|
Comment on lines
+202
to
+203
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this work for GemmGemm problem configs as well ? |
||
|
|
||
|
|
||
| @dataclass | ||
| class TuningResult: | ||
| """Result of tuning a single configuration.""" | ||
| test_vector: str | ||
| success: bool | ||
| timed_out: bool = False | ||
| skipped: bool = False | ||
| gpu_id: int = -1 | ||
| duration_seconds: float = 0.0 | ||
| timestamp: Optional[str] = None | ||
|
|
@@ -500,6 +521,12 @@ def set_succeeded(self, test_vector: str) -> None: | |
| self._state.remove(test_vector) | ||
| self._save_locked() | ||
|
|
||
| def remove(self, test_vector: str) -> None: | ||
| """Remove test_vector from state (e.g. when skipping without marking failed).""" | ||
| with self._lock: | ||
| self._state.remove(test_vector) | ||
| self._save_locked() | ||
|
|
||
| def finalize_interrupted(self) -> None: | ||
| """Mark RUNNING configs as INTERRUPTED on clean shutdown.""" | ||
| with self._lock: | ||
|
|
@@ -715,11 +742,14 @@ class ETATracker: | |
| success_times: List[float] = field(default_factory=list) | ||
| ok_count: int = 0 | ||
| fail_count: int = 0 | ||
| skip_count: int = 0 | ||
| _processed: int = field(default=0, init=False) | ||
|
|
||
| def record(self, result: TuningResult) -> None: | ||
| self._processed += 1 | ||
| if result.success: | ||
| if result.skipped: | ||
| self.skip_count += 1 | ||
| elif result.success: | ||
| self.ok_count += 1 | ||
| self.success_times.append(result.duration_seconds) | ||
| else: | ||
|
|
@@ -760,7 +790,10 @@ def get_postfix_str(self) -> str: | |
| rate = self._format_rate(median) | ||
| eta = self._format_eta(eta_seconds) | ||
|
|
||
| return f"ok={self.ok_count}, fail={self.fail_count}, rate={rate}, eta={eta}" | ||
| postfix = f"ok={self.ok_count}, fail={self.fail_count}" | ||
| if self.skip_count > 0: | ||
| postfix += f", skip={self.skip_count}" | ||
| return f"{postfix}, rate={rate}, eta={eta}" | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -1153,7 +1186,7 @@ def verify_perfconfig(perfconfig: str, config: PerfConfiguration, paths: Paths, | |
| ] | ||
| rocprof_command = [perfRunner.ROCPROF] + perfRunner.get_metric_args_for_rocprof( | ||
| options.arch) + [ | ||
| '--kernel-trace', '--stats', '-f', 'csv', '-o', | ||
| '--kernel-trace', '--stats', '--output-format', 'csv', '-o', | ||
| perfRunner.BENCHMARKING_RESULT_FILE_NAME, '--', paths.mlir_paths.cpu_runner_path | ||
| ] + mlir_cpu_runner_args | ||
|
|
||
|
|
@@ -1520,6 +1553,13 @@ def execute_tuning_task(test_vector: str) -> TuningResult: | |
|
|
||
| state_file.set_running(test_vector) | ||
|
|
||
| if _should_skip_f32_on_navi(ctx.options.chip, test_vector, ctx.conf_class): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to filter out these configs in the beginning and print something like "Skipping N unsupported configs". It would simplify the rest of the code. |
||
| state_file.remove(test_vector) | ||
| return TuningResult(test_vector=test_vector, | ||
| success=False, | ||
| skipped=True, | ||
| gpu_id=gpu_id) | ||
|
|
||
| timestamp = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ') | ||
| start_time = time.time() | ||
| compile_threads = ctx.get_compile_threads(gpu_id) | ||
|
|
@@ -1550,6 +1590,16 @@ def execute_tuning_task(test_vector: str) -> TuningResult: | |
| results_writer.write_result(result) | ||
| if debug_writer: | ||
| debug_writer.write_result(result) | ||
| elif result.skipped: | ||
| skip_msg = (f"'{result.test_vector}' on GPU {result.gpu_id} " | ||
| "(f32 on Navi has no tuning support for this op)") | ||
| if sys.stderr.isatty(): | ||
| tqdm.write( | ||
| f"{_LOG_COLORS[logging.WARNING]}SKIPPED{_COLOR_RESET}: {skip_msg}", | ||
| file=sys.stderr, | ||
| ) | ||
| else: | ||
| tqdm.write(f"SKIPPED: {skip_msg}", file=sys.stderr) | ||
| else: | ||
| has_errors = True | ||
| logger.error( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have i8 and bf16 gemm+gemm configs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was all that was generated by now