Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions slothy/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,23 @@ def with_llvm_mca_after(self):
"""
return self._with_llvm_mca_after

@property
def emit_clobbered_callee_saves_comment(self):
"""If True, prepend a comment to the output listing callee-saved registers
(per the ISA calling convention) that are clobbered by the optimized code.

This is useful for understanding the ABI cost of a snippet and for
constructing correct save/restore sequences around it.

Requires the ISA model to define callee_saved_registers(). When the ISA
model does not provide this method the option has no effect.
"""
return self._emit_clobbered_callee_saves_comment

@emit_clobbered_callee_saves_comment.setter
def emit_clobbered_callee_saves_comment(self, val):
self._emit_clobbered_callee_saves_comment = val

@property
def compiler_binary(self):
"""The compiler binary to be used.
Expand Down Expand Up @@ -619,6 +636,7 @@ def has_objective(self):
self.constraints.move_stalls_to_bottom is True,
self.constraints.minimize_register_usage is not None,
self.constraints.minimize_use_of_extra_registers is not None,
self.constraints.prefer_caller_save_registers is True,
self.target.has_min_max_objective(self),
]
)
Expand Down Expand Up @@ -1197,6 +1215,35 @@ def minimize_spills(self):
"""
return self._minimize_spills

@property
def prefer_caller_save_registers(self):
"""Bias register allocation toward caller-save registers.

Adds solver hints preferring registers that already appear in the
input code (regardless of whether they are caller- or callee-saved),
and discouraging callee-saved registers when new registers must be
introduced. Additionally, a secondary solver objective minimises the
number of distinct callee-saved registers used, applied after the
primary stall-minimisation objective.

Because the preference is expressed through hints and a secondary
objective, optimal scheduling is never sacrificed: if callee-saved
registers are the only way to achieve minimum stalls, they will still
be used.

This option subsumes hints.rename_hint_orig_rename: when both are
enabled the orig-rename hint is already covered, so enabling
rename_hint_orig_rename on top is harmless but redundant.

Requires the ISA model to define callee_saved_registers(). When the
ISA model does not provide this method the option has no effect.
"""
return self._prefer_caller_save_registers

@prefer_caller_save_registers.setter
def prefer_caller_save_registers(self, val):
self._prefer_caller_save_registers = val

@property
def max_displacement(self):
"""The maximum relative displacement of an instruction.
Expand Down Expand Up @@ -1236,6 +1283,7 @@ def __init__(self):
self.minimize_register_usage = None
self.minimize_use_of_extra_registers = None
self.allow_extra_registers = {}
self._prefer_caller_save_registers = False

self._stalls_allowed = 0
self._stalls_maximum_attempt = 512
Expand Down Expand Up @@ -1336,8 +1384,16 @@ def order_hint_orig_order(self):

@property
def rename_hint_orig_rename(self):
"""Hint at using the initial program order for the
program order variables."""
"""Hint the solver to keep each instruction's original output register.

For each instruction, adds a solver hint favouring the register that
was used in the input code.

Note: when constraints.prefer_caller_save_registers is also enabled,
that option already covers this hint (it biases all original registers,
not just each instruction's own output). Enabling both is harmless
but redundant for the hint on the exact original register.
"""
return self._rename_hint_orig_rename

@property
Expand Down Expand Up @@ -1439,6 +1495,7 @@ def __init__(self, Arch, Target, logger):
self._llvm_mca_issue_width_overwrite = False
self._with_llvm_mca_before = False
self._with_llvm_mca_after = False
self._emit_clobbered_callee_saves_comment = True
self._max_solutions = 64
self._timeout = None
self._retry_timeout = None
Expand Down
72 changes: 72 additions & 0 deletions slothy/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,17 @@ def output_renamings(self, v):
assert self._output_renamings is None
self._output_renamings = v

@property
def clobbered_callee_saved(self):
"""Set of callee-saved registers (per the ISA calling convention) that
are clobbered by the optimized code. Empty when the ISA model does not
define callee_saved_registers()."""
return self._clobbered_callee_saved

@clobbered_callee_saved.setter
def clobbered_callee_saved(self, val):
self._clobbered_callee_saved = val

@property
def stalls(self):
"""The number of stalls in the optimization result.
Expand Down Expand Up @@ -1561,6 +1572,7 @@ def __init__(self, config):
self._optimization_user_time = None
self._spills = {}
self._restores = {}
self._clobbered_callee_saved = set()

self.lock()

Expand Down Expand Up @@ -2204,6 +2216,7 @@ def _extract_result(self):

self._extract_spills()
self._extract_code()
self._result.clobbered_callee_saved = self._compute_clobbered_callee_saved()
self._result.selfcheck_with_fixup(self.logger.getChild("selfcheck"))
self._result.offset_fixup(self.logger.getChild("fixup"))

Expand Down Expand Up @@ -2474,6 +2487,17 @@ def get_code_line(line_no):
for s in self._result.code:
self.logger.result.debug("> " + s.to_string())

def _compute_clobbered_callee_saved(self):
if not hasattr(self.arch.RegisterType, "callee_saved_registers"):
return set()
callee_saved = set(self.arch.RegisterType.callee_saved_registers())
used = set()
for t in self._get_nodes():
used.update(t.inst.args_out)
used.update(t.inst.args_in)
used.update(t.inst.args_in_out)
return callee_saved & used

def _add_path_constraint(self, consumer, producer, cb):
"""Add model constraint cb() relating to the pair of producer-consumer
instructions
Expand Down Expand Up @@ -2735,13 +2759,38 @@ def _allow_renaming(_):

