Skip to content

Commit 722ab96

Browse files
committed
JIT: Implement hot-cold code splitting for improved icache locality
Split JIT stencil code into hot and cold sections to improve instruction cache performance. Hot code (common execution paths) is laid out contiguously across all uops in a trace, while cold code (error handling, rare branches) is placed after all hot code. Changes: - _optimizers.py: Physically reorder basic blocks so hot blocks come first, cold blocks last. Insert _JIT_COLD_START symbol at boundary. Add _make_jump() to each arch-specific optimizer for inserting explicit jumps when hot->cold fallthrough edges are broken. - _stencils.py: Track cold_offset in StencilGroup. Add hot_code_size() and cold_code_size() methods. Extract cold offset from _JIT_COLD_START symbol in object files. - _writer.py: Split emit functions to handle hot and cold code bodies separately, with adjusted hole patching for cold code offsets. Update StencilGroup C struct to include hot_code_size and cold_code_size. - _targets.py: Call extract_cold_offset() during stencil build. - jit.c: New memory layout [hot code|cold code|trampolines|data]. Update _PyJIT_Compile and compile_shim to use split layout. https://claude.ai/code/session_01QteUqrt8X8Ssxhei1kX2TH
1 parent 300de1e commit 722ab96

File tree

5 files changed

+213
-43
lines changed

5 files changed

+213
-43
lines changed

