@@ -281,59 +281,36 @@ struct GemmQuantTypeConfig
281281 using CDataType = CDataType_;
282282};
283283
284- template <typename T>
285- struct DataTypeTraits ;
286-
287- template <>
288- struct DataTypeTraits <float >
289- {
290- static constexpr const char * name = " fp32" ;
291- };
292-
293- template <>
294- struct DataTypeTraits <double >
295- {
296- static constexpr const char * name = " fp64" ;
297- };
298-
299- template <>
300- struct DataTypeTraits <int32_t >
284+ auto create_args (int argc, char * argv[])
301285{
302- static constexpr const char * name = " int32" ;
303- };
304-
305- template <>
306- struct DataTypeTraits <ck_tile::half_t >
307- {
308- static constexpr const char * name = " fp16" ;
309- };
310-
311- template <>
312- struct DataTypeTraits <ck_tile::bf16_t >
313- {
314- static constexpr const char * name = " bf16" ;
315- };
316-
317- template <>
318- struct DataTypeTraits <ck_tile::fp8_t >
319- {
320- static constexpr const char * name = " fp8" ;
321- };
322-
323- template <>
324- struct DataTypeTraits <ck_tile::bf8_t >
325- {
326- static constexpr const char * name = " bf8" ;
327- };
328-
329- template <>
330- struct DataTypeTraits <ck_tile::pk_int4_t >
331- {
332- static constexpr const char * name = " pk_int4_t" ;
333- };
334-
335- template <>
336- struct DataTypeTraits <ck_tile::int8_t >
337- {
338- static constexpr const char * name = " int8" ;
339- };
286+ ck_tile::ArgParser arg_parser;
287+ arg_parser.insert (" m" , " 3840" , " m dimension" )
288+ .insert (" n" , " 4096" , " n dimension" )
289+ .insert (" k" , " 2048" , " k dimension" )
290+ .insert (" a_layout" , " R" , " A tensor data layout - Row by default" )
291+ .insert (" b_layout" , " C" , " B tensor data layout - Column by default" )
292+ .insert (" bq_layout" , " C" , " Bq tensor data layout - Column by default" )
293+ .insert (" c_layout" , " R" , " C tensor data layout - Row by default" )
294+ .insert (" stride_a" , " 0" , " Tensor A stride" )
295+ .insert (" stride_q" , " 0" , " Tensor AQ stride" )
296+ .insert (" stride_b" , " 0" , " Tensor B stride" )
297+ .insert (" stride_c" , " 0" , " Tensor C stride" )
298+ .insert (" v" , " 1" , " 0. No validation, 1. Validation on CPU, 2. Validation on GPU" )
299+ .insert (" prec" ,
300+ " fp8" ,
301+ " data type. For AQuant: fp8/bf8/i4fp8/i4bf8, For Bquant: fp8/bf8/fp8i4/bf8i4" )
302+ .insert (" warmup" , " 50" , " number of iterations before benchmark the kernel" )
303+ .insert (" repeat" , " 1000" , " number of iterations to benchmark the kernel" )
304+ .insert (" timer" , " gpu" , " gpu:gpu timer, cpu:cpu timer" )
305+ .insert (" split_k" , " 1" , " splitK value" )
306+ .insert (" init" , " 0" , " 0:random, 1:linear, 2:constant(1)" )
307+ .insert (" flush_cache" , " true" , " flush cache before running the kernel, defaults to true" )
308+ .insert (" rotating_count" , " 1000" , " rotating count, defaults to 1" )
309+ .insert (" quant_mode" , " bquant" , " Choose aquant (default), bquant, tensor or rowcol" )
310+ .insert (" group_size" ,
311+ " 1x1x128" ,
312+ " Quantization group size as MxNxK, e.g., 1x1x128, 1x32x128, 1x64x128" );
313+
314+ bool result = arg_parser.parse (argc, argv);
315+ return std::make_tuple (result, arg_parser);
316+ }
0 commit comments