self.logger.debug("Adding variables for register allocation...")

# One Boolean per callee-saved register: True iff that register is used
# anywhere in the optimized output. Linked to register_usage_vars via
# MaxEquality in _add_constraints_register_renaming and minimized as a
# secondary objective in _add_objective, so the solver prefers
# caller-save registers without ever sacrificing stall minimization.
callee_saved = set()
self._model._callee_saved_used = {}
if self.config.constraints.prefer_caller_save_registers and hasattr(
self.arch.RegisterType, "callee_saved_registers"
):
callee_saved = set(self.arch.RegisterType.callee_saved_registers())
self._model._callee_saved_used = {
reg: self._NewBoolVar(f"callee_saved_used[{reg}]")
for reg in self.arch.RegisterType.callee_saved_registers()
}

if self.config.constraints.minimize_register_usage is not None:
ty = self.config.constraints.minimize_register_usage
regs = self.arch.RegisterType.list_registers(ty)
self._model._register_used = {
reg: self._NewBoolVar(f"reg_used[{reg}]") for reg in regs
}

# Collect registers appearing in the original (pre-renaming) code so
# that hints can prefer them. Input/output pseudo-nodes are excluded
# because their registers are fixed by the interface, not by the code.
orig_regs = set()
for t in self._get_nodes():
orig_regs.update(t.inst.args_out)
orig_regs.update(t.inst.args_in)
orig_regs.update(t.inst.args_in_out)

# Create variables for register renaming

for t in self._get_nodes(allnodes=True):
Expand Down Expand Up @@ -2822,6 +2871,13 @@ def _allow_renaming(_):
if arg_out in candidates_restricted:
self._AddHint(var_dict[arg_out], True)

if self.config.constraints.prefer_caller_save_registers:
for out_reg, var in var_dict.items():
if out_reg in orig_regs:
self._AddHint(var, True)
elif out_reg in callee_saved:
self._AddHint(var, False)

# For convenience, also add references to the variables governing the
# register renaming for input and input/output arguments.
for t in self._get_nodes(allnodes=True):
Expand Down Expand Up @@ -3140,6 +3196,13 @@ def _add_constraints_register_renaming(self):
else:
self._Add(self._model._register_used[reg] == False) # noqa: E712

for reg, used_var in self._model._callee_saved_used.items():
arr = self._model.register_usage_vars.get(reg, [])
if len(arr) > 0:
self._AddMaxEquality(used_var, arr)
else:
self._Add(used_var == False) # noqa: E712

for t in self._get_nodes(allnodes=True):
can_spill = True
if t.is_virtual is True:
Expand Down Expand Up @@ -3975,6 +4038,15 @@ def _add_objective(self, force_objective=False):
else:
maxlist = lst

# Secondary objective: prefer caller-save registers.
# This runs only in the force_objective (step-2 re-optimization) path,
# where stalls are already fixed via a hard constraint.
if force_objective and self._model._callee_saved_used and len(maxlist) == 0:
callee_vars = list(self._model._callee_saved_used.values())
if len(minlist) == 0:
minlist = callee_vars
name = "minimize callee-saved register usage"

self._model.objective_printer = printer
self._model.objective_vars = objective_vars

Expand Down
21 changes: 16 additions & 5 deletions slothy/core/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def periodic(body: list, logger: any, conf: any) -> any:
# the heuristics for linear optimization.
if not conf.sw_pipelining.enabled:
res = Heuristics.linear(body, logger=logger, conf=conf)
return [], res.code, [], 0
return [], res.code, [], 0, res.clobbered_callee_saved

if conf.sw_pipelining.halving_heuristic:
return Heuristics._periodic_halving(body, logger, conf)
Expand All @@ -413,6 +413,8 @@ def periodic(body: list, logger: any, conf: any) -> any:

# Second step: Separately optimize preamble and postamble

clobbered = result.clobbered_callee_saved

preamble = result.preamble
if conf.sw_pipelining.optimize_preamble:
logger.debug("Optimize preamble...")
Expand All @@ -425,6 +427,7 @@ def periodic(body: list, logger: any, conf: any) -> any:
preamble, conf=c, logger=logger.getChild("preamble")
)
preamble = res_preamble.code
clobbered |= res_preamble.clobbered_callee_saved