Python/jit.c

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -642,23 +642,27 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
642642
{
643643
const StencilGroup *group;
644644
// Loop once to find the total compiled size:
645-
size_t code_size = 0;
645+
size_t hot_code_size = 0;
646+
size_t cold_code_size = 0;
646647
size_t data_size = 0;
647648
jit_state state = {0};
648649
for (size_t i = 0; i < length; i++) {
649650
const _PyUOpInstruction *instruction = &trace[i];
650651
group = &stencil_groups[instruction->opcode];
651-
state.instruction_starts[i] = code_size;
652-
code_size += group->code_size;
652+
state.instruction_starts[i] = hot_code_size;
653+
hot_code_size += group->hot_code_size;
654+
cold_code_size += group->cold_code_size;
653655
data_size += group->data_size;
654656
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
655657
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
656658
}
657659
group = &stencil_groups[_FATAL_ERROR_r00];
658-
code_size += group->code_size;
660+
hot_code_size += group->hot_code_size;
661+
cold_code_size += group->cold_code_size;
659662
data_size += group->data_size;
660663
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
661664
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
665+
size_t code_size = hot_code_size + cold_code_size;
662666
// Calculate the size of the trampolines required by the whole trace
663667
for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
664668
state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
@@ -684,29 +688,36 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
684688
OPT_STAT_ADD(jit_got_size, state.got_symbols.size);
685689
OPT_STAT_ADD(jit_padding_size, padding);
686690
OPT_HIST(total_size, trace_total_memory_hist);
687-
// Update the offsets of each instruction:
691+
// Update the offsets of each instruction (hot code comes first):
688692
for (size_t i = 0; i < length; i++) {
689693
state.instruction_starts[i] += (uintptr_t)memory;
690694
}
691-
// Loop again to emit the code:
695+
// Memory layout: [hot code][cold code][trampolines][padding][data][GOT][padding]
696+
// Hot code for all uops is laid out contiguously first, improving
697+
// instruction cache locality. Cold code (error handling, rare paths)
698+
// is placed after all hot code.
692699
unsigned char *code = memory;
700+
unsigned char *cold_code = memory + hot_code_size;
693701
state.trampolines.mem = memory + code_size;
694702
unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
695703
assert(trace[0].opcode == _START_EXECUTOR_r00 || trace[0].opcode == _COLD_EXIT_r00 || trace[0].opcode == _COLD_DYNAMIC_EXIT_r00);
696704
state.got_symbols.mem = data + data_size;
697705
for (size_t i = 0; i < length; i++) {
698706
const _PyUOpInstruction *instruction = &trace[i];
699707
group = &stencil_groups[instruction->opcode];
700-
group->emit(code, data, executor, instruction, &state);
701-
code += group->code_size;
708+
group->emit(code, cold_code, data, executor, instruction, &state);
709+
code += group->hot_code_size;
710+
cold_code += group->cold_code_size;
702711
data += group->data_size;
703712
}
704713
// Protect against accidental buffer overrun into data:
705714
group = &stencil_groups[_FATAL_ERROR_r00];
706-
group->emit(code, data, executor, NULL, &state);
707-
code += group->code_size;
715+
group->emit(code, cold_code, data, executor, NULL, &state);
716+
code += group->hot_code_size;
717+
cold_code += group->cold_code_size;
708718
data += group->data_size;
709-
assert(code == memory + code_size);
719+
assert(code == memory + hot_code_size);
720+
assert(cold_code == memory + code_size);
710721
assert(data == memory + code_size + state.trampolines.size + code_padding + data_size);
711722
if (mark_executable(memory, total_size)) {
712723
jit_free(memory, total_size);
@@ -728,14 +739,17 @@ compile_shim(void)
728739
{
729740
_PyExecutorObject dummy;
730741
const StencilGroup *group;
731-
size_t code_size = 0;
742+
size_t hot_code_size = 0;
743+
size_t cold_code_size = 0;
732744
size_t data_size = 0;
733745
jit_state state = {0};
734746
group = &shim;
735-
code_size += group->code_size;
747+
hot_code_size += group->hot_code_size;
748+
cold_code_size += group->cold_code_size;
736749
data_size += group->data_size;
737750
combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
738751
combine_symbol_mask(group->got_mask, state.got_symbols.mask);
752+
size_t code_size = hot_code_size + cold_code_size;
739753
// Round up to the nearest page:
740754
size_t page_size = get_page_size();
741755
assert((page_size & (page_size - 1)) == 0);
@@ -747,17 +761,20 @@ compile_shim(void)
747761
return NULL;
748762
}
749763
unsigned char *code = memory;
764+
unsigned char *cold_code = memory + hot_code_size;
750765
state.trampolines.mem = memory + code_size;
751766
unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
752767
state.got_symbols.mem = data + data_size;
753768
// Compile the shim, which handles converting between the native
754769
// calling convention and the calling convention used by jitted code
755770
// (which may be different for efficiency reasons).
756771
group = &shim;
757-
group->emit(code, data, &dummy, NULL, &state);
758-
code += group->code_size;
772+
group->emit(code, cold_code, data, &dummy, NULL, &state);
773+
code += group->hot_code_size;
774+
cold_code += group->cold_code_size;
759775
data += group->data_size;
760-
assert(code == memory + code_size);
776+
assert(code == memory + hot_code_size);
777+
assert(cold_code == memory + code_size);
761778
assert(data == memory + code_size + state.trampolines.size + code_padding + data_size);
762779
if (mark_executable(memory, total_size)) {
763780
jit_free(memory, total_size);

Tools/jit/_optimizers.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,72 @@ def _blocks(self) -> typing.Generator[_Block, None, None]:
310310
yield block
311311
block = block.link
312312

313+
def _make_jump(self, target: str) -> Instruction:
314+
"""Create an unconditional jump instruction to the given label."""
315+
raise NotImplementedError
316+
317+
def _reorder_hot_cold(self) -> None:
318+
"""Reorder blocks so all hot blocks come first, then all cold blocks.
319+
320+
This improves instruction cache locality by keeping the common execution
321+
paths contiguous in memory. A _JIT_COLD_START label is inserted at the
322+
boundary between hot and cold code.
323+
324+
For any hot block that previously fell through to a cold block, an
325+
explicit jump is inserted to maintain correctness.
326+
"""
327+
continuation = self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE")
328+
hot_blocks: list[_Block] = []
329+
cold_blocks: list[_Block] = []
330+
# Blocks at and after _JIT_CONTINUE contain metadata; keep them at end
331+
post_continuation: list[_Block] = []
332+
past_continuation = False
333+
for block in self._blocks():
334+
if block is continuation:
335+
past_continuation = True
336+
if past_continuation:
337+
post_continuation.append(block)
338+
elif block.hot:
339+
hot_blocks.append(block)
340+
else:
341+
cold_blocks.append(block)
342+
343+
if not cold_blocks:
344+
# Nothing to reorder — all code is hot
345+
return
346+
347+
# For any hot block that falls through to a cold block, insert an
348+
# explicit jump to maintain correctness after reordering:
349+
for block in hot_blocks:
350+
if (
351+
block.fallthrough
352+
and block.link is not None
353+
and not block.link.hot
354+
and block.link is not continuation
355+
):
356+
# The fallthrough target is cold and will be moved away.
357+
# Insert a jump to the cold block's label:
358+
target_block = block.link.resolve()
359+
if target_block.label:
360+
jump = self._make_jump(target_block.label)
361+
block.instructions.append(jump)
362+
block.fallthrough = False
363+
364+
# Rebuild the linked list: hot -> cold -> post_continuation
365+
all_blocks = hot_blocks + cold_blocks + post_continuation
366+
for i, block in enumerate(all_blocks):
367+
block.link = all_blocks[i + 1] if i + 1 < len(all_blocks) else None
368+
369+
# Insert the cold start label between hot and cold sections.
370+
# Use the symbol_prefix so it appears in the object file's symbol table:
371+
if hot_blocks and cold_blocks:
372+
cold_label = f"{self.symbol_prefix}_JIT_COLD_START"
373+
cold_blocks[0].noninstructions.insert(0, f"{cold_label}:")
374+
375+
# Update the root to the first block
376+
if all_blocks:
377+
self._root = all_blocks[0]
378+
313379
def _body(self) -> str:
314380
lines = ["#" + line for line in self.text.splitlines()]
315381
hot = True
@@ -563,6 +629,7 @@ def run(self) -> None:
563629
self._invert_hot_branches()
564630
self._remove_redundant_jumps()
565631
self._remove_unreachable()
632+
self._reorder_hot_cold()
566633
self._fixup_external_labels()
567634
self._fixup_constants()
568635
self.path.write_text(self._body())
@@ -580,6 +647,9 @@ class OptimizerAArch64(Optimizer): # pylint: disable = too-few-public-methods
580647
rf"\s*(?P<instruction>{'|'.join(_branch_patterns)})\s+(.+,\s+)*(?P<target>[\w.]+)"
581648
)
582649

650+
def _make_jump(self, target: str) -> Instruction:
651+
return Instruction(InstructionKind.JUMP, "b", f"\tb {target}", target)
652+
583653
# https://developer.arm.com/documentation/ddi0406/b/Application-Level-Architecture/Instruction-Details/Alphabetical-list-of-instructions/BL--BLX--immediate-
584654
_re_call = re.compile(r"\s*blx?\s+(?P<target>[\w.]+)")
585655
# https://developer.arm.com/documentation/ddi0602/2025-03/Base-Instructions/B--Branch-
@@ -644,6 +714,9 @@ class OptimizerX86(Optimizer): # pylint: disable = too-few-public-methods
644714
_re_branch = re.compile(
645715
rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w.]+)"
646716
)
717+
718+
def _make_jump(self, target: str) -> Instruction:
719+
return Instruction(InstructionKind.JUMP, "jmp", f"\tjmp {target}", target)
647720
# https://www.felixcloutier.com/x86/call
648721
_re_call = re.compile(r"\s*callq?\s+(?P<target>[\w.]+)")
649722
# https://www.felixcloutier.com/x86/jmp

