@@ -252,7 +252,7 @@ def get_model_and_inputs_from_name(
252252def quantize (
253253 model : GraphModule ,
254254 model_name : str ,
255- compile_specs : EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec ,
255+ compile_specs : ArmCompileSpec ,
256256 example_inputs : Tuple [torch .Tensor ],
257257 evaluator_name : str | None ,
258258 evaluator_config : Dict [str , Any ] | None ,
@@ -425,48 +425,39 @@ def get_calibration_data(
425425 return example_inputs
426426
427427
428- def get_compile_spec (
429- target : str ,
430- intermediates : Optional [str ] = None ,
431- system_config : Optional [str ] = None ,
432- memory_mode : Optional [str ] = None ,
433- quantize : bool = False ,
434- config : Optional [str ] = None ,
435- debug_mode : Optional [str ] = None ,
436- direct_drive : bool = False ,
437- ) -> TosaCompileSpec | EthosUCompileSpec | VgfCompileSpec :
428+ def get_compile_spec (args ) -> ArmCompileSpec :
438429 compile_spec = None
439- if target .startswith ("TOSA" ):
440- tosa_spec = TosaSpecification .create_from_string (target )
430+ if args . target .startswith ("TOSA" ):
431+ tosa_spec = TosaSpecification .create_from_string (args . target )
441432 compile_spec = TosaCompileSpec (tosa_spec )
442- elif "ethos-u" in target :
433+ elif "ethos-u" in args . target :
443434 extra_flags = ["--verbose-operators" , "--verbose-cycle-estimate" ]
444- if debug_mode is not None :
435+ if args . enable_debug_mode is not None :
445436 extra_flags .append ("--enable-debug-db" )
446- if direct_drive :
437+ if args . direct_drive :
447438 extra_flags .append ("--separate-io-regions" )
448439 extra_flags .append ("--cop-format=COP2" )
449440 compile_spec = EthosUCompileSpec (
450- target ,
451- system_config = system_config ,
452- memory_mode = memory_mode ,
441+ args . target ,
442+ system_config = args . system_config ,
443+ memory_mode = args . memory_mode ,
453444 extra_flags = extra_flags ,
454- config_ini = config ,
445+ config_ini = args . config ,
455446 )
456- elif "vgf" in target :
457- if quantize :
447+ elif "vgf" in args . target :
448+ if args . quantize :
458449 tosa_spec = TosaSpecification .create_from_string ("TOSA-1.0+INT" )
459450 else :
460451 tosa_spec = TosaSpecification .create_from_string ("TOSA-1.0+FP" )
461452 compile_spec = VgfCompileSpec (tosa_spec )
462453 else :
463- raise RuntimeError (f"Unkown target { target } " )
454+ raise RuntimeError (f"Unkown target { args . target } " )
464455
465- if intermediates is not None :
466- compile_spec .dump_intermediate_artifacts_to (intermediates )
456+ if args . intermediates is not None :
457+ compile_spec .dump_intermediate_artifacts_to (args . intermediates )
467458
468- if debug_mode is not None :
469- mode = ArmCompileSpec .DebugMode [debug_mode .upper ()]
459+ if args . enable_debug_mode is not None :
460+ mode = ArmCompileSpec .DebugMode [args . enable_debug_mode .upper ()]
470461 compile_spec .dump_debug_info (mode )
471462
472463 return compile_spec
@@ -762,16 +753,7 @@ def to_edge_TOSA_delegate(
762753):
763754 # As we can target multiple output encodings, one must
764755 # be specified.
765- compile_spec = get_compile_spec (
766- args .target ,
767- args .intermediates ,
768- args .system_config ,
769- args .memory_mode ,
770- args .quantize ,
771- args .config ,
772- args .enable_debug_mode ,
773- args .direct_drive ,
774- )
756+ compile_spec = get_compile_spec (args )
775757
776758 model_quant = None
777759 if args .quantize :
@@ -876,16 +858,7 @@ def to_edge_no_delegate(
876858 if args .quantize :
877859 # As we can target multiple output encodings, one must
878860 # be specified.
879- compile_spec = get_compile_spec (
880- args .target ,
881- args .intermediates ,
882- args .system_config ,
883- args .memory_mode ,
884- args .quantize ,
885- args .config ,
886- args .enable_debug_mode ,
887- args .direct_drive ,
888- )
861+ compile_spec = get_compile_spec (args )
889862 model , exported_program = quantize_model (
890863 args , model , example_inputs , compile_spec
891864 )
0 commit comments