postamble = result.postamble
if conf.sw_pipelining.optimize_postamble:
Expand All @@ -436,8 +439,9 @@ def periodic(body: list, logger: any, conf: any) -> any:
postamble, conf=c, logger=logger.getChild("postamble")
)
postamble = res_postamble.code
clobbered |= res_postamble.clobbered_callee_saved

return preamble, kernel, postamble, num_exceptional_iterations
return preamble, kernel, postamble, num_exceptional_iterations, clobbered

@staticmethod
def linear(body: list, logger: any, conf: any) -> any:
Expand Down Expand Up @@ -1184,6 +1188,10 @@ def is_pre(i):
# consider [B;A] as a non-periodic snippet, which may still lead to stalls at the
# loop boundary.

clobbered = set()
if not conf.sw_pipelining.halving_heuristic_split_only:
clobbered = res_halving_0.clobbered_callee_saved

if conf.sw_pipelining.halving_heuristic_periodic:
c = conf.copy()
c.inputs_are_outputs = True
Expand All @@ -1192,9 +1200,11 @@ def is_pre(i):
c.sw_pipelining.allow_pre = False # - no early instructions
c.sw_pipelining.allow_post = False # - no late instructions
# Just make sure to consider loop boundary
kernel = Heuristics.optimize_binsearch(
res_periodic = Heuristics.optimize_binsearch(
kernel, logger.getChild("periodic heuristic"), conf=c
).code
)
kernel = res_periodic.code
clobbered |= res_periodic.clobbered_callee_saved
elif not conf.sw_pipelining.halving_heuristic_split_only:
c = conf.copy()
c.outputs = new_kernel_deps
Expand All @@ -1205,6 +1215,7 @@ def is_pre(i):
kernel, logger.getChild("heuristic"), conf=c
)
final_kernel = res_halving_1.code
clobbered |= res_halving_1.clobbered_callee_saved

reordering2 = res_halving_1.reordering_with_bubbles

Expand Down Expand Up @@ -1252,4 +1263,4 @@ def get_reordering2(i):
kernel = final_kernel

num_exceptional_iterations = 1
return preamble, kernel, postamble, num_exceptional_iterations
return preamble, kernel, postamble, num_exceptional_iterations, clobbered
21 changes: 19 additions & 2 deletions slothy/core/slothy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def __init__(self, arch, target, logger=None):
self.last_result = None
self.success = None

def _prepend_clobber_comment(self, code, clobbered, indentation):
"""Prepend a callee-saved clobber comment to code if the feature is on."""
if not (self.config.emit_clobbered_callee_saves_comment and clobbered):
return code
comment_text = "Clobbered callee-saved registers: " + ", ".join(
sorted(clobbered)
)
clobber_line = SourceLine("").add_comment(comment_text)
clobber_line = SourceLine.apply_indentation([clobber_line], indentation)[0]
return [clobber_line] + code

def _get_version(self):
try:
from importlib.metadata import version
Expand Down Expand Up @@ -394,7 +405,11 @@ def optimize(
pre, body, post, "ORIGINAL", indentation
)

early, core, late, num_exceptional = Heuristics.periodic(body, logger, c)
early, core, late, num_exceptional, clobbered = Heuristics.periodic(
body, logger, c
)

core = self._prepend_clobber_comment(core, clobbered, indentation)

if self.config.with_llvm_mca_before is True:
core = core + orig_stats
Expand Down Expand Up @@ -632,10 +647,12 @@ def optimize_loop(
early, body, late, "ORIGINAL", indentation
)

preamble_code, kernel_code, postamble_code, num_exceptional = (
preamble_code, kernel_code, postamble_code, num_exceptional, clobbered = (
Heuristics.periodic(body, logger, c)
)

kernel_code = self._prepend_clobber_comment(kernel_code, clobbered, indentation)

# Remove branch instructions from preamble and postamble
postamble_code = [
line for line in postamble_code if not line.tags.get("branch")
Expand Down
5 changes: 5 additions & 0 deletions slothy/targets/arm_v81m/arch_v81m.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def __str__(self):
def __repr__(self):
return self.name

@staticmethod
def callee_saved_registers():
# AAPCS32 / Helium calling convention: r4-r11 GPR, q4-q7 MVE
return [f"r{i}" for i in range(4, 12)] + [f"q{i}" for i in range(4, 8)]

@staticmethod
def is_renamed(ty):
"""Indicate if register type should be subject to renaming"""
Expand Down
5 changes: 5 additions & 0 deletions slothy/targets/riscv/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def _list_registers(

list_registers = staticmethod(_list_registers)

@staticmethod
def callee_saved_registers():
# RISC-V calling convention: s0-s11 = x8, x9, x18-x27
return ["x8", "x9"] + [f"x{i}" for i in range(18, 28)]

@staticmethod
def find_type(r):
"""Find type of architectural register"""
Expand Down
Loading
Loading