Skip to content

Commit fde943a

Browse files
authored
Arm backend: Simplify get_compile_spec signature (#18003)
Reduce the number of args to `get_compile_spec` from eight to one by forwarding program args object instead of eight of its members. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent b7da7d5 commit fde943a

1 file changed

Lines changed: 20 additions & 47 deletions

File tree

examples/arm/aot_arm_compiler.py

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def get_model_and_inputs_from_name(
252252
def quantize(
253253
model: GraphModule,
254254
model_name: str,
255-
compile_specs: EthosUCompileSpec | VgfCompileSpec | TosaCompileSpec,
255+
compile_specs: ArmCompileSpec,
256256
example_inputs: Tuple[torch.Tensor],
257257
evaluator_name: str | None,
258258
evaluator_config: Dict[str, Any] | None,
@@ -425,48 +425,39 @@ def get_calibration_data(
425425
return example_inputs
426426

427427

428-
def get_compile_spec(
429-
target: str,
430-
intermediates: Optional[str] = None,
431-
system_config: Optional[str] = None,
432-
memory_mode: Optional[str] = None,
433-
quantize: bool = False,
434-
config: Optional[str] = None,
435-
debug_mode: Optional[str] = None,
436-
direct_drive: bool = False,
437-
) -> TosaCompileSpec | EthosUCompileSpec | VgfCompileSpec:
428+
def get_compile_spec(args) -> ArmCompileSpec:
438429
compile_spec = None
439-
if target.startswith("TOSA"):
440-
tosa_spec = TosaSpecification.create_from_string(target)
430+
if args.target.startswith("TOSA"):
431+
tosa_spec = TosaSpecification.create_from_string(args.target)
441432
compile_spec = TosaCompileSpec(tosa_spec)
442-
elif "ethos-u" in target:
433+
elif "ethos-u" in args.target:
443434
extra_flags = ["--verbose-operators", "--verbose-cycle-estimate"]
444-
if debug_mode is not None:
435+
if args.enable_debug_mode is not None:
445436
extra_flags.append("--enable-debug-db")
446-
if direct_drive:
437+
if args.direct_drive:
447438
extra_flags.append("--separate-io-regions")
448439
extra_flags.append("--cop-format=COP2")
449440
compile_spec = EthosUCompileSpec(
450-
target,
451-
system_config=system_config,
452-
memory_mode=memory_mode,
441+
args.target,
442+
system_config=args.system_config,
443+
memory_mode=args.memory_mode,
453444
extra_flags=extra_flags,
454-
config_ini=config,
445+
config_ini=args.config,
455446
)
456-
elif "vgf" in target:
457-
if quantize:
447+
elif "vgf" in args.target:
448+
if args.quantize:
458449
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
459450
else:
460451
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP")
461452
compile_spec = VgfCompileSpec(tosa_spec)
462453
else:
463-
raise RuntimeError(f"Unkown target {target}")
454+
raise RuntimeError(f"Unkown target {args.target}")
464455

465-
if intermediates is not None:
466-
compile_spec.dump_intermediate_artifacts_to(intermediates)
456+
if args.intermediates is not None:
457+
compile_spec.dump_intermediate_artifacts_to(args.intermediates)
467458

468-
if debug_mode is not None:
469-
mode = ArmCompileSpec.DebugMode[debug_mode.upper()]
459+
if args.enable_debug_mode is not None:
460+
mode = ArmCompileSpec.DebugMode[args.enable_debug_mode.upper()]
470461
compile_spec.dump_debug_info(mode)
471462

472463
return compile_spec
@@ -762,16 +753,7 @@ def to_edge_TOSA_delegate(
762753
):
763754
# As we can target multiple output encodings, one must
764755
# be specified.
765-
compile_spec = get_compile_spec(
766-
args.target,
767-
args.intermediates,
768-
args.system_config,
769-
args.memory_mode,
770-
args.quantize,
771-
args.config,
772-
args.enable_debug_mode,
773-
args.direct_drive,
774-
)
756+
compile_spec = get_compile_spec(args)
775757

776758
model_quant = None
777759
if args.quantize:
@@ -876,16 +858,7 @@ def to_edge_no_delegate(
876858
if args.quantize:
877859
# As we can target multiple output encodings, one must
878860
# be specified.
879-
compile_spec = get_compile_spec(
880-
args.target,
881-
args.intermediates,
882-
args.system_config,
883-
args.memory_mode,
884-
args.quantize,
885-
args.config,
886-
args.enable_debug_mode,
887-
args.direct_drive,
888-
)
861+
compile_spec = get_compile_spec(args)
889862
model, exported_program = quantize_model(
890863
args, model, example_inputs, compile_spec
891864
)

0 commit comments

Comments
 (0)