Tools/jit/_stencils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ class StencilGroup:
256256

257257
code: Stencil = dataclasses.field(default_factory=Stencil, init=False)
258258
data: Stencil = dataclasses.field(default_factory=Stencil, init=False)
259+
# Byte offset within code.body where cold code begins.
260+
# If 0, there is no hot/cold split (all code is hot):
261+
cold_offset: int = dataclasses.field(default=0, init=False)
259262
symbols: dict[int | str, tuple[HoleValue, int]] = dataclasses.field(
260263
default_factory=dict, init=False
261264
)
@@ -265,6 +268,14 @@ class StencilGroup:
265268
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
266269
_got_entries: set[int] = dataclasses.field(default_factory=set, init=False)
267270

271+
def extract_cold_offset(self) -> None:
272+
"""Extract the cold code offset from the _JIT_COLD_START symbol."""
273+
cold_start = self.symbols.get("_JIT_COLD_START")
274+
if cold_start is not None:
275+
value, offset = cold_start
276+
assert value is HoleValue.CODE
277+
self.cold_offset = offset
278+
268279
def convert_labels_to_relocations(self) -> None:
269280
for name, hole_plus in self.symbols.items():
270281
if isinstance(name, str) and "_JIT_RELOCATION_" in name:
@@ -420,9 +431,26 @@ def _get_trampoline_mask(self) -> str:
420431
def _get_got_mask(self) -> str:
421432
return self._get_symbol_mask(self._got_entries)
422433

434+
def hot_code_size(self) -> int:
435+
"""Size of the hot code section (or total code if no cold split)."""
436+
if self.cold_offset:
437+
return self.cold_offset
438+
return len(self.code.body)
439+
440+
def cold_code_size(self) -> int:
441+
"""Size of the cold code section."""
442+
if self.cold_offset:
443+
return len(self.code.body) - self.cold_offset
444+
return 0
445+
423446
def as_c(self, opname: str) -> str:
424447
"""Dump this hole as a StencilGroup initializer."""
425-
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}, {self._get_got_mask()}}}"
448+
return (
449+
f"{{emit_{opname}, {self.hot_code_size()}, "
450+
f"{self.cold_code_size()}, "
451+
f"{len(self.data.body)}, "
452+
f"{self._get_trampoline_mask()}, {self._get_got_mask()}}}"
453+
)
426454

427455

428456
def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:

Tools/jit/_targets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]:
219219
tasks.append(group.create_task(coro, name=opname))
220220
stencil_groups = {task.get_name(): task.result() for task in tasks}
221221
for stencil_group in stencil_groups.values():
222+
stencil_group.extract_cold_offset()
222223
stencil_group.convert_labels_to_relocations()
223224
stencil_group.process_relocations(self.known_symbols)
224225
return stencil_groups

0 commit comments

Comments
 (0)